From 989838757990b358741cc3cb74598d6424b81863 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sun, 12 Jun 2022 14:03:00 +0000 Subject: [PATCH] feat(updater): Configurable min ratio - `UPDATER_MIN_RATIO` variable - `-minratio` flag for CLI operation --- Dockerfile | 1 + internal/cli/update.go | 5 ++++- internal/configuration/settings/errors.go | 1 + .../configuration/settings/helpers/merge.go | 7 +++++++ .../settings/helpers/override.go | 7 +++++++ internal/configuration/settings/updater.go | 19 +++++++++++++++++++ internal/configuration/sources/env/helpers.go | 9 +++++++++ internal/configuration/sources/env/updater.go | 5 +++++ internal/updater/loop/loop.go | 4 ++-- internal/updater/providers.go | 10 +++------- internal/updater/updater.go | 5 +++-- 11 files changed, 61 insertions(+), 12 deletions(-) diff --git a/Dockerfile b/Dockerfile index f9cbd6a7..c43f4938 100644 --- a/Dockerfile +++ b/Dockerfile @@ -153,6 +153,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ HTTP_CONTROL_SERVER_ADDRESS=":8000" \ # Server data updater UPDATER_PERIOD=0 \ + UPDATER_MIN_RATIO=0.8 \ UPDATER_VPN_SERVICE_PROVIDERS= \ # Public IP PUBLICIP_FILE="/tmp/gluetun/ip" \ diff --git a/internal/cli/update.go b/internal/cli/update.go index bd1285ff..c0599519 100644 --- a/internal/cli/update.go +++ b/internal/cli/update.go @@ -41,6 +41,9 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e flagSet.BoolVar(&maintainerMode, "maintainer", false, "Write results to ./internal/storage/servers.json to modify the program (for maintainers)") flagSet.StringVar(&options.DNSAddress, "dns", "8.8.8.8", "DNS resolver address to use") + const defaultMinRatio = 0.8 + flagSet.Float64Var(&options.MinRatio, "minratio", defaultMinRatio, + "Minimum ratio of servers to find for the update to succeed") flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers") flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for") if err := flagSet.Parse(args); err != nil { @@ -83,7 +86,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e unzipper, parallelResolver, ipFetcher, openvpnFileExtractor) updater := updater.New(httpClient, storage, providers, logger) - err = updater.UpdateServers(ctx, options.Providers) + err = updater.UpdateServers(ctx, options.Providers, options.MinRatio) if err != nil { return fmt.Errorf("cannot update server information: %w", err) } diff --git a/internal/configuration/settings/errors.go b/internal/configuration/settings/errors.go index 968a85fa..0092c993 100644 --- a/internal/configuration/settings/errors.go +++ b/internal/configuration/settings/errors.go @@ -10,6 +10,7 @@ var ( ErrFirewallZeroPort = errors.New("cannot have a zero port to block") ErrHostnameNotValid = errors.New("the hostname specified is not valid") ErrISPNotValid = errors.New("the ISP specified is not valid") + ErrMinRatioNotValid = errors.New("minimum ratio is not valid") ErrMissingValue = errors.New("missing value") ErrNameNotValid = errors.New("the server name specified is not valid") ErrOpenVPNClientKeyMissing = errors.New("client key is missing") diff --git a/internal/configuration/settings/helpers/merge.go b/internal/configuration/settings/helpers/merge.go index bb0ca3b8..e1afd670 100644 --- a/internal/configuration/settings/helpers/merge.go +++ b/internal/configuration/settings/helpers/merge.go @@ -34,6 +34,13 @@ func MergeWithInt(existing, other int) (result int) { return other } +func MergeWithFloat64(existing, other float64) (result float64) { + if existing != 0 { + return existing + } + return other +} + func MergeWithStringPtr(existing, other *string) (result *string) { if existing != nil { return existing diff --git a/internal/configuration/settings/helpers/override.go b/internal/configuration/settings/helpers/override.go index 7e1c77e8..bfad029a 100644 --- a/internal/configuration/settings/helpers/override.go +++ b/internal/configuration/settings/helpers/override.go @@ -32,6 +32,13 @@ func OverrideWithInt(existing, other int) (result int) { return other } +func OverrideWithFloat64(existing, other float64) (result float64) { + if other == 0 { + return existing + } + return other +} + func OverrideWithStringPtr(existing, other *string) (result *string) { if other == nil { return existing diff --git a/internal/configuration/settings/updater.go b/internal/configuration/settings/updater.go index 63c144ea..2628f380 100644 --- a/internal/configuration/settings/updater.go +++ b/internal/configuration/settings/updater.go @@ -22,6 +22,10 @@ type Updater struct { // to resolve VPN server hostnames to IP addresses. // It cannot be the empty string in the internal state. DNSAddress string + // MinRatio is the minimum ratio of servers to + // find per provider, compared to the total current + // number of servers. It defaults to 0.8. + MinRatio float64 // Providers is the list of VPN service providers // to update server information for. Providers []string @@ -34,6 +38,11 @@ func (u Updater) Validate() (err error) { ErrUpdaterPeriodTooSmall, *u.Period, minPeriod) } + if u.MinRatio <= 0 || u.MinRatio > 1 { + return fmt.Errorf("%w: %.2f must be between 0+ and 1", + ErrMinRatioNotValid, u.MinRatio) + } + validProviders := providers.All() for _, provider := range u.Providers { valid := false @@ -56,6 +65,7 @@ func (u *Updater) copy() (copied Updater) { return Updater{ Period: helpers.CopyDurationPtr(u.Period), DNSAddress: u.DNSAddress, + MinRatio: u.MinRatio, Providers: helpers.CopyStringSlice(u.Providers), } } @@ -65,6 +75,7 @@ func (u *Updater) copy() (copied Updater) { func (u *Updater) mergeWith(other Updater) { u.Period = helpers.MergeWithDuration(u.Period, other.Period) u.DNSAddress = helpers.MergeWithString(u.DNSAddress, other.DNSAddress) + u.MinRatio = helpers.MergeWithFloat64(u.MinRatio, other.MinRatio) u.Providers = helpers.MergeStringSlices(u.Providers, other.Providers) } @@ -74,12 +85,19 @@ func (u *Updater) mergeWith(other Updater) { func (u *Updater) overrideWith(other Updater) { u.Period = helpers.OverrideWithDuration(u.Period, other.Period) u.DNSAddress = helpers.OverrideWithString(u.DNSAddress, other.DNSAddress) + u.MinRatio = helpers.OverrideWithFloat64(u.MinRatio, other.MinRatio) u.Providers = helpers.OverrideWithStringSlice(u.Providers, other.Providers) } func (u *Updater) SetDefaults(vpnProvider string) { u.Period = helpers.DefaultDuration(u.Period, 0) u.DNSAddress = helpers.DefaultString(u.DNSAddress, "1.1.1.1:53") + + if u.MinRatio == 0 { + const defaultMinRatio = 0.8 + u.MinRatio = defaultMinRatio + } + if len(u.Providers) == 0 && vpnProvider != providers.Custom { u.Providers = []string{vpnProvider} } @@ -97,6 +115,7 @@ func (u Updater) toLinesNode() (node *gotree.Node) { node = gotree.New("Server data updater settings:") node.Appendf("Update period: %s", *u.Period) node.Appendf("DNS address: %s", u.DNSAddress) + node.Appendf("Minimum ratio: %.1f", u.MinRatio) node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", ")) return node diff --git a/internal/configuration/sources/env/helpers.go b/internal/configuration/sources/env/helpers.go index 4e598133..96a516d0 100644 --- a/internal/configuration/sources/env/helpers.go +++ b/internal/configuration/sources/env/helpers.go @@ -38,6 +38,15 @@ func envToInt(envKey string) (n int, err error) { return strconv.Atoi(s) } +func envToFloat64(envKey string) (f float64, err error) { + s := getCleanedEnv(envKey) + if s == "" { + return 0, nil + } + const bits = 64 + return strconv.ParseFloat(s, bits) +} + func envToStringPtr(envKey string) (stringPtr *string) { s := getCleanedEnv(envKey) if s == "" { diff --git a/internal/configuration/sources/env/updater.go b/internal/configuration/sources/env/updater.go index f02b62cd..320f13fb 100644 --- a/internal/configuration/sources/env/updater.go +++ b/internal/configuration/sources/env/updater.go @@ -18,6 +18,11 @@ func readUpdater() (updater settings.Updater, err error) { return updater, err } + updater.MinRatio, err = envToFloat64("UPDATER_MIN_RATIO") + if err != nil { + return updater, fmt.Errorf("environment variable UPDATER_MIN_RATIO: %w", err) + } + updater.Providers = envToCSV("UPDATER_VPN_SERVICE_PROVIDERS") return updater, nil diff --git a/internal/updater/loop/loop.go b/internal/updater/loop/loop.go index 99d8b87e..91d3232e 100644 --- a/internal/updater/loop/loop.go +++ b/internal/updater/loop/loop.go @@ -13,7 +13,7 @@ import ( ) type Updater interface { - UpdateServers(ctx context.Context, providers []string) (err error) + UpdateServers(ctx context.Context, providers []string, minRatio float64) (err error) } type Loop struct { @@ -97,7 +97,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { runWg.Add(1) go func() { defer runWg.Done() - err := l.updater.UpdateServers(updateCtx, settings.Providers) + err := l.updater.UpdateServers(updateCtx, settings.Providers, settings.MinRatio) if err != nil { if updateCtx.Err() == nil { errorCh <- err diff --git a/internal/updater/providers.go b/internal/updater/providers.go index 23028ace..1ca01e97 100644 --- a/internal/updater/providers.go +++ b/internal/updater/providers.go @@ -12,10 +12,11 @@ type Provider interface { FetchServers(ctx context.Context, minServers int) (servers []models.Server, err error) } -func (u *Updater) updateProvider(ctx context.Context, provider Provider) (err error) { +func (u *Updater) updateProvider(ctx context.Context, provider Provider, + minRatio float64) (err error) { providerName := provider.Name() existingServersCount := u.storage.GetServersCount(providerName) - minServers := getMinServers(existingServersCount) + minServers := int(minRatio * float64(existingServersCount)) servers, err := provider.FetchServers(ctx, minServers) if err != nil { return fmt.Errorf("cannot get servers: %w", err) @@ -35,8 +36,3 @@ func (u *Updater) updateProvider(ctx context.Context, provider Provider) (err er } return nil } - -func getMinServers(existingServersCount int) (minServers int) { - const minRatio = 0.8 - return int(minRatio * float64(existingServersCount)) -} diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 8f5f48e1..12efd75b 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -37,7 +37,8 @@ func New(httpClient *http.Client, storage Storage, } } -func (u *Updater) UpdateServers(ctx context.Context, providers []string) (err error) { +func (u *Updater) UpdateServers(ctx context.Context, providers []string, + minRatio float64) (err error) { caser := cases.Title(language.English) for _, providerName := range providers { u.logger.Info("updating " + caser.String(providerName) + " servers...") @@ -45,7 +46,7 @@ func (u *Updater) UpdateServers(ctx context.Context, providers []string) (err er fetcher := u.providers.Get(providerName) // TODO support servers offering only TCP or only UDP // for NordVPN and PureVPN - err := u.updateProvider(ctx, fetcher) + err := u.updateProvider(ctx, fetcher, minRatio) if err == nil { continue }