From 897a9d7f57ed510dc670e813235c13d374372d45 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 17 Aug 2024 12:01:26 +0200 Subject: [PATCH] feat(config): allow invalid server filters (#2419) - Disallow setting a server filter when there is no choice available - Allow setting an invalid server filter when there is at least one choice available - Log at warn level when an invalid server filter is set - Fix #2337 --- cmd/gluetun/main.go | 2 +- internal/cli/openvpnconfig.go | 2 +- internal/configuration/settings/interfaces.go | 5 ++ internal/configuration/settings/provider.go | 4 +- .../configuration/settings/serverselection.go | 66 ++++++++++++++----- internal/configuration/settings/settings.go | 9 +-- internal/configuration/settings/vpn.go | 4 +- internal/server/vpn.go | 2 +- 8 files changed, 68 insertions(+), 26 deletions(-) create mode 100644 internal/configuration/settings/interfaces.go diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 0bfa2d53..d56065c0 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -249,7 +249,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, return fmt.Errorf("checking for IPv6 support: %w", err) } - err = allSettings.Validate(storage, ipv6Supported) + err = allSettings.Validate(storage, ipv6Supported, logger) if err != nil { return err } diff --git a/internal/cli/openvpnconfig.go b/internal/cli/openvpnconfig.go index ccb5c1c8..3ae45441 100644 --- a/internal/cli/openvpnconfig.go +++ b/internal/cli/openvpnconfig.go @@ -59,7 +59,7 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader, return fmt.Errorf("checking for IPv6 support: %w", err) } - if err = allSettings.Validate(storage, ipv6Supported); err != nil { + if err = allSettings.Validate(storage, ipv6Supported, logger); err != nil { return fmt.Errorf("validating settings: %w", err) } diff --git a/internal/configuration/settings/interfaces.go b/internal/configuration/settings/interfaces.go new file mode 100644 index 00000000..a93f1022 --- /dev/null +++ b/internal/configuration/settings/interfaces.go @@ -0,0 +1,5 @@ +package settings + +type Warner interface { + Warn(message string) +} diff --git a/internal/configuration/settings/provider.go b/internal/configuration/settings/provider.go index 2129a186..31a144ff 100644 --- a/internal/configuration/settings/provider.go +++ b/internal/configuration/settings/provider.go @@ -25,7 +25,7 @@ type Provider struct { } // TODO v4 remove pointer for receiver (because of Surfshark). -func (p *Provider) validate(vpnType string, storage Storage) (err error) { +func (p *Provider) validate(vpnType string, storage Storage, warner Warner) (err error) { // Validate Name var validNames []string if vpnType == vpn.OpenVPN { @@ -48,7 +48,7 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) { return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err) } - err = p.ServerSelection.validate(p.Name, storage) + err = p.ServerSelection.validate(p.Name, storage, warner) if err != nil { return fmt.Errorf("server selection: %w", err) } diff --git a/internal/configuration/settings/serverselection.go b/internal/configuration/settings/serverselection.go index ed254f2f..66c39b13 100644 --- a/internal/configuration/settings/serverselection.go +++ b/internal/configuration/settings/serverselection.go @@ -91,14 +91,14 @@ var ( ) func (ss *ServerSelection) validate(vpnServiceProvider string, - storage Storage) (err error) { + storage Storage, warner Warner) (err error) { switch ss.VPN { case vpn.OpenVPN, vpn.Wireguard: default: return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN) } - filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage) + filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage, warner) if err != nil { return err // already wrapped error } @@ -111,7 +111,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string, *ss = surfsharkRetroRegion(*ss) } - err = validateServerFilters(*ss, filterChoices, vpnServiceProvider) + err = validateServerFilters(*ss, filterChoices, vpnServiceProvider, warner) if err != nil { return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err) } @@ -142,19 +142,19 @@ func (ss *ServerSelection) validate(vpnServiceProvider string, } func getLocationFilterChoices(vpnServiceProvider string, - ss *ServerSelection, storage Storage) (filterChoices models.FilterChoices, - err error) { + ss *ServerSelection, storage Storage, warner Warner) ( + filterChoices models.FilterChoices, err error) { filterChoices = storage.GetFilterChoices(vpnServiceProvider) if vpnServiceProvider == providers.Surfshark { // // Retro compatibility // TODO v4 remove newAndRetroRegions := append(filterChoices.Regions, validation.SurfsharkRetroLocChoices()...) //nolint:gocritic - err := validate.AreAllOneOfCaseInsensitive(ss.Regions, newAndRetroRegions) + err := atLeastOneIsOneOfCaseInsensitive(ss.Regions, newAndRetroRegions, warner) if err != nil { // Only return error comparing with newer regions, we don't want to confuse the user // with the retro regions in the error message. - err = validate.AreAllOneOfCaseInsensitive(ss.Regions, filterChoices.Regions) + err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner) return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err) } } @@ -165,28 +165,28 @@ func getLocationFilterChoices(vpnServiceProvider string, // validateServerFilters validates filters against the choices given as arguments. // Set an argument to nil to pass the check for a particular filter. func validateServerFilters(settings ServerSelection, filterChoices models.FilterChoices, - vpnServiceProvider string) (err error) { - err = validate.AreAllOneOfCaseInsensitive(settings.Countries, filterChoices.Countries) + vpnServiceProvider string, warner Warner) (err error) { + err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner) if err != nil { return fmt.Errorf("%w: %w", ErrCountryNotValid, err) } - err = validate.AreAllOneOfCaseInsensitive(settings.Regions, filterChoices.Regions) + err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner) if err != nil { return fmt.Errorf("%w: %w", ErrRegionNotValid, err) } - err = validate.AreAllOneOfCaseInsensitive(settings.Cities, filterChoices.Cities) + err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner) if err != nil { return fmt.Errorf("%w: %w", ErrCityNotValid, err) } - err = validate.AreAllOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs) + err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner) if err != nil { return fmt.Errorf("%w: %w", ErrISPNotValid, err) } - err = validate.AreAllOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames) + err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner) if err != nil { return fmt.Errorf("%w: %w", ErrHostnameNotValid, err) } @@ -197,12 +197,12 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter // which requires a server name for TLS verification. filterChoices.Names = settings.Names } - err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names) + err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner) if err != nil { return fmt.Errorf("%w: %w", ErrNameNotValid, err) } - err = validate.AreAllOneOfCaseInsensitive(settings.Categories, filterChoices.Categories) + err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner) if err != nil { return fmt.Errorf("%w: %w", ErrCategoryNotValid, err) } @@ -210,6 +210,42 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter return nil } +func atLeastOneIsOneOfCaseInsensitive(values, choices []string, + warner Warner) (err error) { + if len(values) > 0 && len(choices) == 0 { + return fmt.Errorf("%w", validate.ErrNoChoice) + } + + set := make(map[string]struct{}, len(choices)) + for _, choice := range choices { + lowercaseChoice := strings.ToLower(choice) + set[lowercaseChoice] = struct{}{} + } + + invalidValues := make([]string, 0, len(values)) + for _, value := range values { + lowercaseValue := strings.ToLower(value) + _, ok := set[lowercaseValue] + if ok { + continue + } + invalidValues = append(invalidValues, value) + } + + switch len(invalidValues) { + case 0: + return nil + case len(values): + return fmt.Errorf("%w: none of %s is one of the choices available %s", + validate.ErrValueNotOneOf, strings.Join(values, ", "), strings.Join(choices, ", ")) + default: + warner.Warn(fmt.Sprintf("values %s are not in choices %s", + strings.Join(invalidValues, ", "), strings.Join(choices, ", "))) + } + + return nil +} + func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvider string) error { switch { case *settings.FreeOnly && diff --git a/internal/configuration/settings/settings.go b/internal/configuration/settings/settings.go index 736bc2fd..399ad3eb 100644 --- a/internal/configuration/settings/settings.go +++ b/internal/configuration/settings/settings.go @@ -36,7 +36,8 @@ type Storage interface { // Validate validates all the settings and returns an error // if one of them is not valid. // TODO v4 remove pointer for receiver (because of Surfshark). -func (s *Settings) Validate(storage Storage, ipv6Supported bool) (err error) { +func (s *Settings) Validate(storage Storage, ipv6Supported bool, + warner Warner) (err error) { nameToValidation := map[string]func() error{ "control server": s.ControlServer.validate, "dns": s.DNS.validate, @@ -51,7 +52,7 @@ func (s *Settings) Validate(storage Storage, ipv6Supported bool) (err error) { "version": s.Version.validate, // Pprof validation done in pprof constructor "VPN": func() error { - return s.VPN.Validate(storage, ipv6Supported) + return s.VPN.Validate(storage, ipv6Supported, warner) }, } @@ -84,7 +85,7 @@ func (s *Settings) copy() (copied Settings) { } func (s *Settings) OverrideWith(other Settings, - storage Storage, ipv6Supported bool) (err error) { + storage Storage, ipv6Supported bool, warner Warner) (err error) { patchedSettings := s.copy() patchedSettings.ControlServer.overrideWith(other.ControlServer) patchedSettings.DNS.overrideWith(other.DNS) @@ -99,7 +100,7 @@ func (s *Settings) OverrideWith(other Settings, patchedSettings.Version.overrideWith(other.Version) patchedSettings.VPN.OverrideWith(other.VPN) patchedSettings.Pprof.OverrideWith(other.Pprof) - err = patchedSettings.Validate(storage, ipv6Supported) + err = patchedSettings.Validate(storage, ipv6Supported, warner) if err != nil { return err } diff --git a/internal/configuration/settings/vpn.go b/internal/configuration/settings/vpn.go index ba2fd93a..9d491bbb 100644 --- a/internal/configuration/settings/vpn.go +++ b/internal/configuration/settings/vpn.go @@ -21,14 +21,14 @@ type VPN struct { } // TODO v4 remove pointer for receiver (because of Surfshark). -func (v *VPN) Validate(storage Storage, ipv6Supported bool) (err error) { +func (v *VPN) Validate(storage Storage, ipv6Supported bool, warner Warner) (err error) { // Validate Type validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard} if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil { return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err) } - err = v.Provider.validate(v.Type, storage) + err = v.Provider.validate(v.Type, storage, warner) if err != nil { return fmt.Errorf("provider settings: %w", err) } diff --git a/internal/server/vpn.go b/internal/server/vpn.go index 7e4aaa0d..cb32c806 100644 --- a/internal/server/vpn.go +++ b/internal/server/vpn.go @@ -116,7 +116,7 @@ func (h *vpnHandler) patchSettings(w http.ResponseWriter, r *http.Request) { updatedSettings := h.looper.GetSettings() // already copied updatedSettings.OverrideWith(overrideSettings) - err = updatedSettings.Validate(h.storage, h.ipv6Supported) + err = updatedSettings.Validate(h.storage, h.ipv6Supported, h.warner) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return