feat(config): allow invalid server filters (#2419)
- Disallow setting a server filter when there is no choice available - Allow setting an invalid server filter when there is at least one choice available - Log at warn level when an invalid server filter is set - Fix #2337
This commit is contained in:
@@ -249,7 +249,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
return fmt.Errorf("checking for IPv6 support: %w", err)
|
return fmt.Errorf("checking for IPv6 support: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = allSettings.Validate(storage, ipv6Supported)
|
err = allSettings.Validate(storage, ipv6Supported, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
|
|||||||
return fmt.Errorf("checking for IPv6 support: %w", err)
|
return fmt.Errorf("checking for IPv6 support: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = allSettings.Validate(storage, ipv6Supported); err != nil {
|
if err = allSettings.Validate(storage, ipv6Supported, logger); err != nil {
|
||||||
return fmt.Errorf("validating settings: %w", err)
|
return fmt.Errorf("validating settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
5
internal/configuration/settings/interfaces.go
Normal file
5
internal/configuration/settings/interfaces.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
type Warner interface {
|
||||||
|
Warn(message string)
|
||||||
|
}
|
||||||
@@ -25,7 +25,7 @@ type Provider struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||||
func (p *Provider) validate(vpnType string, storage Storage) (err error) {
|
func (p *Provider) validate(vpnType string, storage Storage, warner Warner) (err error) {
|
||||||
// Validate Name
|
// Validate Name
|
||||||
var validNames []string
|
var validNames []string
|
||||||
if vpnType == vpn.OpenVPN {
|
if vpnType == vpn.OpenVPN {
|
||||||
@@ -48,7 +48,7 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) {
|
|||||||
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
|
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.ServerSelection.validate(p.Name, storage)
|
err = p.ServerSelection.validate(p.Name, storage, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("server selection: %w", err)
|
return fmt.Errorf("server selection: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,14 +91,14 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (ss *ServerSelection) validate(vpnServiceProvider string,
|
func (ss *ServerSelection) validate(vpnServiceProvider string,
|
||||||
storage Storage) (err error) {
|
storage Storage, warner Warner) (err error) {
|
||||||
switch ss.VPN {
|
switch ss.VPN {
|
||||||
case vpn.OpenVPN, vpn.Wireguard:
|
case vpn.OpenVPN, vpn.Wireguard:
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
||||||
}
|
}
|
||||||
|
|
||||||
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage)
|
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err // already wrapped error
|
return err // already wrapped error
|
||||||
}
|
}
|
||||||
@@ -111,7 +111,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
|
|||||||
*ss = surfsharkRetroRegion(*ss)
|
*ss = surfsharkRetroRegion(*ss)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateServerFilters(*ss, filterChoices, vpnServiceProvider)
|
err = validateServerFilters(*ss, filterChoices, vpnServiceProvider, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err)
|
return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err)
|
||||||
}
|
}
|
||||||
@@ -142,19 +142,19 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getLocationFilterChoices(vpnServiceProvider string,
|
func getLocationFilterChoices(vpnServiceProvider string,
|
||||||
ss *ServerSelection, storage Storage) (filterChoices models.FilterChoices,
|
ss *ServerSelection, storage Storage, warner Warner) (
|
||||||
err error) {
|
filterChoices models.FilterChoices, err error) {
|
||||||
filterChoices = storage.GetFilterChoices(vpnServiceProvider)
|
filterChoices = storage.GetFilterChoices(vpnServiceProvider)
|
||||||
|
|
||||||
if vpnServiceProvider == providers.Surfshark {
|
if vpnServiceProvider == providers.Surfshark {
|
||||||
// // Retro compatibility
|
// // Retro compatibility
|
||||||
// TODO v4 remove
|
// TODO v4 remove
|
||||||
newAndRetroRegions := append(filterChoices.Regions, validation.SurfsharkRetroLocChoices()...) //nolint:gocritic
|
newAndRetroRegions := append(filterChoices.Regions, validation.SurfsharkRetroLocChoices()...) //nolint:gocritic
|
||||||
err := validate.AreAllOneOfCaseInsensitive(ss.Regions, newAndRetroRegions)
|
err := atLeastOneIsOneOfCaseInsensitive(ss.Regions, newAndRetroRegions, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Only return error comparing with newer regions, we don't want to confuse the user
|
// Only return error comparing with newer regions, we don't want to confuse the user
|
||||||
// with the retro regions in the error message.
|
// with the retro regions in the error message.
|
||||||
err = validate.AreAllOneOfCaseInsensitive(ss.Regions, filterChoices.Regions)
|
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
|
||||||
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,28 +165,28 @@ func getLocationFilterChoices(vpnServiceProvider string,
|
|||||||
// validateServerFilters validates filters against the choices given as arguments.
|
// validateServerFilters validates filters against the choices given as arguments.
|
||||||
// Set an argument to nil to pass the check for a particular filter.
|
// Set an argument to nil to pass the check for a particular filter.
|
||||||
func validateServerFilters(settings ServerSelection, filterChoices models.FilterChoices,
|
func validateServerFilters(settings ServerSelection, filterChoices models.FilterChoices,
|
||||||
vpnServiceProvider string) (err error) {
|
vpnServiceProvider string, warner Warner) (err error) {
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.Countries, filterChoices.Countries)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrCountryNotValid, err)
|
return fmt.Errorf("%w: %w", ErrCountryNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.Regions, filterChoices.Regions)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
return fmt.Errorf("%w: %w", ErrRegionNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.Cities, filterChoices.Cities)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrCityNotValid, err)
|
return fmt.Errorf("%w: %w", ErrCityNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrISPNotValid, err)
|
return fmt.Errorf("%w: %w", ErrISPNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
||||||
}
|
}
|
||||||
@@ -197,12 +197,12 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
|||||||
// which requires a server name for TLS verification.
|
// which requires a server name for TLS verification.
|
||||||
filterChoices.Names = settings.Names
|
filterChoices.Names = settings.Names
|
||||||
}
|
}
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrNameNotValid, err)
|
return fmt.Errorf("%w: %w", ErrNameNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.Categories, filterChoices.Categories)
|
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err)
|
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err)
|
||||||
}
|
}
|
||||||
@@ -210,6 +210,42 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func atLeastOneIsOneOfCaseInsensitive(values, choices []string,
|
||||||
|
warner Warner) (err error) {
|
||||||
|
if len(values) > 0 && len(choices) == 0 {
|
||||||
|
return fmt.Errorf("%w", validate.ErrNoChoice)
|
||||||
|
}
|
||||||
|
|
||||||
|
set := make(map[string]struct{}, len(choices))
|
||||||
|
for _, choice := range choices {
|
||||||
|
lowercaseChoice := strings.ToLower(choice)
|
||||||
|
set[lowercaseChoice] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
invalidValues := make([]string, 0, len(values))
|
||||||
|
for _, value := range values {
|
||||||
|
lowercaseValue := strings.ToLower(value)
|
||||||
|
_, ok := set[lowercaseValue]
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
invalidValues = append(invalidValues, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch len(invalidValues) {
|
||||||
|
case 0:
|
||||||
|
return nil
|
||||||
|
case len(values):
|
||||||
|
return fmt.Errorf("%w: none of %s is one of the choices available %s",
|
||||||
|
validate.ErrValueNotOneOf, strings.Join(values, ", "), strings.Join(choices, ", "))
|
||||||
|
default:
|
||||||
|
warner.Warn(fmt.Sprintf("values %s are not in choices %s",
|
||||||
|
strings.Join(invalidValues, ", "), strings.Join(choices, ", ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvider string) error {
|
func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvider string) error {
|
||||||
switch {
|
switch {
|
||||||
case *settings.FreeOnly &&
|
case *settings.FreeOnly &&
|
||||||
|
|||||||
@@ -36,7 +36,8 @@ type Storage interface {
|
|||||||
// Validate validates all the settings and returns an error
|
// Validate validates all the settings and returns an error
|
||||||
// if one of them is not valid.
|
// if one of them is not valid.
|
||||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||||
func (s *Settings) Validate(storage Storage, ipv6Supported bool) (err error) {
|
func (s *Settings) Validate(storage Storage, ipv6Supported bool,
|
||||||
|
warner Warner) (err error) {
|
||||||
nameToValidation := map[string]func() error{
|
nameToValidation := map[string]func() error{
|
||||||
"control server": s.ControlServer.validate,
|
"control server": s.ControlServer.validate,
|
||||||
"dns": s.DNS.validate,
|
"dns": s.DNS.validate,
|
||||||
@@ -51,7 +52,7 @@ func (s *Settings) Validate(storage Storage, ipv6Supported bool) (err error) {
|
|||||||
"version": s.Version.validate,
|
"version": s.Version.validate,
|
||||||
// Pprof validation done in pprof constructor
|
// Pprof validation done in pprof constructor
|
||||||
"VPN": func() error {
|
"VPN": func() error {
|
||||||
return s.VPN.Validate(storage, ipv6Supported)
|
return s.VPN.Validate(storage, ipv6Supported, warner)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,7 +85,7 @@ func (s *Settings) copy() (copied Settings) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Settings) OverrideWith(other Settings,
|
func (s *Settings) OverrideWith(other Settings,
|
||||||
storage Storage, ipv6Supported bool) (err error) {
|
storage Storage, ipv6Supported bool, warner Warner) (err error) {
|
||||||
patchedSettings := s.copy()
|
patchedSettings := s.copy()
|
||||||
patchedSettings.ControlServer.overrideWith(other.ControlServer)
|
patchedSettings.ControlServer.overrideWith(other.ControlServer)
|
||||||
patchedSettings.DNS.overrideWith(other.DNS)
|
patchedSettings.DNS.overrideWith(other.DNS)
|
||||||
@@ -99,7 +100,7 @@ func (s *Settings) OverrideWith(other Settings,
|
|||||||
patchedSettings.Version.overrideWith(other.Version)
|
patchedSettings.Version.overrideWith(other.Version)
|
||||||
patchedSettings.VPN.OverrideWith(other.VPN)
|
patchedSettings.VPN.OverrideWith(other.VPN)
|
||||||
patchedSettings.Pprof.OverrideWith(other.Pprof)
|
patchedSettings.Pprof.OverrideWith(other.Pprof)
|
||||||
err = patchedSettings.Validate(storage, ipv6Supported)
|
err = patchedSettings.Validate(storage, ipv6Supported, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,14 +21,14 @@ type VPN struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||||
func (v *VPN) Validate(storage Storage, ipv6Supported bool) (err error) {
|
func (v *VPN) Validate(storage Storage, ipv6Supported bool, warner Warner) (err error) {
|
||||||
// Validate Type
|
// Validate Type
|
||||||
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
|
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
|
||||||
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
|
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = v.Provider.validate(v.Type, storage)
|
err = v.Provider.validate(v.Type, storage, warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("provider settings: %w", err)
|
return fmt.Errorf("provider settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ func (h *vpnHandler) patchSettings(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
updatedSettings := h.looper.GetSettings() // already copied
|
updatedSettings := h.looper.GetSettings() // already copied
|
||||||
updatedSettings.OverrideWith(overrideSettings)
|
updatedSettings.OverrideWith(overrideSettings)
|
||||||
err = updatedSettings.Validate(h.storage, h.ipv6Supported)
|
err = updatedSettings.Validate(h.storage, h.ipv6Supported, h.warner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user