diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 2c0b32c4..c91cd377 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -215,9 +215,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, return err } - allServers := storage.GetServers() - - err = allSettings.Validate(allServers) + err = allSettings.Validate(storage) if err != nil { return err } @@ -378,7 +376,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, vpnLogger := logger.New(log.SetComponent("vpn")) vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.Firewall.VPNInputPorts, - allServers, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper, + storage, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper, cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient, buildInfo, *allSettings.Version.Enabled) vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler( @@ -386,8 +384,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, go vpnLooper.Run(vpnCtx, vpnDone) updaterLooper := updater.NewLooper(allSettings.Updater, - allServers, storage, vpnLooper.SetServers, httpClient, - logger.New(log.SetComponent("updater"))) + storage, httpClient, logger.New(log.SetComponent("updater"))) updaterHandler, updaterCtx, updaterDone := goshutdown.NewGoRoutineHandler( "updater", goroutine.OptionTimeout(defaultShutdownTimeout)) // wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker diff --git a/internal/cli/formatservers.go b/internal/cli/formatservers.go index d7d95e46..f0aad72c 100644 --- a/internal/cli/formatservers.go +++ b/internal/cli/formatservers.go @@ -10,7 +10,6 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/storage" "golang.org/x/text/cases" "golang.org/x/text/language" @@ -80,9 +79,8 @@ func (c *CLI) FormatServers(args []string) error { if err != nil { return fmt.Errorf("cannot create servers storage: %w", err) } - currentServers := storage.GetServers() - formatted := formatServers(currentServers, providerToFormat) + formatted := storage.FormatToMarkdown(providerToFormat) output = filepath.Clean(output) file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644) @@ -103,11 +101,3 @@ func (c *CLI) FormatServers(args []string) error { return nil } - -func formatServers(allServers models.AllServers, provider string) (formatted string) { - servers, ok := allServers.ProviderToServers[provider] - if !ok { - panic(fmt.Sprintf("unknown provider in format map: %s", provider)) - } - return servers.ToMarkdown(provider) -} diff --git a/internal/cli/openvpnconfig.go b/internal/cli/openvpnconfig.go index dc29706d..5e2ec48b 100644 --- a/internal/cli/openvpnconfig.go +++ b/internal/cli/openvpnconfig.go @@ -25,18 +25,17 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) e if err != nil { return err } - allServers := storage.GetServers() allSettings, err := source.Read() if err != nil { return err } - if err = allSettings.Validate(allServers); err != nil { + if err = allSettings.Validate(storage); err != nil { return err } - providerConf := provider.New(*allSettings.VPN.Provider.Name, allServers, time.Now) + providerConf := provider.New(*allSettings.VPN.Provider.Name, storage, time.Now) connection, err := providerConf.GetConnection(allSettings.VPN.Provider.ServerSelection) if err != nil { return err diff --git a/internal/cli/update.go b/internal/cli/update.go index acc04b38..977cbcee 100644 --- a/internal/cli/update.go +++ b/internal/cli/update.go @@ -2,20 +2,17 @@ package cli import ( "context" - "encoding/json" "errors" "flag" "fmt" "net" "net/http" - "os" "strings" "time" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/updater" ) @@ -83,41 +80,19 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e if err != nil { return fmt.Errorf("cannot create servers storage: %w", err) } - currentServers := storage.GetServers() - updater := updater.New(options, httpClient, currentServers, logger) - allServers, err := updater.UpdateServers(ctx) + updater := updater.New(options, httpClient, storage, logger) + err = updater.UpdateServers(ctx) if err != nil { return fmt.Errorf("cannot update server information: %w", err) } - if endUserMode { - if err := storage.FlushToFile(&allServers); err != nil { - return fmt.Errorf("cannot write updated information to file: %w", err) - } - } - if maintainerMode { - if err := writeToEmbeddedJSON(c.repoServersPath, &allServers); err != nil { - return fmt.Errorf("cannot write updated information to file: %w", err) + err := storage.FlushToFile(c.repoServersPath) + if err != nil { + return fmt.Errorf("cannot write servers data to embedded JSON file: %w", err) } } return nil } - -func writeToEmbeddedJSON(repoServersPath string, - allServers *models.AllServers) error { - const perms = 0600 - f, err := os.OpenFile(repoServersPath, - os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms) - if err != nil { - return err - } - - defer f.Close() - - encoder := json.NewEncoder(f) - encoder.SetIndent("", " ") - return encoder.Encode(allServers) -} diff --git a/internal/configuration/settings/provider.go b/internal/configuration/settings/provider.go index dfd41076..9d217420 100644 --- a/internal/configuration/settings/provider.go +++ b/internal/configuration/settings/provider.go @@ -6,7 +6,6 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" - "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gotree" ) @@ -23,7 +22,7 @@ type Provider struct { } // TODO v4 remove pointer for receiver (because of Surfshark). -func (p *Provider) validate(vpnType string, allServers models.AllServers) (err error) { +func (p *Provider) validate(vpnType string, storage Storage) (err error) { // Validate Name var validNames []string if vpnType == vpn.OpenVPN { @@ -42,7 +41,7 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e ErrVPNProviderNameNotValid, *p.Name, helpers.ChoicesOrString(validNames)) } - err = p.ServerSelection.validate(*p.Name, allServers) + err = p.ServerSelection.validate(*p.Name, storage) if err != nil { return fmt.Errorf("server selection: %w", err) } diff --git a/internal/configuration/settings/serverselection.go b/internal/configuration/settings/serverselection.go index 4dceb6d7..adc6f051 100644 --- a/internal/configuration/settings/serverselection.go +++ b/internal/configuration/settings/serverselection.go @@ -68,21 +68,19 @@ var ( ) func (ss *ServerSelection) validate(vpnServiceProvider string, - allServers models.AllServers) (err error) { + storage Storage) (err error) { switch ss.VPN { case vpn.OpenVPN, vpn.Wireguard: default: return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN) } - countryChoices, regionChoices, cityChoices, - ispChoices, nameChoices, hostnameChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, allServers) + filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage) if err != nil { return err // already wrapped error } - err = validateServerFilters(*ss, countryChoices, regionChoices, cityChoices, - ispChoices, nameChoices, hostnameChoices) + err = validateServerFilters(*ss, filterChoices) if err != nil { if errors.Is(err, helpers.ErrNoChoice) { return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err) @@ -135,63 +133,48 @@ func (ss *ServerSelection) validate(vpnServiceProvider string, return nil } -func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection, - allServers models.AllServers) ( - countryChoices, regionChoices, cityChoices, - ispChoices, nameChoices, hostnameChoices []string, +func getLocationFilterChoices(vpnServiceProvider string, + ss *ServerSelection, storage Storage) (filterChoices models.FilterChoices, err error) { - providerServers, ok := allServers.ProviderToServers[vpnServiceProvider] - if !ok && vpnServiceProvider != providers.Custom { - panic(fmt.Sprintf("VPN service provider unknown: %s", vpnServiceProvider)) - } - servers := providerServers.Servers - countryChoices = validation.ExtractCountries(servers) - regionChoices = validation.ExtractRegions(servers) - cityChoices = validation.ExtractCities(servers) - ispChoices = validation.ExtractISPs(servers) - nameChoices = validation.ExtractServerNames(servers) - hostnameChoices = validation.ExtractHostnames(servers) + filterChoices = storage.GetFilterChoices(vpnServiceProvider) if vpnServiceProvider == providers.Surfshark { // // Retro compatibility // TODO v4 remove - regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...) - if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil { - return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err) + filterChoices.Regions = append(filterChoices.Regions, validation.SurfsharkRetroLocChoices()...) + if err := helpers.AreAllOneOf(ss.Regions, filterChoices.Regions); err != nil { + return models.FilterChoices{}, fmt.Errorf("%w: %s", ErrRegionNotValid, err) } *ss = surfsharkRetroRegion(*ss) } - return countryChoices, regionChoices, cityChoices, - ispChoices, nameChoices, hostnameChoices, nil + return filterChoices, nil } // validateServerFilters validates filters against the choices given as arguments. // Set an argument to nil to pass the check for a particular filter. -func validateServerFilters(settings ServerSelection, - countryChoices, regionChoices, cityChoices, ispChoices, - nameChoices, hostnameChoices []string) (err error) { - if err := helpers.AreAllOneOf(settings.Countries, countryChoices); err != nil { +func validateServerFilters(settings ServerSelection, filterChoices models.FilterChoices) (err error) { + if err := helpers.AreAllOneOf(settings.Countries, filterChoices.Countries); err != nil { return fmt.Errorf("%w: %s", ErrCountryNotValid, err) } - if err := helpers.AreAllOneOf(settings.Regions, regionChoices); err != nil { + if err := helpers.AreAllOneOf(settings.Regions, filterChoices.Regions); err != nil { return fmt.Errorf("%w: %s", ErrRegionNotValid, err) } - if err := helpers.AreAllOneOf(settings.Cities, cityChoices); err != nil { + if err := helpers.AreAllOneOf(settings.Cities, filterChoices.Cities); err != nil { return fmt.Errorf("%w: %s", ErrCityNotValid, err) } - if err := helpers.AreAllOneOf(settings.ISPs, ispChoices); err != nil { + if err := helpers.AreAllOneOf(settings.ISPs, filterChoices.ISPs); err != nil { return fmt.Errorf("%w: %s", ErrISPNotValid, err) } - if err := helpers.AreAllOneOf(settings.Hostnames, hostnameChoices); err != nil { + if err := helpers.AreAllOneOf(settings.Hostnames, filterChoices.Hostnames); err != nil { return fmt.Errorf("%w: %s", ErrHostnameNotValid, err) } - if err := helpers.AreAllOneOf(settings.Names, nameChoices); err != nil { + if err := helpers.AreAllOneOf(settings.Names, filterChoices.Names); err != nil { return fmt.Errorf("%w: %s", ErrNameNotValid, err) } diff --git a/internal/configuration/settings/settings.go b/internal/configuration/settings/settings.go index 2b4c2939..ba35927f 100644 --- a/internal/configuration/settings/settings.go +++ b/internal/configuration/settings/settings.go @@ -24,10 +24,14 @@ type Settings struct { Pprof pprof.Settings } +type Storage interface { + GetFilterChoices(provider string) models.FilterChoices +} + // Validate validates all the settings and returns an error // if one of them is not valid. // TODO v4 remove pointer for receiver (because of Surfshark). -func (s *Settings) Validate(allServers models.AllServers) (err error) { +func (s *Settings) Validate(storage Storage) (err error) { nameToValidation := map[string]func() error{ "control server": s.ControlServer.validate, "dns": s.DNS.validate, @@ -42,7 +46,7 @@ func (s *Settings) Validate(allServers models.AllServers) (err error) { "version": s.Version.validate, // Pprof validation done in pprof constructor "VPN": func() error { - return s.VPN.validate(allServers) + return s.VPN.validate(storage) }, } @@ -91,7 +95,7 @@ func (s *Settings) MergeWith(other Settings) { } func (s *Settings) OverrideWith(other Settings, - allServers models.AllServers) (err error) { + storage Storage) (err error) { patchedSettings := s.copy() patchedSettings.ControlServer.overrideWith(other.ControlServer) patchedSettings.DNS.overrideWith(other.DNS) @@ -106,7 +110,7 @@ func (s *Settings) OverrideWith(other Settings, patchedSettings.Version.overrideWith(other.Version) patchedSettings.VPN.overrideWith(other.VPN) patchedSettings.Pprof.MergeWith(other.Pprof) - err = patchedSettings.Validate(allServers) + err = patchedSettings.Validate(storage) if err != nil { return err } diff --git a/internal/configuration/settings/updater.go b/internal/configuration/settings/updater.go index 82db7848..6807695e 100644 --- a/internal/configuration/settings/updater.go +++ b/internal/configuration/settings/updater.go @@ -35,17 +35,18 @@ func (u Updater) Validate() (err error) { ErrUpdaterPeriodTooSmall, *u.Period, minPeriod) } - for i, provider := range u.Providers { + validProviders := providers.All() + for _, provider := range u.Providers { valid := false - for _, validProvider := range providers.All() { + for _, validProvider := range validProviders { if provider == validProvider { valid = true break } } if !valid { - return fmt.Errorf("%w: %s at index %d", - ErrVPNProviderNameNotValid, provider, i) + return fmt.Errorf("%w: %q can only be one of %s", + ErrVPNProviderNameNotValid, provider, helpers.ChoicesOrString(validProviders)) } } diff --git a/internal/configuration/settings/vpn.go b/internal/configuration/settings/vpn.go index fb8cfe2d..5923c5a6 100644 --- a/internal/configuration/settings/vpn.go +++ b/internal/configuration/settings/vpn.go @@ -6,7 +6,6 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/constants/vpn" - "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gotree" ) @@ -21,7 +20,7 @@ type VPN struct { } // TODO v4 remove pointer for receiver (because of Surfshark). -func (v *VPN) validate(allServers models.AllServers) (err error) { +func (v *VPN) validate(storage Storage) (err error) { // Validate Type validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard} if !helpers.IsOneOf(v.Type, validVPNTypes...) { @@ -29,7 +28,7 @@ func (v *VPN) validate(allServers models.AllServers) (err error) { ErrVPNTypeNotValid, v.Type, strings.Join(validVPNTypes, ", ")) } - err = v.Provider.validate(v.Type, allServers) + err = v.Provider.validate(v.Type, storage) if err != nil { return fmt.Errorf("provider settings: %w", err) } diff --git a/internal/models/filters.go b/internal/models/filters.go new file mode 100644 index 00000000..8a7512ac --- /dev/null +++ b/internal/models/filters.go @@ -0,0 +1,10 @@ +package models + +type FilterChoices struct { + Countries []string + Regions []string + Cities []string + ISPs []string + Names []string + Hostnames []string +} diff --git a/internal/models/getservers.go b/internal/models/getservers.go deleted file mode 100644 index 023347ee..00000000 --- a/internal/models/getservers.go +++ /dev/null @@ -1,51 +0,0 @@ -package models - -import ( - "net" -) - -func (a AllServers) GetCopy() (allServersCopy AllServers) { - allServersCopy.Version = a.Version - allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers)) - for provider, servers := range a.ProviderToServers { - allServersCopy.ProviderToServers[provider] = Servers{ - Version: servers.Version, - Timestamp: servers.Timestamp, - Servers: copyServers(servers.Servers), - } - } - return allServersCopy -} - -func copyServers(servers []Server) (serversCopy []Server) { - if servers == nil { - return nil - } - - serversCopy = make([]Server, len(servers)) - for i, server := range servers { - serversCopy[i] = server - serversCopy[i].IPs = copyIPs(server.IPs) - } - - return serversCopy -} - -func copyIPs(toCopy []net.IP) (copied []net.IP) { - if toCopy == nil { - return nil - } - - copied = make([]net.IP, len(toCopy)) - for i := range toCopy { - copied[i] = copyIP(toCopy[i]) - } - - return copied -} - -func copyIP(toCopy net.IP) (copied net.IP) { - copied = make(net.IP, len(toCopy)) - copy(copied, toCopy) - return copied -} diff --git a/internal/models/getservers_test.go b/internal/models/getservers_test.go deleted file mode 100644 index 1381803e..00000000 --- a/internal/models/getservers_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package models - -import ( - "net" - "testing" - - "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func Test_AllServers_GetCopy(t *testing.T) { - allServers := AllServers{ - Version: 1, - ProviderToServers: map[string]Servers{ - providers.Cyberghost: { - Version: 2, - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Expressvpn: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Fastestvpn: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.HideMyAss: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Ipvanish: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Ivpn: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Mullvad: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Nordvpn: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Perfectprivacy: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Privado: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.PrivateInternetAccess: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Privatevpn: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Protonvpn: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Purevpn: { - Version: 1, - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Surfshark: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Torguard: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.VPNUnlimited: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Vyprvpn: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Wevpn: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - providers.Windscribe: { - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - }, - } - - servers := allServers.GetCopy() - - assert.Equal(t, allServers, servers) -} - -func Test_copyIPs(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - toCopy []net.IP - copied []net.IP - }{ - "nil": {}, - "empty": { - toCopy: []net.IP{}, - copied: []net.IP{}, - }, - "single IP": { - toCopy: []net.IP{{1, 1, 1, 1}}, - copied: []net.IP{{1, 1, 1, 1}}, - }, - "two IPs": { - toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}}, - copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}}, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - // Reserver leading 9 for copy modifications below - for _, ipToCopy := range testCase.toCopy { - require.NotEqual(t, 9, ipToCopy[0]) - } - - copied := copyIPs(testCase.toCopy) - - assert.Equal(t, testCase.copied, copied) - - if len(copied) > 0 { - original := testCase.toCopy[0][0] - testCase.toCopy[0][0] = 9 - assert.NotEqual(t, 9, copied[0][0]) - testCase.toCopy[0][0] = original - - copied[0][0] = 9 - assert.NotEqual(t, 9, testCase.toCopy[0][0]) - } - }) - } -} diff --git a/internal/models/server.go b/internal/models/server.go index 7b1e41f8..de92ba37 100644 --- a/internal/models/server.go +++ b/internal/models/server.go @@ -2,6 +2,7 @@ package models import ( "net" + "reflect" ) type Server struct { @@ -26,3 +27,28 @@ type Server struct { PortForward bool `json:"port_forward,omitempty"` IPs []net.IP `json:"ips,omitempty"` } + +func (s *Server) Equal(other Server) (equal bool) { + if !ipsAreEqual(s.IPs, other.IPs) { + return false + } + + serverCopy := *s + serverCopy.IPs = nil + other.IPs = nil + return reflect.DeepEqual(serverCopy, other) +} + +func ipsAreEqual(a, b []net.IP) (equal bool) { + if len(a) != len(b) { + return false + } + + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + + return true +} diff --git a/internal/models/server_test.go b/internal/models/server_test.go new file mode 100644 index 00000000..98daa516 --- /dev/null +++ b/internal/models/server_test.go @@ -0,0 +1,120 @@ +package models + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Server_Equal(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + a *Server + b Server + equal bool + }{ + "same IPs": { + a: &Server{ + IPs: []net.IP{net.IPv4(1, 2, 3, 4)}, + }, + b: Server{ + IPs: []net.IP{net.IPv4(1, 2, 3, 4)}, + }, + equal: true, + }, + "same IP strings": { + a: &Server{ + IPs: []net.IP{net.IPv4(1, 2, 3, 4)}, + }, + b: Server{ + IPs: []net.IP{{1, 2, 3, 4}}, + }, + equal: true, + }, + "different IPs": { + a: &Server{ + IPs: []net.IP{{1, 2, 3, 4}, {2, 3, 4, 5}}, + }, + b: Server{ + IPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 4}}, + }, + }, + "all fields equal": { + a: &Server{ + VPN: "vpn", + Country: "country", + Region: "region", + City: "city", + ISP: "isp", + Owned: true, + Number: 1, + ServerName: "server_name", + Hostname: "hostname", + TCP: true, + UDP: true, + OvpnX509: "x509", + RetroLoc: "retroloc", + MultiHop: true, + WgPubKey: "wgpubkey", + Free: true, + Stream: true, + PortForward: true, + IPs: []net.IP{net.IPv4(1, 2, 3, 4)}, + }, + b: Server{ + VPN: "vpn", + Country: "country", + Region: "region", + City: "city", + ISP: "isp", + Owned: true, + Number: 1, + ServerName: "server_name", + Hostname: "hostname", + TCP: true, + UDP: true, + OvpnX509: "x509", + RetroLoc: "retroloc", + MultiHop: true, + WgPubKey: "wgpubkey", + Free: true, + Stream: true, + PortForward: true, + IPs: []net.IP{net.IPv4(1, 2, 3, 4)}, + }, + equal: true, + }, + "different field": { + a: &Server{ + VPN: "vpn", + }, + b: Server{ + VPN: "other vpn", + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + ipsOfANotNil := testCase.a.IPs != nil + ipsOfBNotNil := testCase.b.IPs != nil + + equal := testCase.a.Equal(testCase.b) + + assert.Equal(t, testCase.equal, equal) + + // Ensure IPs field is not modified + if ipsOfANotNil { + assert.NotNil(t, testCase.a) + } + if ipsOfBNotNil { + assert.NotNil(t, testCase.b) + } + }) + } +} diff --git a/internal/models/servers.go b/internal/models/servers.go index 2297881a..10ae4dfe 100644 --- a/internal/models/servers.go +++ b/internal/models/servers.go @@ -15,18 +15,6 @@ type AllServers struct { ProviderToServers map[string]Servers } -func (a *AllServers) ServersSlice(provider string) []Server { - if provider == providers.Custom { - return nil - } - - servers, ok := a.ProviderToServers[provider] - if !ok { - panic(fmt.Sprintf("provider %s not found in all servers", provider)) - } - return copyServers(servers.Servers) -} - var _ json.Marshaler = (*AllServers)(nil) // MarshalJSON marshals all servers to JSON. diff --git a/internal/provider/common/mocks.go b/internal/provider/common/mocks.go new file mode 100644 index 00000000..3c90d19c --- /dev/null +++ b/internal/provider/common/mocks.go @@ -0,0 +1,66 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/provider/common (interfaces: Storage) + +// Package common is a generated GoMock package. +package common + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + settings "github.com/qdm12/gluetun/internal/configuration/settings" + models "github.com/qdm12/gluetun/internal/models" +) + +// MockStorage is a mock of Storage interface. +type MockStorage struct { + ctrl *gomock.Controller + recorder *MockStorageMockRecorder +} + +// MockStorageMockRecorder is the mock recorder for MockStorage. +type MockStorageMockRecorder struct { + mock *MockStorage +} + +// NewMockStorage creates a new mock instance. +func NewMockStorage(ctrl *gomock.Controller) *MockStorage { + mock := &MockStorage{ctrl: ctrl} + mock.recorder = &MockStorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStorage) EXPECT() *MockStorageMockRecorder { + return m.recorder +} + +// FilterServers mocks base method. +func (m *MockStorage) FilterServers(arg0 string, arg1 settings.ServerSelection) ([]models.Server, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FilterServers", arg0, arg1) + ret0, _ := ret[0].([]models.Server) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FilterServers indicates an expected call of FilterServers. +func (mr *MockStorageMockRecorder) FilterServers(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterServers", reflect.TypeOf((*MockStorage)(nil).FilterServers), arg0, arg1) +} + +// GetServerByName mocks base method. +func (m *MockStorage) GetServerByName(arg0, arg1 string) (models.Server, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServerByName", arg0, arg1) + ret0, _ := ret[0].(models.Server) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetServerByName indicates an expected call of GetServerByName. +func (mr *MockStorageMockRecorder) GetServerByName(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerByName", reflect.TypeOf((*MockStorage)(nil).GetServerByName), arg0, arg1) +} diff --git a/internal/provider/common/mocks_generate_test.go b/internal/provider/common/mocks_generate_test.go new file mode 100644 index 00000000..bf15d1e7 --- /dev/null +++ b/internal/provider/common/mocks_generate_test.go @@ -0,0 +1,5 @@ +package common + +// Exceptionally, the storage mock is exported since it is used by all +// provider subpackages tests, and it reduces test code duplication a lot. +//go:generate mockgen -destination=mocks.go -package $GOPACKAGE . Storage diff --git a/internal/provider/common/storage.go b/internal/provider/common/storage.go new file mode 100644 index 00000000..37517e0d --- /dev/null +++ b/internal/provider/common/storage.go @@ -0,0 +1,12 @@ +package common + +import ( + "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/models" +) + +type Storage interface { + FilterServers(provider string, selection settings.ServerSelection) ( + servers []models.Server, err error) + GetServerByName(provider, name string) (server models.Server, ok bool) +} diff --git a/internal/provider/cyberghost/connection.go b/internal/provider/cyberghost/connection.go index 175d1237..8b13a80a 100644 --- a/internal/provider/cyberghost/connection.go +++ b/internal/provider/cyberghost/connection.go @@ -2,6 +2,7 @@ package cyberghost import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Cyberghost, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/cyberghost/provider.go b/internal/provider/cyberghost/provider.go index ab1cbf8e..401f6d5c 100644 --- a/internal/provider/cyberghost/provider.go +++ b/internal/provider/cyberghost/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Cyberghost), } diff --git a/internal/provider/expressvpn/connection.go b/internal/provider/expressvpn/connection.go index 8c02d68e..0e72c048 100644 --- a/internal/provider/expressvpn/connection.go +++ b/internal/provider/expressvpn/connection.go @@ -2,6 +2,7 @@ package expressvpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(0, 1195, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Expressvpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/expressvpn/connection_test.go b/internal/provider/expressvpn/connection_test.go index d0e71413..f04b5f83 100644 --- a/internal/provider/expressvpn/connection_test.go +++ b/internal/provider/expressvpn/connection_test.go @@ -1,41 +1,63 @@ package expressvpn import ( + "errors" "math/rand" "net" "testing" + "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/provider/utils" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/stretchr/testify/assert" ) func Test_Provider_GetConnection(t *testing.T) { t.Parallel() + const provider = providers.Expressvpn + + errTest := errors.New("test error") + boolPtr := func(b bool) *bool { return &b } + testCases := map[string]struct { - servers []models.Server - selection settings.ServerSelection - connection models.Connection - errWrapped error - errMessage string + filteredServers []models.Server + storageErr error + selection settings.ServerSelection + connection models.Connection + errWrapped error + errMessage string + panicMessage string }{ - "no server": { - selection: settings.ServerSelection{}.WithDefaults(providers.Expressvpn), - errWrapped: utils.ErrNoServer, - errMessage: "no server", + "error": { + storageErr: errTest, + errWrapped: errTest, + errMessage: "cannot filter servers: test error", }, - "no filter": { - servers: []models.Server{ - {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, + "default OpenVPN TCP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, }, - selection: settings.ServerSelection{}.WithDefaults(providers.Expressvpn), + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(true), + }, + }.WithDefaults(provider), + panicMessage: "no default OpenVPN TCP port is defined!", + }, + "default OpenVPN UDP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(false), + }, + }.WithDefaults(provider), connection: models.Connection{ Type: vpn.OpenVPN, IP: net.IPv4(1, 1, 1, 1), @@ -43,38 +65,14 @@ func Test_Provider_GetConnection(t *testing.T) { Protocol: constants.UDP, }, }, - "target IP": { + "default Wireguard port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, selection: settings.ServerSelection{ - TargetIP: net.IPv4(2, 2, 2, 2), - }.WithDefaults(providers.Expressvpn), - servers: []models.Server{ - {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, - }, - connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1195, - Protocol: constants.UDP, - }, - }, - "with filter": { - selection: settings.ServerSelection{ - Hostnames: []string{"b"}, - }.WithDefaults(providers.Expressvpn), - servers: []models.Server{ - {Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, - }, - connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1195, - Protocol: constants.UDP, - Hostname: "b", - }, + VPN: vpn.Wireguard, + }.WithDefaults(provider), + panicMessage: "no default Wireguard port is defined!", }, } @@ -82,12 +80,23 @@ func Test_Provider_GetConnection(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + storage := common.NewMockStorage(ctrl) + storage.EXPECT().FilterServers(provider, testCase.selection). + Return(testCase.filteredServers, testCase.storageErr) randSource := rand.NewSource(0) - m := New(testCase.servers, randSource) + provider := New(storage, randSource) - connection, err := m.GetConnection(testCase.selection) + if testCase.panicMessage != "" { + assert.PanicsWithValue(t, testCase.panicMessage, func() { + _, _ = provider.GetConnection(testCase.selection) + }) + return + } + + connection, err := provider.GetConnection(testCase.selection) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { diff --git a/internal/provider/expressvpn/provider.go b/internal/provider/expressvpn/provider.go index 805e3f6f..56a8d50e 100644 --- a/internal/provider/expressvpn/provider.go +++ b/internal/provider/expressvpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Expressvpn), } diff --git a/internal/provider/fastestvpn/connection.go b/internal/provider/fastestvpn/connection.go index 4a339d98..75e62574 100644 --- a/internal/provider/fastestvpn/connection.go +++ b/internal/provider/fastestvpn/connection.go @@ -2,6 +2,7 @@ package fastestvpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(4443, 4443, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Fastestvpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/fastestvpn/provider.go b/internal/provider/fastestvpn/provider.go index b52c2479..e6c20f74 100644 --- a/internal/provider/fastestvpn/provider.go +++ b/internal/provider/fastestvpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Fastestvpn), } diff --git a/internal/provider/hidemyass/connection.go b/internal/provider/hidemyass/connection.go index c6d7a1a0..f131053f 100644 --- a/internal/provider/hidemyass/connection.go +++ b/internal/provider/hidemyass/connection.go @@ -2,6 +2,7 @@ package hidemyass import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(8080, 553, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.HideMyAss, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/hidemyass/provider.go b/internal/provider/hidemyass/provider.go index e744927d..b217cdb6 100644 --- a/internal/provider/hidemyass/provider.go +++ b/internal/provider/hidemyass/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.HideMyAss), } diff --git a/internal/provider/ipvanish/connection.go b/internal/provider/ipvanish/connection.go index 162f2af8..deab75d8 100644 --- a/internal/provider/ipvanish/connection.go +++ b/internal/provider/ipvanish/connection.go @@ -2,6 +2,7 @@ package ipvanish import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Ipvanish, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/ipvanish/provider.go b/internal/provider/ipvanish/provider.go index beea984a..334af132 100644 --- a/internal/provider/ipvanish/provider.go +++ b/internal/provider/ipvanish/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Ipvanish), } diff --git a/internal/provider/ivpn/connection.go b/internal/provider/ivpn/connection.go index 1ce0f826..95c19c3c 100644 --- a/internal/provider/ivpn/connection.go +++ b/internal/provider/ivpn/connection.go @@ -2,6 +2,7 @@ package ivpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(443, 1194, 58237) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Ivpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/ivpn/connection_test.go b/internal/provider/ivpn/connection_test.go index 2db0ac22..a89d2fc5 100644 --- a/internal/provider/ivpn/connection_test.go +++ b/internal/provider/ivpn/connection_test.go @@ -1,41 +1,67 @@ package ivpn import ( + "errors" "math/rand" "net" "testing" + "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/provider/utils" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/stretchr/testify/assert" ) func Test_Provider_GetConnection(t *testing.T) { t.Parallel() + const provider = providers.Ivpn + + errTest := errors.New("test error") + boolPtr := func(b bool) *bool { return &b } + testCases := map[string]struct { - servers []models.Server - selection settings.ServerSelection - connection models.Connection - errWrapped error - errMessage string + filteredServers []models.Server + storageErr error + selection settings.ServerSelection + connection models.Connection + errWrapped error + errMessage string }{ - "no server available": { - selection: settings.ServerSelection{}.WithDefaults(providers.Ivpn), - errWrapped: utils.ErrNoServer, - errMessage: "no server", + "error": { + storageErr: errTest, + errWrapped: errTest, + errMessage: "cannot filter servers: test error", }, - "no filter": { - servers: []models.Server{ - {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, + "default OpenVPN TCP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, }, - selection: settings.ServerSelection{}.WithDefaults(providers.Ivpn), + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(true), + }, + }.WithDefaults(provider), + connection: models.Connection{ + Type: vpn.OpenVPN, + IP: net.IPv4(1, 1, 1, 1), + Port: 443, + Protocol: constants.TCP, + }, + }, + "default OpenVPN UDP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(false), + }, + }.WithDefaults(provider), connection: models.Connection{ Type: vpn.OpenVPN, IP: net.IPv4(1, 1, 1, 1), @@ -43,51 +69,36 @@ func Test_Provider_GetConnection(t *testing.T) { Protocol: constants.UDP, }, }, - "target IP": { + "default Wireguard port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, selection: settings.ServerSelection{ - TargetIP: net.IPv4(2, 2, 2, 2), - }.WithDefaults(providers.Ivpn), - servers: []models.Server{ - {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, - }, + VPN: vpn.Wireguard, + }.WithDefaults(provider), connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1194, + Type: vpn.Wireguard, + IP: net.IPv4(1, 1, 1, 1), + Port: 58237, Protocol: constants.UDP, }, }, - "with filter": { - selection: settings.ServerSelection{ - Hostnames: []string{"b"}, - }.WithDefaults(providers.Ivpn), - servers: []models.Server{ - {Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, - }, - connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1194, - Protocol: constants.UDP, - Hostname: "b", - }, - }, } for name, testCase := range testCases { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + storage := common.NewMockStorage(ctrl) + storage.EXPECT().FilterServers(provider, testCase.selection). + Return(testCase.filteredServers, testCase.storageErr) randSource := rand.NewSource(0) - m := New(testCase.servers, randSource) + provider := New(storage, randSource) - connection, err := m.GetConnection(testCase.selection) + connection, err := provider.GetConnection(testCase.selection) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { diff --git a/internal/provider/ivpn/provider.go b/internal/provider/ivpn/provider.go index 48f61390..de51ebf4 100644 --- a/internal/provider/ivpn/provider.go +++ b/internal/provider/ivpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Ivpn), } diff --git a/internal/provider/mullvad/connection.go b/internal/provider/mullvad/connection.go index ecb1e086..74eb97a3 100644 --- a/internal/provider/mullvad/connection.go +++ b/internal/provider/mullvad/connection.go @@ -2,6 +2,7 @@ package mullvad import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(443, 1194, 51820) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Mullvad, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/mullvad/connection_test.go b/internal/provider/mullvad/connection_test.go index 38b2af16..c70561db 100644 --- a/internal/provider/mullvad/connection_test.go +++ b/internal/provider/mullvad/connection_test.go @@ -1,41 +1,67 @@ package mullvad import ( + "errors" "math/rand" "net" "testing" + "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/provider/utils" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/stretchr/testify/assert" ) func Test_Provider_GetConnection(t *testing.T) { t.Parallel() + const provider = providers.Mullvad + + errTest := errors.New("test error") + boolPtr := func(b bool) *bool { return &b } + testCases := map[string]struct { - servers []models.Server - selection settings.ServerSelection - connection models.Connection - errWrapped error - errMessage string + filteredServers []models.Server + storageErr error + selection settings.ServerSelection + connection models.Connection + errWrapped error + errMessage string }{ - "no server available": { - selection: settings.ServerSelection{}.WithDefaults(providers.Mullvad), - errWrapped: utils.ErrNoServer, - errMessage: "no server", + "error": { + storageErr: errTest, + errWrapped: errTest, + errMessage: "cannot filter servers: test error", }, - "no filter": { - servers: []models.Server{ - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}}, - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}}, + "default OpenVPN TCP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, }, - selection: settings.ServerSelection{}.WithDefaults(providers.Mullvad), + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(true), + }, + }.WithDefaults(provider), + connection: models.Connection{ + Type: vpn.OpenVPN, + IP: net.IPv4(1, 1, 1, 1), + Port: 443, + Protocol: constants.TCP, + }, + }, + "default OpenVPN UDP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(false), + }, + }.WithDefaults(provider), connection: models.Connection{ Type: vpn.OpenVPN, IP: net.IPv4(1, 1, 1, 1), @@ -43,36 +69,17 @@ func Test_Provider_GetConnection(t *testing.T) { Protocol: constants.UDP, }, }, - "target IP": { + "default Wireguard port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, selection: settings.ServerSelection{ - TargetIP: net.IPv4(2, 2, 2, 2), - }.WithDefaults(providers.Mullvad), - servers: []models.Server{ - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}}, - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}}, - }, + VPN: vpn.Wireguard, + }.WithDefaults(provider), connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1194, - Protocol: constants.UDP, - }, - }, - "with filter": { - selection: settings.ServerSelection{ - Hostnames: []string{"b"}, - }.WithDefaults(providers.Mullvad), - servers: []models.Server{ - {VPN: vpn.OpenVPN, UDP: true, Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, - {VPN: vpn.OpenVPN, UDP: true, Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}}, - {VPN: vpn.OpenVPN, UDP: true, Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}}, - }, - connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1194, - Hostname: "b", + Type: vpn.Wireguard, + IP: net.IPv4(1, 1, 1, 1), + Port: 51820, Protocol: constants.UDP, }, }, @@ -82,12 +89,16 @@ func Test_Provider_GetConnection(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + storage := common.NewMockStorage(ctrl) + storage.EXPECT().FilterServers(provider, testCase.selection). + Return(testCase.filteredServers, testCase.storageErr) randSource := rand.NewSource(0) - m := New(testCase.servers, randSource) + provider := New(storage, randSource) - connection, err := m.GetConnection(testCase.selection) + connection, err := provider.GetConnection(testCase.selection) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { diff --git a/internal/provider/mullvad/provider.go b/internal/provider/mullvad/provider.go index bd274d7c..e10a6417 100644 --- a/internal/provider/mullvad/provider.go +++ b/internal/provider/mullvad/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Mullvad), } diff --git a/internal/provider/nordvpn/connection.go b/internal/provider/nordvpn/connection.go index 3175ef98..d431634c 100644 --- a/internal/provider/nordvpn/connection.go +++ b/internal/provider/nordvpn/connection.go @@ -2,6 +2,7 @@ package nordvpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Nordvpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/nordvpn/provider.go b/internal/provider/nordvpn/provider.go index e68dda2b..72d27a80 100644 --- a/internal/provider/nordvpn/provider.go +++ b/internal/provider/nordvpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Nordvpn), } diff --git a/internal/provider/perfectprivacy/connection.go b/internal/provider/perfectprivacy/connection.go index 6d6544e0..85265b18 100644 --- a/internal/provider/perfectprivacy/connection.go +++ b/internal/provider/perfectprivacy/connection.go @@ -2,6 +2,7 @@ package perfectprivacy import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Perfectprivacy, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/perfectprivacy/provider.go b/internal/provider/perfectprivacy/provider.go index 85cf6d8a..b604f1b9 100644 --- a/internal/provider/perfectprivacy/provider.go +++ b/internal/provider/perfectprivacy/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Perfectprivacy), } diff --git a/internal/provider/privado/connection.go b/internal/provider/privado/connection.go index bbc41cd5..5529c737 100644 --- a/internal/provider/privado/connection.go +++ b/internal/provider/privado/connection.go @@ -2,6 +2,7 @@ package privado import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Privado, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/privado/provider.go b/internal/provider/privado/provider.go index a736fcd3..2a4ad909 100644 --- a/internal/provider/privado/provider.go +++ b/internal/provider/privado/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Privado), } diff --git a/internal/provider/privateinternetaccess/connection.go b/internal/provider/privateinternetaccess/connection.go index 97c9891e..4326bf3c 100644 --- a/internal/provider/privateinternetaccess/connection.go +++ b/internal/provider/privateinternetaccess/connection.go @@ -2,6 +2,7 @@ package privateinternetaccess import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/privateinternetaccess/presets" "github.com/qdm12/gluetun/internal/provider/utils" @@ -20,5 +21,6 @@ func (p *Provider) GetConnection(selection settings.ServerSelection) ( defaults.OpenVPNUDPPort = 1197 } - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.PrivateInternetAccess, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 3d0d3f9b..eb213845 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -15,25 +15,24 @@ import ( "strings" "time" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/golibs/format" ) var ( - ErrGatewayIPIsNil = errors.New("gateway IP address is nil") - ErrServerNameEmpty = errors.New("server name is empty") + ErrServerNameNotFound = errors.New("server name not found in servers") + ErrGatewayIPIsNil = errors.New("gateway IP address is nil") + ErrServerNameEmpty = errors.New("server name is empty") ) // PortForward obtains a VPN server side port forwarded from PIA. func (p *Provider) PortForward(ctx context.Context, client *http.Client, logger utils.Logger, gateway net.IP, serverName string) ( port uint16, err error) { - var server models.Server - for _, server = range p.servers { - if server.ServerName == serverName { - break - } + server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName) + if !ok { + return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName) } if !server.PortForward { diff --git a/internal/provider/privateinternetaccess/provider.go b/internal/provider/privateinternetaccess/provider.go index 8fc56e8c..9a8c90eb 100644 --- a/internal/provider/privateinternetaccess/provider.go +++ b/internal/provider/privateinternetaccess/provider.go @@ -5,11 +5,11 @@ import ( "time" "github.com/qdm12/gluetun/internal/constants/openvpn" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source timeNow func() time.Time // Port forwarding @@ -17,11 +17,11 @@ type Provider struct { authFilePath string } -func New(servers []models.Server, randSource rand.Source, +func New(storage common.Storage, randSource rand.Source, timeNow func() time.Time) *Provider { const jsonPortForwardPath = "/gluetun/piaportforward.json" return &Provider{ - servers: servers, + storage: storage, timeNow: timeNow, randSource: randSource, portForwardPath: jsonPortForwardPath, diff --git a/internal/provider/privatevpn/connection.go b/internal/provider/privatevpn/connection.go index 8b464de5..608b16eb 100644 --- a/internal/provider/privatevpn/connection.go +++ b/internal/provider/privatevpn/connection.go @@ -2,6 +2,7 @@ package privatevpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Privatevpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/privatevpn/provider.go b/internal/provider/privatevpn/provider.go index c3c4f880..f4476264 100644 --- a/internal/provider/privatevpn/provider.go +++ b/internal/provider/privatevpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Privatevpn), } diff --git a/internal/provider/protonvpn/connection.go b/internal/provider/protonvpn/connection.go index 8dc65c91..4ace8ab4 100644 --- a/internal/provider/protonvpn/connection.go +++ b/internal/provider/protonvpn/connection.go @@ -2,6 +2,7 @@ package protonvpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Protonvpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/protonvpn/provider.go b/internal/provider/protonvpn/provider.go index 2ba2d3f4..671aa4c2 100644 --- a/internal/provider/protonvpn/provider.go +++ b/internal/provider/protonvpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Protonvpn), } diff --git a/internal/provider/provider.go b/internal/provider/provider.go index a051dab6..54500570 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -50,52 +50,57 @@ type PortForwarder interface { port uint16, gateway net.IP, serverName string) (err error) } -func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider { - serversSlice := allServers.ServersSlice(provider) +type Storage interface { + FilterServers(provider string, selection settings.ServerSelection) ( + servers []models.Server, err error) + GetServerByName(provider, name string) (server models.Server, ok bool) +} + +func New(provider string, storage Storage, timeNow func() time.Time) Provider { randSource := rand.NewSource(timeNow().UnixNano()) switch provider { case providers.Custom: return custom.New() case providers.Cyberghost: - return cyberghost.New(serversSlice, randSource) + return cyberghost.New(storage, randSource) case providers.Expressvpn: - return expressvpn.New(serversSlice, randSource) + return expressvpn.New(storage, randSource) case providers.Fastestvpn: - return fastestvpn.New(serversSlice, randSource) + return fastestvpn.New(storage, randSource) case providers.HideMyAss: - return hidemyass.New(serversSlice, randSource) + return hidemyass.New(storage, randSource) case providers.Ipvanish: - return ipvanish.New(serversSlice, randSource) + return ipvanish.New(storage, randSource) case providers.Ivpn: - return ivpn.New(serversSlice, randSource) + return ivpn.New(storage, randSource) case providers.Mullvad: - return mullvad.New(serversSlice, randSource) + return mullvad.New(storage, randSource) case providers.Nordvpn: - return nordvpn.New(serversSlice, randSource) + return nordvpn.New(storage, randSource) case providers.Perfectprivacy: - return perfectprivacy.New(serversSlice, randSource) + return perfectprivacy.New(storage, randSource) case providers.Privado: - return privado.New(serversSlice, randSource) + return privado.New(storage, randSource) case providers.PrivateInternetAccess: - return privateinternetaccess.New(serversSlice, randSource, timeNow) + return privateinternetaccess.New(storage, randSource, timeNow) case providers.Privatevpn: - return privatevpn.New(serversSlice, randSource) + return privatevpn.New(storage, randSource) case providers.Protonvpn: - return protonvpn.New(serversSlice, randSource) + return protonvpn.New(storage, randSource) case providers.Purevpn: - return purevpn.New(serversSlice, randSource) + return purevpn.New(storage, randSource) case providers.Surfshark: - return surfshark.New(serversSlice, randSource) + return surfshark.New(storage, randSource) case providers.Torguard: - return torguard.New(serversSlice, randSource) + return torguard.New(storage, randSource) case providers.VPNUnlimited: - return vpnunlimited.New(serversSlice, randSource) + return vpnunlimited.New(storage, randSource) case providers.Vyprvpn: - return vyprvpn.New(serversSlice, randSource) + return vyprvpn.New(storage, randSource) case providers.Wevpn: - return wevpn.New(serversSlice, randSource) + return wevpn.New(storage, randSource) case providers.Windscribe: - return windscribe.New(serversSlice, randSource) + return windscribe.New(storage, randSource) default: panic("provider " + provider + " is unknown") // should never occur } diff --git a/internal/provider/purevpn/connection.go b/internal/provider/purevpn/connection.go index c49a54ba..508dfc1f 100644 --- a/internal/provider/purevpn/connection.go +++ b/internal/provider/purevpn/connection.go @@ -2,6 +2,7 @@ package purevpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(80, 53, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Purevpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/purevpn/provider.go b/internal/provider/purevpn/provider.go index abb9b78e..9d134fba 100644 --- a/internal/provider/purevpn/provider.go +++ b/internal/provider/purevpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Purevpn), } diff --git a/internal/provider/surfshark/connection.go b/internal/provider/surfshark/connection.go index 66f69f15..4f343598 100644 --- a/internal/provider/surfshark/connection.go +++ b/internal/provider/surfshark/connection.go @@ -2,6 +2,7 @@ package surfshark import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(1443, 1194, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Surfshark, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/surfshark/provider.go b/internal/provider/surfshark/provider.go index 6e9a765f..a65dcc41 100644 --- a/internal/provider/surfshark/provider.go +++ b/internal/provider/surfshark/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Surfshark), } diff --git a/internal/provider/torguard/connection.go b/internal/provider/torguard/connection.go index 6ebb4782..50d4d4c8 100644 --- a/internal/provider/torguard/connection.go +++ b/internal/provider/torguard/connection.go @@ -2,6 +2,7 @@ package torguard import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(1912, 1912, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Torguard, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/torguard/provider.go b/internal/provider/torguard/provider.go index f36806dc..42de29b8 100644 --- a/internal/provider/torguard/provider.go +++ b/internal/provider/torguard/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Torguard), } diff --git a/internal/provider/utils/connection.go b/internal/provider/utils/connection.go index 1e3853a5..a3deb4e1 100644 --- a/internal/provider/utils/connection.go +++ b/internal/provider/utils/connection.go @@ -1,7 +1,7 @@ package utils import ( - "errors" + "fmt" "math/rand" "github.com/qdm12/gluetun/internal/configuration/settings" @@ -24,20 +24,20 @@ func NewConnectionDefaults(openvpnTCPPort, openvpnUDPPort, } } -var ErrNoServer = errors.New("no server") +type Storage interface { + FilterServers(provider string, selection settings.ServerSelection) ( + servers []models.Server, err error) +} -func GetConnection(servers []models.Server, +func GetConnection(provider string, + storage Storage, selection settings.ServerSelection, defaults ConnectionDefaults, randSource rand.Source) ( connection models.Connection, err error) { - if len(servers) == 0 { - return connection, ErrNoServer - } - - servers = filterServers(servers, selection) - if len(servers) == 0 { - return connection, noServerFoundError(selection) + servers, err := storage.FilterServers(provider, selection) + if err != nil { + return connection, fmt.Errorf("cannot filter servers: %w", err) } protocol := getProtocol(selection) diff --git a/internal/provider/utils/connection_test.go b/internal/provider/utils/connection_test.go index 937fa7fa..d149f41d 100644 --- a/internal/provider/utils/connection_test.go +++ b/internal/provider/utils/connection_test.go @@ -1,23 +1,30 @@ package utils import ( + "errors" "math/rand" "net" "testing" + "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/stretchr/testify/assert" ) func Test_GetConnection(t *testing.T) { t.Parallel() + errTest := errors.New("test error") + testCases := map[string]struct { - servers []models.Server + provider string + filteredServers []models.Server + filterError error serverSelection settings.ServerSelection defaults ConnectionDefaults randSource rand.Source @@ -25,25 +32,13 @@ func Test_GetConnection(t *testing.T) { errWrapped error errMessage string }{ - "no server": { - serverSelection: settings.ServerSelection{}. - WithDefaults(providers.Mullvad), - errWrapped: ErrNoServer, - errMessage: "no server", - }, - "all servers filtered": { - servers: []models.Server{ - {VPN: vpn.Wireguard}, - {VPN: vpn.Wireguard}, - }, - serverSelection: settings.ServerSelection{ - VPN: vpn.OpenVPN, - }.WithDefaults(providers.Mullvad), - errWrapped: ErrNoServerFound, - errMessage: "no server found: for VPN openvpn; protocol udp", + "storage filter error": { + filterError: errTest, + errWrapped: errTest, + errMessage: "cannot filter servers: test error", }, "server without IPs": { - servers: []models.Server{ + filteredServers: []models.Server{ {VPN: vpn.OpenVPN, UDP: true}, {VPN: vpn.OpenVPN, UDP: true}, }, @@ -58,7 +53,7 @@ func Test_GetConnection(t *testing.T) { errMessage: "no connection to pick from", }, "OpenVPN server with hostname": { - servers: []models.Server{ + filteredServers: []models.Server{ { VPN: vpn.OpenVPN, UDP: true, @@ -79,7 +74,7 @@ func Test_GetConnection(t *testing.T) { }, }, "OpenVPN server with x509": { - servers: []models.Server{ + filteredServers: []models.Server{ { VPN: vpn.OpenVPN, UDP: true, @@ -101,7 +96,7 @@ func Test_GetConnection(t *testing.T) { }, }, "server with IPv4 and IPv6": { - servers: []models.Server{ + filteredServers: []models.Server{ { VPN: vpn.OpenVPN, UDP: true, @@ -128,7 +123,7 @@ func Test_GetConnection(t *testing.T) { }, }, "mixed servers": { - servers: []models.Server{ + filteredServers: []models.Server{ { VPN: vpn.OpenVPN, UDP: true, @@ -169,8 +164,14 @@ func Test_GetConnection(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) - connection, err := GetConnection(testCase.servers, + storage := common.NewMockStorage(ctrl) + storage.EXPECT(). + FilterServers(testCase.provider, testCase.serverSelection). + Return(testCase.filteredServers, testCase.filterError) + + connection, err := GetConnection(testCase.provider, storage, testCase.serverSelection, testCase.defaults, testCase.randSource) diff --git a/internal/provider/vpnunlimited/connection.go b/internal/provider/vpnunlimited/connection.go index e81f3136..3a98cee9 100644 --- a/internal/provider/vpnunlimited/connection.go +++ b/internal/provider/vpnunlimited/connection.go @@ -2,6 +2,7 @@ package vpnunlimited import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.VPNUnlimited, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/vpnunlimited/provider.go b/internal/provider/vpnunlimited/provider.go index 624aa6ab..a6dd922a 100644 --- a/internal/provider/vpnunlimited/provider.go +++ b/internal/provider/vpnunlimited/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.VPNUnlimited), } diff --git a/internal/provider/vyprvpn/connection.go b/internal/provider/vyprvpn/connection.go index 544055b6..fe8ae915 100644 --- a/internal/provider/vyprvpn/connection.go +++ b/internal/provider/vyprvpn/connection.go @@ -2,6 +2,7 @@ package vyprvpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Vyprvpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/vyprvpn/provider.go b/internal/provider/vyprvpn/provider.go index 6b37d68e..82535e3c 100644 --- a/internal/provider/vyprvpn/provider.go +++ b/internal/provider/vyprvpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Vyprvpn), } diff --git a/internal/provider/wevpn/connection.go b/internal/provider/wevpn/connection.go index cb235c57..b1c3c75a 100644 --- a/internal/provider/wevpn/connection.go +++ b/internal/provider/wevpn/connection.go @@ -2,6 +2,7 @@ package wevpn import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(1195, 1194, 0) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Wevpn, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/wevpn/connection_test.go b/internal/provider/wevpn/connection_test.go index b5bff1a3..bf8f562e 100644 --- a/internal/provider/wevpn/connection_test.go +++ b/internal/provider/wevpn/connection_test.go @@ -1,43 +1,68 @@ package wevpn import ( + "errors" "math/rand" "net" "testing" + "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/provider/utils" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/stretchr/testify/assert" ) func Test_Provider_GetConnection(t *testing.T) { t.Parallel() + const provider = providers.Wevpn + + errTest := errors.New("test error") + boolPtr := func(b bool) *bool { return &b } + testCases := map[string]struct { - servers []models.Server - selection settings.ServerSelection - connection models.Connection - errWrapped error - errMessage string + filteredServers []models.Server + storageErr error + selection settings.ServerSelection + connection models.Connection + errWrapped error + errMessage string + panicMessage string }{ - "no server available": { - selection: settings.ServerSelection{ - VPN: vpn.OpenVPN, - }.WithDefaults(providers.Wevpn), - errWrapped: utils.ErrNoServer, - errMessage: "no server", + "error": { + storageErr: errTest, + errWrapped: errTest, + errMessage: "cannot filter servers: test error", }, - "no filter": { - servers: []models.Server{ - {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, + "default OpenVPN TCP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, }, - selection: settings.ServerSelection{}.WithDefaults(providers.Wevpn), + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(true), + }, + }.WithDefaults(provider), + connection: models.Connection{ + Type: vpn.OpenVPN, + IP: net.IPv4(1, 1, 1, 1), + Port: 1195, + Protocol: constants.TCP, + }, + }, + "default OpenVPN UDP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(false), + }, + }.WithDefaults(provider), connection: models.Connection{ Type: vpn.OpenVPN, IP: net.IPv4(1, 1, 1, 1), @@ -45,38 +70,14 @@ func Test_Provider_GetConnection(t *testing.T) { Protocol: constants.UDP, }, }, - "target IP": { + "default Wireguard port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, selection: settings.ServerSelection{ - TargetIP: net.IPv4(2, 2, 2, 2), - }.WithDefaults(providers.Wevpn), - servers: []models.Server{ - {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, - }, - connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1194, - Protocol: constants.UDP, - }, - }, - "with filter": { - selection: settings.ServerSelection{ - Hostnames: []string{"b"}, - }.WithDefaults(providers.Wevpn), - servers: []models.Server{ - {Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, - }, - connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1194, - Hostname: "b", - Protocol: constants.UDP, - }, + VPN: vpn.Wireguard, + }.WithDefaults(provider), + panicMessage: "no default Wireguard port is defined!", }, } @@ -84,12 +85,23 @@ func Test_Provider_GetConnection(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + storage := common.NewMockStorage(ctrl) + storage.EXPECT().FilterServers(provider, testCase.selection). + Return(testCase.filteredServers, testCase.storageErr) randSource := rand.NewSource(0) - m := New(testCase.servers, randSource) + provider := New(storage, randSource) - connection, err := m.GetConnection(testCase.selection) + if testCase.panicMessage != "" { + assert.PanicsWithValue(t, testCase.panicMessage, func() { + _, _ = provider.GetConnection(testCase.selection) + }) + return + } + + connection, err := provider.GetConnection(testCase.selection) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { diff --git a/internal/provider/wevpn/provider.go b/internal/provider/wevpn/provider.go index 960383ea..68b1ce55 100644 --- a/internal/provider/wevpn/provider.go +++ b/internal/provider/wevpn/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Wevpn), } diff --git a/internal/provider/windscribe/connection.go b/internal/provider/windscribe/connection.go index f70c9ea6..039ddada 100644 --- a/internal/provider/windscribe/connection.go +++ b/internal/provider/windscribe/connection.go @@ -2,6 +2,7 @@ package windscribe import ( "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -9,5 +10,6 @@ import ( func (p *Provider) GetConnection(selection settings.ServerSelection) ( connection models.Connection, err error) { defaults := utils.NewConnectionDefaults(443, 1194, 1194) //nolint:gomnd - return utils.GetConnection(p.servers, selection, defaults, p.randSource) + return utils.GetConnection(providers.Windscribe, + p.storage, selection, defaults, p.randSource) } diff --git a/internal/provider/windscribe/connection_test.go b/internal/provider/windscribe/connection_test.go index a6933cd6..79754943 100644 --- a/internal/provider/windscribe/connection_test.go +++ b/internal/provider/windscribe/connection_test.go @@ -1,41 +1,68 @@ package windscribe import ( + "errors" "math/rand" "net" "testing" + "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/provider/utils" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/stretchr/testify/assert" ) func Test_Provider_GetConnection(t *testing.T) { t.Parallel() + const provider = providers.Windscribe + + errTest := errors.New("test error") + boolPtr := func(b bool) *bool { return &b } + testCases := map[string]struct { - servers []models.Server - selection settings.ServerSelection - connection models.Connection - errWrapped error - errMessage string + filteredServers []models.Server + storageErr error + selection settings.ServerSelection + connection models.Connection + errWrapped error + errMessage string + panicMessage string }{ - "no server available": { - selection: settings.ServerSelection{}.WithDefaults(providers.Windscribe), - errWrapped: utils.ErrNoServer, - errMessage: "no server", + "error": { + storageErr: errTest, + errWrapped: errTest, + errMessage: "cannot filter servers: test error", }, - "no filter": { - servers: []models.Server{ - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}}, - {VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}}, + "default OpenVPN TCP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, }, - selection: settings.ServerSelection{}.WithDefaults(providers.Windscribe), + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(true), + }, + }.WithDefaults(provider), + connection: models.Connection{ + Type: vpn.OpenVPN, + IP: net.IPv4(1, 1, 1, 1), + Port: 443, + Protocol: constants.TCP, + }, + }, + "default OpenVPN UDP port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, + selection: settings.ServerSelection{ + OpenVPN: settings.OpenVPNSelection{ + TCP: boolPtr(false), + }, + }.WithDefaults(provider), connection: models.Connection{ Type: vpn.OpenVPN, IP: net.IPv4(1, 1, 1, 1), @@ -43,49 +70,41 @@ func Test_Provider_GetConnection(t *testing.T) { Protocol: constants.UDP, }, }, - "target IP": { + "default Wireguard port": { + filteredServers: []models.Server{ + {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, + }, selection: settings.ServerSelection{ - TargetIP: net.IPv4(2, 2, 2, 2), - }.WithDefaults(providers.Windscribe), - servers: []models.Server{ - {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, - }, + VPN: vpn.Wireguard, + }.WithDefaults(provider), connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), + Type: vpn.Wireguard, + IP: net.IPv4(1, 1, 1, 1), Port: 1194, Protocol: constants.UDP, }, }, - "with filter": { - selection: settings.ServerSelection{ - Hostnames: []string{"b"}, - }.WithDefaults(providers.Windscribe), - servers: []models.Server{ - {Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, - {Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true}, - {Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true}, - }, - connection: models.Connection{ - Type: vpn.OpenVPN, - IP: net.IPv4(2, 2, 2, 2), - Port: 1194, - Hostname: "b", - Protocol: constants.UDP, - }, - }, } for name, testCase := range testCases { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + storage := common.NewMockStorage(ctrl) + storage.EXPECT().FilterServers(provider, testCase.selection). + Return(testCase.filteredServers, testCase.storageErr) randSource := rand.NewSource(0) - provider := New(testCase.servers, randSource) + provider := New(storage, randSource) + + if testCase.panicMessage != "" { + assert.PanicsWithValue(t, testCase.panicMessage, func() { + _, _ = provider.GetConnection(testCase.selection) + }) + return + } connection, err := provider.GetConnection(testCase.selection) @@ -93,6 +112,7 @@ func Test_Provider_GetConnection(t *testing.T) { if testCase.errWrapped != nil { assert.EqualError(t, err, testCase.errMessage) } + assert.Equal(t, testCase.connection, connection) }) } diff --git a/internal/provider/windscribe/provider.go b/internal/provider/windscribe/provider.go index e8adc1de..13b8f6db 100644 --- a/internal/provider/windscribe/provider.go +++ b/internal/provider/windscribe/provider.go @@ -4,19 +4,19 @@ import ( "math/rand" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { - servers []models.Server + storage common.Storage randSource rand.Source utils.NoPortForwarder } -func New(servers []models.Server, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source) *Provider { return &Provider{ - servers: servers, + storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Windscribe), } diff --git a/internal/storage/choices.go b/internal/storage/choices.go new file mode 100644 index 00000000..155f61fb --- /dev/null +++ b/internal/storage/choices.go @@ -0,0 +1,27 @@ +package storage + +import ( + "github.com/qdm12/gluetun/internal/configuration/settings/validation" + "github.com/qdm12/gluetun/internal/constants/providers" + "github.com/qdm12/gluetun/internal/models" +) + +func (s *Storage) GetFilterChoices(provider string) models.FilterChoices { + if provider == providers.Custom { + return models.FilterChoices{} + } + + s.mergedMutex.RLock() + defer s.mergedMutex.RUnlock() + + serversObject := s.getMergedServersObject(provider) + servers := serversObject.Servers + return models.FilterChoices{ + Countries: validation.ExtractCountries(servers), + Regions: validation.ExtractRegions(servers), + Cities: validation.ExtractCities(servers), + ISPs: validation.ExtractISPs(servers), + Names: validation.ExtractServerNames(servers), + Hostnames: validation.ExtractHostnames(servers), + } +} diff --git a/internal/storage/copy.go b/internal/storage/copy.go new file mode 100644 index 00000000..5f0b8ed6 --- /dev/null +++ b/internal/storage/copy.go @@ -0,0 +1,32 @@ +package storage + +import ( + "net" + + "github.com/qdm12/gluetun/internal/models" +) + +func copyServer(server models.Server) (serverCopy models.Server) { + serverCopy = server + serverCopy.IPs = copyIPs(server.IPs) + return serverCopy +} + +func copyIPs(toCopy []net.IP) (copied []net.IP) { + if toCopy == nil { + return nil + } + + copied = make([]net.IP, len(toCopy)) + for i := range toCopy { + copied[i] = copyIP(toCopy[i]) + } + + return copied +} + +func copyIP(toCopy net.IP) (copied net.IP) { + copied = make(net.IP, len(toCopy)) + copy(copied, toCopy) + return copied +} diff --git a/internal/storage/copy_test.go b/internal/storage/copy_test.go new file mode 100644 index 00000000..bad83d59 --- /dev/null +++ b/internal/storage/copy_test.go @@ -0,0 +1,75 @@ +package storage + +import ( + "net" + "testing" + + "github.com/qdm12/gluetun/internal/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_copyServer(t *testing.T) { + t.Parallel() + + server := models.Server{ + Country: "a", + IPs: []net.IP{{1, 2, 3, 4}}, + } + + serverCopy := copyServer(server) + + assert.Equal(t, server, serverCopy) + // Check for mutation + serverCopy.IPs[0][0] = 9 + assert.NotEqual(t, server, serverCopy) +} + +func Test_copyIPs(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + toCopy []net.IP + copied []net.IP + }{ + "nil": {}, + "empty": { + toCopy: []net.IP{}, + copied: []net.IP{}, + }, + "single IP": { + toCopy: []net.IP{{1, 1, 1, 1}}, + copied: []net.IP{{1, 1, 1, 1}}, + }, + "two IPs": { + toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}}, + copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + // Reserver leading 9 for copy modifications below + for _, ipToCopy := range testCase.toCopy { + require.NotEqual(t, 9, ipToCopy[0]) + } + + copied := copyIPs(testCase.toCopy) + + assert.Equal(t, testCase.copied, copied) + + if len(copied) > 0 { + original := testCase.toCopy[0][0] + testCase.toCopy[0][0] = 9 + assert.NotEqual(t, 9, copied[0][0]) + testCase.toCopy[0][0] = original + + copied[0][0] = 9 + assert.NotEqual(t, 9, testCase.toCopy[0][0]) + } + }) + } +} diff --git a/internal/storage/filter.go b/internal/storage/filter.go new file mode 100644 index 00000000..353c2d87 --- /dev/null +++ b/internal/storage/filter.go @@ -0,0 +1,143 @@ +package storage + +import ( + "strings" + + "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" + "github.com/qdm12/gluetun/internal/constants/vpn" + "github.com/qdm12/gluetun/internal/models" +) + +// FilterServers filter servers for the given provider and according +// to the given selection. The filtered servers are deep copied so they +// are safe for mutation by the caller. +func (s *Storage) FilterServers(provider string, selection settings.ServerSelection) ( + servers []models.Server, err error) { + if provider == providers.Custom { + return nil, nil + } + + s.mergedMutex.RLock() + defer s.mergedMutex.RUnlock() + + serversObject := s.getMergedServersObject(provider) + allServers := serversObject.Servers + + if len(allServers) == 0 { + return nil, ErrNoServerFound + } + + for _, server := range allServers { + if filterServer(server, selection) { + continue + } + + server = copyServer(server) + servers = append(servers, server) + } + + if len(servers) == 0 { + return nil, noServerFoundError(selection) + } + + return servers, nil +} + +func filterServer(server models.Server, + selection settings.ServerSelection) (filtered bool) { + // Note each condition is split to make sure + // we have full testing coverage. + if server.VPN != selection.VPN { + return true + } + + if filterByProtocol(selection, server.TCP, server.UDP) { + return true + } + + if *selection.MultiHopOnly && !server.MultiHop { + return true + } + + if *selection.FreeOnly && !server.Free { + return true + } + + if *selection.StreamOnly && !server.Stream { + return true + } + + if *selection.OwnedOnly && !server.Owned { + return true + } + + if filterByPossibilities(server.Country, selection.Countries) { + return true + } + + if filterByPossibilities(server.Region, selection.Regions) { + return true + } + + if filterByPossibilities(server.City, selection.Cities) { + return true + } + + if filterByPossibilities(server.ISP, selection.ISPs) { + return true + } + + if filterByPossibilitiesUint16(server.Number, selection.Numbers) { + return true + } + + if filterByPossibilities(server.ServerName, selection.Names) { + return true + } + + if filterByPossibilities(server.Hostname, selection.Hostnames) { + return true + } + + // TODO filter port forward server for PIA + + return false +} + +func filterByPossibilities(value string, possibilities []string) (filtered bool) { + if len(possibilities) == 0 { + return false + } + for _, possibility := range possibilities { + if strings.EqualFold(value, possibility) { + return false + } + } + return true +} + +// TODO merge with filterByPossibilities with generics in Go 1.18. +func filterByPossibilitiesUint16(value uint16, possibilities []uint16) (filtered bool) { + if len(possibilities) == 0 { + return false + } + for _, possibility := range possibilities { + if value == possibility { + return false + } + } + return true +} + +func filterByProtocol(selection settings.ServerSelection, + serverTCP, serverUDP bool) (filtered bool) { + switch selection.VPN { + case vpn.Wireguard: + return !serverUDP + default: // OpenVPN + wantTCP := *selection.OpenVPN.TCP + wantUDP := !wantTCP + return (wantTCP && !serverTCP) || (wantUDP && !serverUDP) + } +} diff --git a/internal/storage/flush.go b/internal/storage/flush.go index ac730e51..bea9b90e 100644 --- a/internal/storage/flush.go +++ b/internal/storage/flush.go @@ -4,21 +4,20 @@ import ( "encoding/json" "os" "path/filepath" - - "github.com/qdm12/gluetun/internal/models" ) -var _ Flusher = (*Storage)(nil) +// FlushToFile flushes the merged servers data to the file +// specified by path, as indented JSON. +func (s *Storage) FlushToFile(path string) error { + s.mergedMutex.RLock() + defer s.mergedMutex.RUnlock() -type Flusher interface { - FlushToFile(allServers *models.AllServers) error + return s.flushToFile(path) } -func (s *Storage) FlushToFile(allServers *models.AllServers) error { - return flushToFile(s.filepath, allServers) -} - -func flushToFile(path string, servers *models.AllServers) error { +// flushToFile flushes the merged servers data to the file +// specified by path, as indented JSON. It is not thread-safe. +func (s *Storage) flushToFile(path string) error { dirPath := filepath.Dir(path) if err := os.MkdirAll(dirPath, 0644); err != nil { return err @@ -28,11 +27,15 @@ func flushToFile(path string, servers *models.AllServers) error { if err != nil { return err } + encoder := json.NewEncoder(file) encoder.SetIndent("", " ") - if err := encoder.Encode(servers); err != nil { + + err = encoder.Encode(&s.mergedServers) + if err != nil { _ = file.Close() return err } + return file.Close() } diff --git a/internal/provider/utils/formatting.go b/internal/storage/formatting.go similarity index 99% rename from internal/provider/utils/formatting.go rename to internal/storage/formatting.go index 918f6131..1fa0e461 100644 --- a/internal/provider/utils/formatting.go +++ b/internal/storage/formatting.go @@ -1,4 +1,4 @@ -package utils +package storage import ( "errors" diff --git a/internal/storage/servers.go b/internal/storage/servers.go index 0830777c..f0d95966 100644 --- a/internal/storage/servers.go +++ b/internal/storage/servers.go @@ -1,7 +1,118 @@ package storage -import "github.com/qdm12/gluetun/internal/models" +import ( + "fmt" + "time" -func (s *Storage) GetServers() models.AllServers { - return s.mergedServers.GetCopy() + "github.com/qdm12/gluetun/internal/constants/providers" + "github.com/qdm12/gluetun/internal/models" +) + +// SetServers sets the given servers for the given provider +// in the storage in-memory map and saves all the servers +// to file. +// Note the servers given are not copied so the caller must +// NOT MUTATE them after calling this method. +func (s *Storage) SetServers(provider string, servers []models.Server) (err error) { + if provider == providers.Custom { + return + } + + s.mergedMutex.Lock() + defer s.mergedMutex.Unlock() + + serversObject := s.getMergedServersObject(provider) + serversObject.Timestamp = time.Now().Unix() + serversObject.Servers = servers + s.mergedServers.ProviderToServers[provider] = serversObject + + err = s.flushToFile(s.filepath) + if err != nil { + return fmt.Errorf("cannot save servers to file: %w", err) + } + return nil +} + +// GetServerByName returns the server for the given provider +// and server name. It returns `ok` as false if the server is +// not found. The returned server is also deep copied so it is +// safe for mutation and/or thread safe use. +func (s *Storage) GetServerByName(provider, name string) ( + server models.Server, ok bool) { + if provider == providers.Custom { + return server, false + } + + s.mergedMutex.RLock() + defer s.mergedMutex.RUnlock() + + serversObject := s.getMergedServersObject(provider) + for _, server := range serversObject.Servers { + if server.ServerName == name { + return copyServer(server), true + } + } + + return server, false +} + +// GetServersCount returns the number of servers for the provider given. +func (s *Storage) GetServersCount(provider string) (count int) { + if provider == providers.Custom { + return 0 + } + + s.mergedMutex.RLock() + defer s.mergedMutex.RUnlock() + + serversObject := s.getMergedServersObject(provider) + return len(serversObject.Servers) +} + +// FormatToMarkdown Markdown formats the servers for the provider given +// and returns the resulting string. +func (s *Storage) FormatToMarkdown(provider string) (formatted string) { + if provider == providers.Custom { + return "" + } + + s.mergedMutex.RLock() + defer s.mergedMutex.RUnlock() + + serversObject := s.getMergedServersObject(provider) + formatted = serversObject.ToMarkdown(provider) + return formatted +} + +// GetServersCount returns the number of servers for the provider given. +func (s *Storage) ServersAreEqual(provider string, servers []models.Server) (equal bool) { + if provider == providers.Custom { + return true + } + + s.mergedMutex.RLock() + defer s.mergedMutex.RUnlock() + + serversObject := s.getMergedServersObject(provider) + existingServers := serversObject.Servers + + if len(existingServers) != len(servers) { + return false + } + + for i := range existingServers { + if !existingServers[i].Equal(servers[i]) { + return false + } + } + + return true +} + +func (s *Storage) getMergedServersObject(provider string) (serversObject models.Servers) { + serversObject, ok := s.mergedServers.ProviderToServers[provider] + if !ok { + panic(fmt.Sprintf("provider %s not found in in-memory servers map", provider)) + } + return serversObject } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index afa02644..e8bb8fd2 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -2,11 +2,14 @@ package storage import ( + "sync" + "github.com/qdm12/gluetun/internal/models" ) type Storage struct { mergedServers models.AllServers + mergedMutex sync.RWMutex // this is stored in memory to avoid re-parsing // the embedded JSON file on every call to the // SyncServers method. diff --git a/internal/storage/sync.go b/internal/storage/sync.go index a5c88a5b..c58cb4fe 100644 --- a/internal/storage/sync.go +++ b/internal/storage/sync.go @@ -29,6 +29,9 @@ func (s *Storage) syncServers() (err error) { hardcodedCount := countServers(s.hardcodedServers) countOnFile := countServers(serversOnFile) + s.mergedMutex.Lock() + defer s.mergedMutex.Unlock() + if countOnFile == 0 { s.logger.Info(fmt.Sprintf( "creating %s with %d hardcoded servers", @@ -47,7 +50,8 @@ func (s *Storage) syncServers() (err error) { return nil } - if err := flushToFile(s.filepath, &s.mergedServers); err != nil { + err = s.flushToFile(s.filepath) + if err != nil { return fmt.Errorf("cannot write servers to file: %w", err) } return nil diff --git a/internal/updater/loop/loop.go b/internal/updater/loop/loop.go index 1c9eaca4..7ccfe131 100644 --- a/internal/updater/loop/loop.go +++ b/internal/updater/loop/loop.go @@ -9,7 +9,6 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/updater" ) @@ -24,16 +23,14 @@ type Looper interface { } type Updater interface { - UpdateServers(ctx context.Context) (allServers models.AllServers, err error) + UpdateServers(ctx context.Context) (err error) } type looper struct { state state // Objects - updater Updater - flusher storage.Flusher - setAllServers func(allServers models.AllServers) - logger Logger + updater Updater + logger Logger // Internal channels and locks loopLock sync.Mutex start chan struct{} @@ -49,32 +46,35 @@ type looper struct { const defaultBackoffTime = 5 * time.Second +type Storage interface { + SetServers(provider string, servers []models.Server) (err error) + GetServersCount(provider string) (count int) + ServersAreEqual(provider string, servers []models.Server) (equal bool) +} + type Logger interface { Info(s string) Warn(s string) Error(s string) } -func NewLooper(settings settings.Updater, currentServers models.AllServers, - flusher storage.Flusher, setAllServers func(allServers models.AllServers), +func NewLooper(settings settings.Updater, storage Storage, client *http.Client, logger Logger) Looper { return &looper{ state: state{ status: constants.Stopped, settings: settings, }, - updater: updater.New(settings, client, currentServers, logger), - flusher: flusher, - setAllServers: setAllServers, - logger: logger, - start: make(chan struct{}), - running: make(chan models.LoopStatus), - stop: make(chan struct{}), - stopped: make(chan struct{}), - updateTicker: make(chan struct{}), - timeNow: time.Now, - timeSince: time.Since, - backoffTime: defaultBackoffTime, + updater: updater.New(settings, client, storage, logger), + logger: logger, + start: make(chan struct{}), + running: make(chan models.LoopStatus), + stop: make(chan struct{}), + stopped: make(chan struct{}), + updateTicker: make(chan struct{}), + timeNow: time.Now, + timeSince: time.Since, + backoffTime: defaultBackoffTime, } } @@ -106,20 +106,19 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { for ctx.Err() == nil { updateCtx, updateCancel := context.WithCancel(ctx) - serversCh := make(chan models.AllServers) errorCh := make(chan error) runWg := &sync.WaitGroup{} runWg.Add(1) go func() { defer runWg.Done() - servers, err := l.updater.UpdateServers(updateCtx) + err := l.updater.UpdateServers(updateCtx) if err != nil { if updateCtx.Err() == nil { errorCh <- err } return } - serversCh <- servers + l.state.setStatusWithLock(constants.Completed) }() if !crashed { @@ -148,16 +147,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { updateCancel() runWg.Wait() l.stopped <- struct{}{} - case servers := <-serversCh: - l.setAllServers(servers) - if err := l.flusher.FlushToFile(&servers); err != nil { - l.logger.Error(err.Error()) - } - runWg.Wait() - l.state.setStatusWithLock(constants.Completed) - l.logger.Info("Updated servers information") case err := <-errorCh: - close(serversCh) runWg.Wait() l.state.setStatusWithLock(constants.Crashed) l.logAndWait(ctx, err) diff --git a/internal/updater/providers.go b/internal/updater/providers.go index 3fdb36d7..35e5ecad 100644 --- a/internal/updater/providers.go +++ b/internal/updater/providers.go @@ -3,8 +3,6 @@ package updater import ( "context" "fmt" - "reflect" - "time" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" @@ -31,18 +29,25 @@ import ( ) func (u *Updater) updateProvider(ctx context.Context, provider string) (err error) { - existingServers := u.getProviderServers(provider) - minServers := getMinServers(existingServers) + existingServersCount := u.storage.GetServersCount(provider) + minServers := getMinServers(existingServersCount) servers, err := u.getServers(ctx, provider, minServers) if err != nil { - return err + return fmt.Errorf("cannot get servers: %w", err) } - if reflect.DeepEqual(existingServers, servers) { + if u.storage.ServersAreEqual(provider, servers) { return nil } - u.patchProvider(provider, servers) + // Note the servers variable must NOT BE MUTATED after this call, + // since the implementation does not deep copy the servers. + // TODO set in storage in provider updater directly, server by server, + // to avoid accumulating server data in memory. + err = u.storage.SetServers(provider, servers) + if err != nil { + return fmt.Errorf("cannot set servers to storage: %w", err) + } return nil } @@ -101,25 +106,7 @@ func (u *Updater) getServers(ctx context.Context, provider string, return providerUpdater.GetServers(ctx, minServers) } -func (u *Updater) getProviderServers(provider string) (servers []models.Server) { - providerServers, ok := u.servers.ProviderToServers[provider] - if !ok { - panic(fmt.Sprintf("provider %s is unknown", provider)) - } - return providerServers.Servers -} - -func getMinServers(servers []models.Server) (minServers int) { +func getMinServers(existingServersCount int) (minServers int) { const minRatio = 0.8 - return int(minRatio * float64(len(servers))) -} - -func (u *Updater) patchProvider(provider string, servers []models.Server) { - providerServers, ok := u.servers.ProviderToServers[provider] - if !ok { - panic(fmt.Sprintf("provider %s is unknown", provider)) - } - providerServers.Timestamp = time.Now().Unix() - providerServers.Servers = servers - u.servers.ProviderToServers[provider] = providerServers + return int(minRatio * float64(existingServersCount)) } diff --git a/internal/updater/updater.go b/internal/updater/updater.go index a8746de9..1df4db50 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -19,7 +19,7 @@ type Updater struct { options settings.Updater // state - servers models.AllServers + storage Storage // Functions for tests logger Logger @@ -29,6 +29,12 @@ type Updater struct { unzipper unzip.Unzipper } +type Storage interface { + SetServers(provider string, servers []models.Server) (err error) + GetServersCount(provider string) (count int) + ServersAreEqual(provider string, servers []models.Server) (equal bool) +} + type Logger interface { Info(s string) Warn(s string) @@ -36,20 +42,20 @@ type Logger interface { } func New(settings settings.Updater, httpClient *http.Client, - currentServers models.AllServers, logger Logger) *Updater { + storage Storage, logger Logger) *Updater { unzipper := unzip.New(httpClient) return &Updater{ + options: settings, + storage: storage, logger: logger, timeNow: time.Now, presolver: resolver.NewParallelResolver(settings.DNSAddress.String()), client: httpClient, unzipper: unzipper, - options: settings, - servers: currentServers, } } -func (u *Updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) { +func (u *Updater) UpdateServers(ctx context.Context) (err error) { caser := cases.Title(language.English) for _, provider := range u.options.Providers { u.logger.Info("updating " + caser.String(provider) + " servers...") @@ -62,17 +68,17 @@ func (u *Updater) UpdateServers(ctx context.Context) (allServers models.AllServe // return the only error for the single provider. if len(u.options.Providers) == 1 { - return allServers, err + return err } // stop updating the next providers if context is canceled. if ctxErr := ctx.Err(); ctxErr != nil { - return allServers, ctxErr + return ctxErr } // Log the error and continue updating the next provider. u.logger.Error(err.Error()) } - return u.servers, nil + return nil } diff --git a/internal/vpn/loop.go b/internal/vpn/loop.go index 2a43104c..76ae2ffc 100644 --- a/internal/vpn/loop.go +++ b/internal/vpn/loop.go @@ -27,12 +27,12 @@ type Looper interface { loopstate.Getter loopstate.Applier SettingsGetSetter - ServersGetterSetter } type Loop struct { statusManager loopstate.Manager state state.Manager + storage Storage // Fixed parameters buildInfo models.BuildInformation versionInfo bool @@ -64,12 +64,17 @@ type firewallConfigurer interface { firewall.PortAllower } +type Storage interface { + FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error) + GetServerByName(provider, name string) (server models.Server, ok bool) +} + const ( defaultBackoffTime = 15 * time.Second ) func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16, - allServers models.AllServers, openvpnConf openvpn.Interface, + storage Storage, openvpnConf openvpn.Interface, netLinker netlink.NetLinker, fw firewallConfigurer, routing routing.VPNGetter, portForward portforward.StartStopper, starter command.Starter, publicip publicip.Looper, dnsLooper dns.Looper, @@ -81,11 +86,12 @@ func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16, stopped := make(chan struct{}) statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped) - state := state.New(statusManager, vpnSettings, allServers) + state := state.New(statusManager, vpnSettings) return &Loop{ statusManager: statusManager, state: state, + storage: storage, buildInfo: buildInfo, versionInfo: versionInfo, vpnInputPorts: vpnInputPorts, diff --git a/internal/vpn/run.go b/internal/vpn/run.go index 9e40bd8d..42f9c9d5 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -28,9 +28,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { } for ctx.Err() == nil { - settings, allServers := l.state.GetSettingsAndServers() + settings := l.state.GetSettings() - providerConf := provider.New(*settings.Provider.Name, allServers, time.Now) + providerConf := provider.New(*settings.Provider.Name, l.storage, time.Now) portForwarding := *settings.Provider.PortForwarding.Enabled var vpnRunner vpnRunner diff --git a/internal/vpn/servers.go b/internal/vpn/servers.go deleted file mode 100644 index 61d8fb68..00000000 --- a/internal/vpn/servers.go +++ /dev/null @@ -1,16 +0,0 @@ -package vpn - -import ( - "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/vpn/state" -) - -type ServersGetterSetter = state.ServersGetterSetter - -func (l *Loop) GetServers() (servers models.AllServers) { - return l.state.GetServers() -} - -func (l *Loop) SetServers(servers models.AllServers) { - l.state.SetServers(servers) -} diff --git a/internal/vpn/state/servers.go b/internal/vpn/state/servers.go deleted file mode 100644 index 547312fe..00000000 --- a/internal/vpn/state/servers.go +++ /dev/null @@ -1,20 +0,0 @@ -package state - -import "github.com/qdm12/gluetun/internal/models" - -type ServersGetterSetter interface { - GetServers() (servers models.AllServers) - SetServers(servers models.AllServers) -} - -func (s *State) GetServers() (servers models.AllServers) { - s.allServersMu.RLock() - defer s.allServersMu.RUnlock() - return s.allServers -} - -func (s *State) SetServers(servers models.AllServers) { - s.allServersMu.Lock() - defer s.allServersMu.Unlock() - s.allServers = servers -} diff --git a/internal/vpn/state/state.go b/internal/vpn/state/state.go index dfb99418..df48aaa7 100644 --- a/internal/vpn/state/state.go +++ b/internal/vpn/state/state.go @@ -5,23 +5,18 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/loopstate" - "github.com/qdm12/gluetun/internal/models" ) var _ Manager = (*State)(nil) type Manager interface { SettingsGetSetter - ServersGetterSetter - GetSettingsAndServers() (vpn settings.VPN, allServers models.AllServers) } -func New(statusApplier loopstate.Applier, - vpn settings.VPN, allServers models.AllServers) *State { +func New(statusApplier loopstate.Applier, vpn settings.VPN) *State { return &State{ statusApplier: statusApplier, vpn: vpn, - allServers: allServers, } } @@ -30,18 +25,4 @@ type State struct { vpn settings.VPN settingsMu sync.RWMutex - - allServers models.AllServers - allServersMu sync.RWMutex -} - -func (s *State) GetSettingsAndServers() (vpn settings.VPN, - allServers models.AllServers) { - s.settingsMu.RLock() - s.allServersMu.RLock() - vpn = s.vpn - allServers = s.allServers - s.settingsMu.RUnlock() - s.allServersMu.RUnlock() - return vpn, allServers }