From bd0868d764e9a6cccf629188748bfe85374a5ef0 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 27 May 2022 00:59:47 +0000 Subject: [PATCH] chore(all): provider to servers map in `allServers` - Simplify formatting CLI - Simplify updater code - Simplify filter choices for config validation - Simplify all servers deep copying - Custom JSON marshaling methods for `AllServers` - Simplify provider constructor switch - Simplify storage merging - Simplify storage reading and extraction - Simplify updating code --- internal/cli/formatservers.go | 89 ++-- internal/cli/update.go | 13 +- internal/configuration/settings/provider.go | 2 +- .../configuration/settings/serverselection.go | 120 +---- internal/configuration/settings/updater.go | 4 - .../settings/validation/servers.go | 18 + internal/constants/providers/providers.go | 10 +- .../constants/providers/providers_test.go | 23 + internal/models/getservers.go | 113 +--- internal/models/getservers_test.go | 227 ++++----- internal/models/servers.go | 195 +++++-- internal/models/servers_test.go | 189 +++++++ internal/provider/provider.go | 41 +- internal/storage/flush.go | 6 +- internal/storage/hardcoded_test.go | 12 +- internal/storage/merge.go | 35 +- internal/storage/read.go | 260 ++-------- internal/storage/read_test.go | 220 ++++---- internal/storage/sync.go | 28 +- internal/updater/loop.go | 2 +- internal/updater/providers.go | 481 +++--------------- internal/updater/updater.go | 61 +-- 22 files changed, 854 insertions(+), 1295 deletions(-) create mode 100644 internal/constants/providers/providers_test.go create mode 100644 internal/models/servers_test.go diff --git a/internal/cli/formatservers.go b/internal/cli/formatservers.go index 8042d51f..ae73dbdd 100644 --- a/internal/cli/formatservers.go +++ b/internal/cli/formatservers.go @@ -10,6 +10,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/storage" ) @@ -18,8 +19,9 @@ type ServersFormatter interface { } var ( - ErrFormatNotRecognized = errors.New("format is not recognized") - ErrProviderUnspecified = errors.New("VPN provider to format was not specified") + ErrFormatNotRecognized = errors.New("format is not recognized") + ErrProviderUnspecified = errors.New("VPN provider to format was not specified") + ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified") ) func addProviderFlag(flagSet *flag.FlagSet, @@ -31,21 +33,12 @@ func addProviderFlag(flagSet *flag.FlagSet, flagSet.BoolVar(boolPtr, provider, false, "Format "+strings.Title(provider)+" servers") } -func getFormatForProvider(providerToFormat map[string]*bool, provider string) (format bool) { - formatPtr, ok := providerToFormat[provider] - if !ok { - panic(fmt.Sprintf("unknown provider in format map: %s", provider)) - } - return *formatPtr -} - func (c *CLI) FormatServers(args []string) error { var format, output string allProviders := providers.All() providersToFormat := make(map[string]*bool, len(allProviders)) for _, provider := range allProviders { - value := false - providersToFormat[provider] = &value + providersToFormat[provider] = new(bool) } flagSet := flag.NewFlagSet("markdown", flag.ExitOnError) flagSet.StringVar(&format, "format", "markdown", "Format to use which can be: 'markdown'") @@ -61,6 +54,24 @@ func (c *CLI) FormatServers(args []string) error { return fmt.Errorf("%w: %s", ErrFormatNotRecognized, format) } + // Verify only one provider is set to be formatted. + var providers []string + for provider, formatPtr := range providersToFormat { + if *formatPtr { + providers = append(providers, provider) + } + } + switch len(providers) { + case 0: + return ErrProviderUnspecified + case 1: + default: + return fmt.Errorf("%w: %d specified: %s", + ErrMultipleProvidersToFormat, len(providers), + strings.Join(providers, ", ")) + } + providerToFormat := providers[0] + logger := newNoopLogger() storage, err := storage.New(logger, constants.ServersData) if err != nil { @@ -68,51 +79,7 @@ func (c *CLI) FormatServers(args []string) error { } currentServers := storage.GetServers() - var formatted string - switch { - case getFormatForProvider(providersToFormat, providers.Cyberghost): - formatted = currentServers.Cyberghost.ToMarkdown(providers.Cyberghost) - case getFormatForProvider(providersToFormat, providers.Expressvpn): - formatted = currentServers.Expressvpn.ToMarkdown(providers.Expressvpn) - case getFormatForProvider(providersToFormat, providers.Fastestvpn): - formatted = currentServers.Fastestvpn.ToMarkdown(providers.Fastestvpn) - case getFormatForProvider(providersToFormat, providers.HideMyAss): - formatted = currentServers.HideMyAss.ToMarkdown(providers.HideMyAss) - case getFormatForProvider(providersToFormat, providers.Ipvanish): - formatted = currentServers.Ipvanish.ToMarkdown(providers.Ipvanish) - case getFormatForProvider(providersToFormat, providers.Ivpn): - formatted = currentServers.Ivpn.ToMarkdown(providers.Ivpn) - case getFormatForProvider(providersToFormat, providers.Mullvad): - formatted = currentServers.Mullvad.ToMarkdown(providers.Mullvad) - case getFormatForProvider(providersToFormat, providers.Nordvpn): - formatted = currentServers.Nordvpn.ToMarkdown(providers.Nordvpn) - case getFormatForProvider(providersToFormat, providers.Perfectprivacy): - formatted = currentServers.Perfectprivacy.ToMarkdown(providers.Perfectprivacy) - case getFormatForProvider(providersToFormat, providers.PrivateInternetAccess): - formatted = currentServers.Pia.ToMarkdown(providers.PrivateInternetAccess) - case getFormatForProvider(providersToFormat, providers.Privado): - formatted = currentServers.Privado.ToMarkdown(providers.Privado) - case getFormatForProvider(providersToFormat, providers.Privatevpn): - formatted = currentServers.Privatevpn.ToMarkdown(providers.Privatevpn) - case getFormatForProvider(providersToFormat, providers.Protonvpn): - formatted = currentServers.Protonvpn.ToMarkdown(providers.Protonvpn) - case getFormatForProvider(providersToFormat, providers.Purevpn): - formatted = currentServers.Purevpn.ToMarkdown(providers.Purevpn) - case getFormatForProvider(providersToFormat, providers.Surfshark): - formatted = currentServers.Surfshark.ToMarkdown(providers.Surfshark) - case getFormatForProvider(providersToFormat, providers.Torguard): - formatted = currentServers.Torguard.ToMarkdown(providers.Torguard) - case getFormatForProvider(providersToFormat, providers.VPNUnlimited): - formatted = currentServers.VPNUnlimited.ToMarkdown(providers.VPNUnlimited) - case getFormatForProvider(providersToFormat, providers.Vyprvpn): - formatted = currentServers.Vyprvpn.ToMarkdown(providers.Vyprvpn) - case getFormatForProvider(providersToFormat, providers.Wevpn): - formatted = currentServers.Wevpn.ToMarkdown(providers.Wevpn) - case getFormatForProvider(providersToFormat, providers.Windscribe): - formatted = currentServers.Windscribe.ToMarkdown(providers.Windscribe) - default: - return ErrProviderUnspecified - } + formatted := formatServers(currentServers, providerToFormat) output = filepath.Clean(output) file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644) @@ -133,3 +100,11 @@ func (c *CLI) FormatServers(args []string) error { return nil } + +func formatServers(allServers models.AllServers, provider string) (formatted string) { + servers, ok := allServers.ProviderToServers[provider] + if !ok { + panic(fmt.Sprintf("unknown provider in format map: %s", provider)) + } + return servers.ToMarkdown(providers.Cyberghost) +} diff --git a/internal/cli/update.go b/internal/cli/update.go index 79cfe284..3eb548ef 100644 --- a/internal/cli/update.go +++ b/internal/cli/update.go @@ -63,12 +63,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e } if updateAll { - for _, provider := range providers.All() { - if provider == providers.Custom { - continue - } - options.Providers = append(options.Providers, provider) - } + options.Providers = providers.All() } else { if csvProviders == "" { return ErrNoProviderSpecified @@ -99,13 +94,13 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e } if endUserMode { - if err := storage.FlushToFile(allServers); err != nil { + if err := storage.FlushToFile(&allServers); err != nil { return fmt.Errorf("cannot write updated information to file: %w", err) } } if maintainerMode { - if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil { + if err := writeToEmbeddedJSON(c.repoServersPath, &allServers); err != nil { return fmt.Errorf("cannot write updated information to file: %w", err) } } @@ -114,7 +109,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e } func writeToEmbeddedJSON(repoServersPath string, - allServers models.AllServers) error { + allServers *models.AllServers) error { const perms = 0600 f, err := os.OpenFile(repoServersPath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms) diff --git a/internal/configuration/settings/provider.go b/internal/configuration/settings/provider.go index 1e237511..dfd41076 100644 --- a/internal/configuration/settings/provider.go +++ b/internal/configuration/settings/provider.go @@ -27,7 +27,7 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e // Validate Name var validNames []string if vpnType == vpn.OpenVPN { - validNames = providers.All() + validNames = providers.AllWithCustom() validNames = append(validNames, "pia") // Retro-compatibility } else { // Wireguard validNames = []string{ diff --git a/internal/configuration/settings/serverselection.go b/internal/configuration/settings/serverselection.go index a8b84390..d885ea54 100644 --- a/internal/configuration/settings/serverselection.go +++ b/internal/configuration/settings/serverselection.go @@ -140,118 +140,26 @@ func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection, countryChoices, regionChoices, cityChoices, ispChoices, nameChoices, hostnameChoices []string, err error) { - switch vpnServiceProvider { - case providers.Custom: - case providers.Cyberghost: - servers := allServers.GetCyberghost() - countryChoices = validation.ExtractCountries(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Expressvpn: - servers := allServers.GetExpressvpn() - countryChoices = validation.ExtractCountries(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Fastestvpn: - servers := allServers.GetFastestvpn() - countryChoices = validation.ExtractCountries(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.HideMyAss: - servers := allServers.GetHideMyAss() - countryChoices = validation.ExtractCountries(servers) - regionChoices = validation.ExtractRegions(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Ipvanish: - servers := allServers.GetIpvanish() - countryChoices = validation.ExtractCountries(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Ivpn: - servers := allServers.GetIvpn() - countryChoices = validation.ExtractCountries(servers) - cityChoices = validation.ExtractCities(servers) - ispChoices = validation.ExtractISPs(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Mullvad: - servers := allServers.GetMullvad() - countryChoices = validation.ExtractCountries(servers) - cityChoices = validation.ExtractCities(servers) - ispChoices = validation.ExtractISPs(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Nordvpn: - servers := allServers.GetNordvpn() - regionChoices = validation.ExtractRegions(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Perfectprivacy: - servers := allServers.GetPerfectprivacy() - cityChoices = validation.ExtractCities(servers) - case providers.Privado: - servers := allServers.GetPrivado() - countryChoices = validation.ExtractCountries(servers) - regionChoices = validation.ExtractRegions(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.PrivateInternetAccess: - servers := allServers.GetPia() - regionChoices = validation.ExtractRegions(servers) - hostnameChoices = validation.ExtractHostnames(servers) - nameChoices = validation.ExtractServerNames(servers) - case providers.Privatevpn: - servers := allServers.GetPrivatevpn() - countryChoices = validation.ExtractCountries(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Protonvpn: - servers := allServers.GetProtonvpn() - countryChoices = validation.ExtractCountries(servers) - regionChoices = validation.ExtractRegions(servers) - cityChoices = validation.ExtractCities(servers) - nameChoices = validation.ExtractServerNames(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Purevpn: - servers := allServers.GetPurevpn() - countryChoices = validation.ExtractCountries(servers) - regionChoices = validation.ExtractRegions(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Surfshark: - servers := allServers.GetSurfshark() - countryChoices = validation.ExtractCountries(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - regionChoices = validation.ExtractRegions(servers) + providerServers, ok := allServers.ProviderToServers[vpnServiceProvider] + if !ok { + panic(fmt.Sprintf("VPN service provider unknown: %s", vpnServiceProvider)) + } + servers := providerServers.Servers + countryChoices = validation.ExtractCountries(servers) + regionChoices = validation.ExtractRegions(servers) + cityChoices = validation.ExtractCities(servers) + ispChoices = validation.ExtractISPs(servers) + nameChoices = validation.ExtractServerNames(servers) + hostnameChoices = validation.ExtractHostnames(servers) + + if vpnServiceProvider == providers.Surfshark { + // // Retro compatibility // TODO v4 remove regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...) if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil { return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err) } - // Retro compatibility - // TODO remove in v4 *ss = surfsharkRetroRegion(*ss) - case providers.Torguard: - servers := allServers.GetTorguard() - countryChoices = validation.ExtractCountries(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.VPNUnlimited: - servers := allServers.GetVPNUnlimited() - countryChoices = validation.ExtractCountries(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Vyprvpn: - servers := allServers.GetVyprvpn() - regionChoices = validation.ExtractRegions(servers) - case providers.Wevpn: - servers := allServers.GetWevpn() - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - case providers.Windscribe: - servers := allServers.GetWindscribe() - regionChoices = validation.ExtractRegions(servers) - cityChoices = validation.ExtractCities(servers) - hostnameChoices = validation.ExtractHostnames(servers) - default: - return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrVPNProviderNameNotValid, vpnServiceProvider) } return countryChoices, regionChoices, cityChoices, diff --git a/internal/configuration/settings/updater.go b/internal/configuration/settings/updater.go index a930f22f..92cd55c3 100644 --- a/internal/configuration/settings/updater.go +++ b/internal/configuration/settings/updater.go @@ -43,10 +43,6 @@ func (u Updater) Validate() (err error) { for i, provider := range u.Providers { valid := false for _, validProvider := range providers.All() { - if validProvider == providers.Custom { - continue - } - if provider == validProvider { valid = true break diff --git a/internal/configuration/settings/validation/servers.go b/internal/configuration/settings/validation/servers.go index 9606e3dd..fb1d7f4e 100644 --- a/internal/configuration/settings/validation/servers.go +++ b/internal/configuration/settings/validation/servers.go @@ -19,6 +19,9 @@ func ExtractCountries(servers []models.Server) (values []string) { values = make([]string, 0, len(servers)) for _, server := range servers { value := server.Country + if value == "" { + continue + } _, alreadySeen := seen[value] if alreadySeen { continue @@ -35,6 +38,9 @@ func ExtractRegions(servers []models.Server) (values []string) { values = make([]string, 0, len(servers)) for _, server := range servers { value := server.Region + if value == "" { + continue + } _, alreadySeen := seen[value] if alreadySeen { continue @@ -51,6 +57,9 @@ func ExtractCities(servers []models.Server) (values []string) { values = make([]string, 0, len(servers)) for _, server := range servers { value := server.City + if value == "" { + continue + } _, alreadySeen := seen[value] if alreadySeen { continue @@ -67,6 +76,9 @@ func ExtractISPs(servers []models.Server) (values []string) { values = make([]string, 0, len(servers)) for _, server := range servers { value := server.ISP + if value == "" { + continue + } _, alreadySeen := seen[value] if alreadySeen { continue @@ -83,6 +95,9 @@ func ExtractServerNames(servers []models.Server) (values []string) { values = make([]string, 0, len(servers)) for _, server := range servers { value := server.ServerName + if value == "" { + continue + } _, alreadySeen := seen[value] if alreadySeen { continue @@ -99,6 +114,9 @@ func ExtractHostnames(servers []models.Server) (values []string) { values = make([]string, 0, len(servers)) for _, server := range servers { value := server.Hostname + if value == "" { + continue + } _, alreadySeen := seen[value] if alreadySeen { continue diff --git a/internal/constants/providers/providers.go b/internal/constants/providers/providers.go index 52ef1ca0..ca683051 100644 --- a/internal/constants/providers/providers.go +++ b/internal/constants/providers/providers.go @@ -26,9 +26,9 @@ const ( Windscribe = "windscribe" ) +// All returns all the providers except the custom provider. func All() []string { return []string{ - Custom, Cyberghost, Expressvpn, Fastestvpn, @@ -51,3 +51,11 @@ func All() []string { Windscribe, } } + +func AllWithCustom() []string { + allProviders := All() + allProvidersWithCustom := make([]string, len(allProviders)+1) + copy(allProvidersWithCustom, allProviders) + allProvidersWithCustom[len(allProvidersWithCustom)-1] = Custom + return allProvidersWithCustom +} diff --git a/internal/constants/providers/providers_test.go b/internal/constants/providers/providers_test.go new file mode 100644 index 00000000..88491dd9 --- /dev/null +++ b/internal/constants/providers/providers_test.go @@ -0,0 +1,23 @@ +package providers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_All(t *testing.T) { + t.Parallel() + + all := All() + assert.NotContains(t, all, Custom) + assert.NotEmpty(t, all) +} + +func Test_AllWithCustom(t *testing.T) { + t.Parallel() + + all := AllWithCustom() + assert.Contains(t, all, Custom) + assert.Len(t, all, len(All())+1) +} diff --git a/internal/models/getservers.go b/internal/models/getservers.go index d2cdd201..023347ee 100644 --- a/internal/models/getservers.go +++ b/internal/models/getservers.go @@ -4,108 +4,17 @@ import ( "net" ) -func (a AllServers) GetCopy() (servers AllServers) { - servers = a // copy versions and timestamps - servers.Cyberghost.Servers = a.GetCyberghost() - servers.Expressvpn.Servers = a.GetExpressvpn() - servers.Fastestvpn.Servers = a.GetFastestvpn() - servers.HideMyAss.Servers = a.GetHideMyAss() - servers.Ipvanish.Servers = a.GetIpvanish() - servers.Ivpn.Servers = a.GetIvpn() - servers.Mullvad.Servers = a.GetMullvad() - servers.Nordvpn.Servers = a.GetNordvpn() - servers.Perfectprivacy.Servers = a.GetPerfectprivacy() - servers.Privado.Servers = a.GetPrivado() - servers.Pia.Servers = a.GetPia() - servers.Privatevpn.Servers = a.GetPrivatevpn() - servers.Protonvpn.Servers = a.GetProtonvpn() - servers.Purevpn.Servers = a.GetPurevpn() - servers.Surfshark.Servers = a.GetSurfshark() - servers.Torguard.Servers = a.GetTorguard() - servers.VPNUnlimited.Servers = a.GetVPNUnlimited() - servers.Vyprvpn.Servers = a.GetVyprvpn() - servers.Windscribe.Servers = a.GetWindscribe() - return servers -} - -func (a *AllServers) GetCyberghost() (servers []Server) { - return copyServers(a.Cyberghost.Servers) -} - -func (a *AllServers) GetExpressvpn() (servers []Server) { - return copyServers(a.Expressvpn.Servers) -} - -func (a *AllServers) GetFastestvpn() (servers []Server) { - return copyServers(a.Fastestvpn.Servers) -} - -func (a *AllServers) GetHideMyAss() (servers []Server) { - return copyServers(a.HideMyAss.Servers) -} - -func (a *AllServers) GetIpvanish() (servers []Server) { - return copyServers(a.Ipvanish.Servers) -} - -func (a *AllServers) GetIvpn() (servers []Server) { - return copyServers(a.Ivpn.Servers) -} - -func (a *AllServers) GetMullvad() (servers []Server) { - return copyServers(a.Mullvad.Servers) -} - -func (a *AllServers) GetNordvpn() (servers []Server) { - return copyServers(a.Nordvpn.Servers) -} - -func (a *AllServers) GetPerfectprivacy() (servers []Server) { - return copyServers(a.Perfectprivacy.Servers) -} - -func (a *AllServers) GetPia() (servers []Server) { - return copyServers(a.Pia.Servers) -} - -func (a *AllServers) GetPrivado() (servers []Server) { - return copyServers(a.Privado.Servers) -} - -func (a *AllServers) GetPrivatevpn() (servers []Server) { - return copyServers(a.Privatevpn.Servers) -} - -func (a *AllServers) GetProtonvpn() (servers []Server) { - return copyServers(a.Protonvpn.Servers) -} - -func (a *AllServers) GetPurevpn() (servers []Server) { - return copyServers(a.Purevpn.Servers) -} - -func (a *AllServers) GetSurfshark() (servers []Server) { - return copyServers(a.Surfshark.Servers) -} - -func (a *AllServers) GetTorguard() (servers []Server) { - return copyServers(a.Torguard.Servers) -} - -func (a *AllServers) GetVPNUnlimited() (servers []Server) { - return copyServers(a.VPNUnlimited.Servers) -} - -func (a *AllServers) GetVyprvpn() (servers []Server) { - return copyServers(a.Vyprvpn.Servers) -} - -func (a *AllServers) GetWevpn() (servers []Server) { - return copyServers(a.Wevpn.Servers) -} - -func (a *AllServers) GetWindscribe() (servers []Server) { - return copyServers(a.Windscribe.Servers) +func (a AllServers) GetCopy() (allServersCopy AllServers) { + allServersCopy.Version = a.Version + allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers)) + for provider, servers := range a.ProviderToServers { + allServersCopy.ProviderToServers[provider] = Servers{ + Version: servers.Version, + Timestamp: servers.Timestamp, + Servers: copyServers(servers.Servers), + } + } + return allServersCopy } func copyServers(servers []Server) (serversCopy []Server) { diff --git a/internal/models/getservers_test.go b/internal/models/getservers_test.go index dff410ac..1381803e 100644 --- a/internal/models/getservers_test.go +++ b/internal/models/getservers_test.go @@ -4,108 +4,117 @@ import ( "net" "testing" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_AllServers_GetCopy(t *testing.T) { allServers := AllServers{ - Cyberghost: Servers{ - Version: 2, - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Expressvpn: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Fastestvpn: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - HideMyAss: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Ipvanish: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Ivpn: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Mullvad: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Nordvpn: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Perfectprivacy: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Privado: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Pia: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Privatevpn: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Protonvpn: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Purevpn: Servers{ - Version: 1, - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Surfshark: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Torguard: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - VPNUnlimited: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Vyprvpn: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, - }, - Windscribe: Servers{ - Servers: []Server{{ - IPs: []net.IP{{1, 2, 3, 4}}, - }}, + Version: 1, + ProviderToServers: map[string]Servers{ + providers.Cyberghost: { + Version: 2, + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Expressvpn: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Fastestvpn: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.HideMyAss: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Ipvanish: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Ivpn: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Mullvad: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Nordvpn: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Perfectprivacy: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Privado: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.PrivateInternetAccess: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Privatevpn: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Protonvpn: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Purevpn: { + Version: 1, + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Surfshark: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Torguard: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.VPNUnlimited: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Vyprvpn: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Wevpn: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, + providers.Windscribe: { + Servers: []Server{{ + IPs: []net.IP{{1, 2, 3, 4}}, + }}, + }, }, } @@ -114,32 +123,6 @@ func Test_AllServers_GetCopy(t *testing.T) { assert.Equal(t, allServers, servers) } -func Test_AllServers_GetVyprvpn(t *testing.T) { - allServers := AllServers{ - Vyprvpn: Servers{ - Servers: []Server{ - {Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}}, - {Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}}, - }, - }, - } - - servers := allServers.GetVyprvpn() - - expectedServers := []Server{ - {Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}}, - {Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}}, - } - assert.Equal(t, expectedServers, servers) - - allServers.Vyprvpn.Servers[0].IPs[0][0] = 9 - assert.NotEqual(t, 9, servers[0].IPs[0][0]) - - allServers.Vyprvpn.Servers[0].IPs[0][0] = 1 - servers[0].IPs[0][0] = 9 - assert.NotEqual(t, 9, allServers.Vyprvpn.Servers[0].IPs[0][0]) -} - func Test_copyIPs(t *testing.T) { t.Parallel() diff --git a/internal/models/servers.go b/internal/models/servers.go index ce6cb1a2..d594f726 100644 --- a/internal/models/servers.go +++ b/internal/models/servers.go @@ -1,54 +1,163 @@ package models +import ( + "bytes" + "encoding/json" + "fmt" + "math" + "reflect" + + "github.com/qdm12/gluetun/internal/constants/providers" +) + type AllServers struct { - Version uint16 `json:"version"` // used for migration of the top level scheme - Cyberghost Servers `json:"cyberghost"` - Expressvpn Servers `json:"expressvpn"` - Fastestvpn Servers `json:"fastestvpn"` - HideMyAss Servers `json:"hidemyass"` - Ipvanish Servers `json:"ipvanish"` - Ivpn Servers `json:"ivpn"` - Mullvad Servers `json:"mullvad"` - Perfectprivacy Servers `json:"perfect privacy"` - Nordvpn Servers `json:"nordvpn"` - Privado Servers `json:"privado"` - Pia Servers `json:"private internet access"` - Privatevpn Servers `json:"privatevpn"` - Protonvpn Servers `json:"protonvpn"` - Purevpn Servers `json:"purevpn"` - Surfshark Servers `json:"surfshark"` - Torguard Servers `json:"torguard"` - VPNUnlimited Servers `json:"vpn unlimited"` - Vyprvpn Servers `json:"vyprvpn"` - Wevpn Servers `json:"wevpn"` - Windscribe Servers `json:"windscribe"` + Version uint16 // used for migration of the top level scheme + ProviderToServers map[string]Servers } -func (a *AllServers) Count() int { - return len(a.Cyberghost.Servers) + - len(a.Expressvpn.Servers) + - len(a.Fastestvpn.Servers) + - len(a.HideMyAss.Servers) + - len(a.Ipvanish.Servers) + - len(a.Ivpn.Servers) + - len(a.Mullvad.Servers) + - len(a.Nordvpn.Servers) + - len(a.Perfectprivacy.Servers) + - len(a.Privado.Servers) + - len(a.Pia.Servers) + - len(a.Privatevpn.Servers) + - len(a.Protonvpn.Servers) + - len(a.Purevpn.Servers) + - len(a.Surfshark.Servers) + - len(a.Torguard.Servers) + - len(a.VPNUnlimited.Servers) + - len(a.Vyprvpn.Servers) + - len(a.Wevpn.Servers) + - len(a.Windscribe.Servers) +func (a *AllServers) ServersSlice(provider string) []Server { + servers, ok := a.ProviderToServers[provider] + if !ok { + panic(fmt.Sprintf("provider %s not found in all servers", provider)) + } + return copyServers(servers.Servers) +} + +var _ json.Marshaler = (*AllServers)(nil) + +// MarshalJSON marshals all servers to JSON. +// Note you need to use a pointer to all servers +// for it to work with native json methods such as +// json.Marshal. +func (a *AllServers) MarshalJSON() (data []byte, err error) { + buffer := bytes.NewBuffer(nil) + + _, err = buffer.WriteString("{") + if err != nil { + return nil, fmt.Errorf("cannot write opening bracket: %w", err) + } + + versionString := fmt.Sprintf(`"version":%d`, a.Version) + _, err = buffer.WriteString(versionString) + if err != nil { + return nil, fmt.Errorf("cannot write schema version string: %w", err) + } + + for _, provider := range providers.All() { + servers, ok := a.ProviderToServers[provider] + if !ok { + panic(fmt.Sprintf("provider %s not found in all servers", provider)) + } + + providerKey := fmt.Sprintf(`,"%s":`, provider) + _, err = buffer.WriteString(providerKey) + if err != nil { + return nil, fmt.Errorf("cannot write provider key %s: %w", + providerKey, err) + } + + serversJSON, err := json.Marshal(servers) + if err != nil { + return nil, fmt.Errorf("failed encoding servers for provider %s: %w", + provider, err) + } + _, err = buffer.Write(serversJSON) + if err != nil { + return nil, fmt.Errorf("cannot write JSON servers data for provider %s: %w", + provider, err) + } + } + + _, err = buffer.WriteString("}") + if err != nil { + return nil, fmt.Errorf("cannot write closing bracket: %w", err) + } + + return buffer.Bytes(), nil +} + +var _ json.Unmarshaler = (*AllServers)(nil) + +func (a *AllServers) UnmarshalJSON(data []byte) (err error) { + keyValues := make(map[string]interface{}) + err = json.Unmarshal(data, &keyValues) + if err != nil { + return err + } + + versionUnmarshaled := keyValues["version"] + if versionUnmarshaled != nil { // defaults to 0 + version, ok := versionUnmarshaled.(float64) + if !ok { + return &json.UnmarshalTypeError{ + Value: fmt.Sprintf("number %v", versionUnmarshaled), + Type: reflect.TypeOf(uint16(0)), + Struct: "models.AllServers", + Field: "Version", + } + } + + if math.Round(version) != version || + version < 0 || version > float64(^uint16(0)) { + return &json.UnmarshalTypeError{ + Value: fmt.Sprintf("number %v", version), + Type: reflect.TypeOf(uint16(0)), + Struct: "models.AllServers", + Field: "Version", + } + } + + a.Version = uint16(version) + delete(keyValues, "version") + } + + if len(keyValues) == 0 { + return nil + } + + a.ProviderToServers = make(map[string]Servers, len(keyValues)) + + allProviders := providers.All() + allProvidersSet := make(map[string]struct{}, len(allProviders)) + for _, provider := range allProviders { + allProvidersSet[provider] = struct{}{} + } + + for key, value := range keyValues { + if _, ok := allProvidersSet[key]; !ok { + // not a provider known by Gluetun + // or a non-servers field. + continue + } + + jsonValue, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("cannot marshal %s servers: %w", + key, err) + } + + var servers Servers + err = json.Unmarshal(jsonValue, &servers) + if err != nil { + return fmt.Errorf("cannot unmarshal %s servers: %w", + key, err) + } + + a.ProviderToServers[key] = servers + } + + return nil +} + +func (a *AllServers) Count() (count int) { + for _, servers := range a.ProviderToServers { + count += len(servers.Servers) + } + return count } type Servers struct { Version uint16 `json:"version"` Timestamp int64 `json:"timestamp"` - Servers []Server `json:"servers"` + Servers []Server `json:"servers,omitempty"` } diff --git a/internal/models/servers_test.go b/internal/models/servers_test.go new file mode 100644 index 00000000..73caeaf9 --- /dev/null +++ b/internal/models/servers_test.go @@ -0,0 +1,189 @@ +package models + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/qdm12/gluetun/internal/constants/providers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_AllServers_MarshalJSON(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + allServers *AllServers + dataString string + errWrapped error + errMessage string + }{ + "empty": { + allServers: &AllServers{ + ProviderToServers: map[string]Servers{}, + }, + dataString: `{"version":0,` + + `"cyberghost":{"version":0,"timestamp":0},` + + `"expressvpn":{"version":0,"timestamp":0},` + + `"fastestvpn":{"version":0,"timestamp":0},` + + `"hidemyass":{"version":0,"timestamp":0},` + + `"ipvanish":{"version":0,"timestamp":0},` + + `"ivpn":{"version":0,"timestamp":0},` + + `"mullvad":{"version":0,"timestamp":0},` + + `"nordvpn":{"version":0,"timestamp":0},` + + `"perfect privacy":{"version":0,"timestamp":0},` + + `"privado":{"version":0,"timestamp":0},` + + `"private internet access":{"version":0,"timestamp":0},` + + `"privatevpn":{"version":0,"timestamp":0},` + + `"protonvpn":{"version":0,"timestamp":0},` + + `"purevpn":{"version":0,"timestamp":0},` + + `"surfshark":{"version":0,"timestamp":0},` + + `"torguard":{"version":0,"timestamp":0},` + + `"vpn unlimited":{"version":0,"timestamp":0},` + + `"vyprvpn":{"version":0,"timestamp":0},` + + `"wevpn":{"version":0,"timestamp":0},` + + `"windscribe":{"version":0,"timestamp":0}}`, + }, + "two known providers": { + allServers: &AllServers{ + Version: 1, + ProviderToServers: map[string]Servers{ + providers.Cyberghost: { + Version: 1, + Timestamp: 1000, + Servers: []Server{ + {Country: "A"}, + {Country: "B"}, + }, + }, + providers.Privado: { + Version: 2, + Timestamp: 2000, + Servers: []Server{ + {City: "C"}, + {City: "D"}, + }, + }, + }, + }, + dataString: `{"version":1,` + + `"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` + + `"expressvpn":{"version":0,"timestamp":0},` + + `"fastestvpn":{"version":0,"timestamp":0},` + + `"hidemyass":{"version":0,"timestamp":0},` + + `"ipvanish":{"version":0,"timestamp":0},` + + `"ivpn":{"version":0,"timestamp":0},` + + `"mullvad":{"version":0,"timestamp":0},` + + `"nordvpn":{"version":0,"timestamp":0},` + + `"perfect privacy":{"version":0,"timestamp":0},` + + `"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]},` + + `"private internet access":{"version":0,"timestamp":0},` + + `"privatevpn":{"version":0,"timestamp":0},` + + `"protonvpn":{"version":0,"timestamp":0},` + + `"purevpn":{"version":0,"timestamp":0},` + + `"surfshark":{"version":0,"timestamp":0},` + + `"torguard":{"version":0,"timestamp":0},` + + `"vpn unlimited":{"version":0,"timestamp":0},` + + `"vyprvpn":{"version":0,"timestamp":0},` + + `"wevpn":{"version":0,"timestamp":0},` + + `"windscribe":{"version":0,"timestamp":0}}`, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + // Populate all providers in all servers + for _, provider := range providers.All() { + _, has := testCase.allServers.ProviderToServers[provider] + if !has { + testCase.allServers.ProviderToServers[provider] = Servers{} + } + } + + data, err := testCase.allServers.MarshalJSON() + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.dataString, string(data)) + + data, err = json.Marshal(testCase.allServers) + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.dataString, string(data)) + + buffer := bytes.NewBuffer(nil) + encoder := json.NewEncoder(buffer) + // encoder.SetIndent("", " ") + err = encoder.Encode(testCase.allServers) + require.NoError(t, err) + assert.Equal(t, testCase.dataString+"\n", buffer.String()) + }) + } +} + +func Test_AllServers_UnmarshalJSON(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + dataString string + allServers AllServers + errWrapped error + errMessage string + }{ + "empty": { + dataString: "{}", + allServers: AllServers{}, + }, + "two known providers": { + dataString: `{"version":1,` + + `"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` + + `"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]}}`, + allServers: AllServers{ + Version: 1, + ProviderToServers: map[string]Servers{ + providers.Cyberghost: { + Version: 1, + Timestamp: 1000, + Servers: []Server{ + {Country: "A"}, + {Country: "B"}, + }, + }, + providers.Privado: { + Version: 2, + Timestamp: 2000, + Servers: []Server{ + {City: "C"}, + {City: "D"}, + }, + }, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + data := []byte(testCase.dataString) + var allServers AllServers + + err := json.Unmarshal(data, &allServers) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.allServers, allServers) + }) + } +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index e0fc93d8..2337239e 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -51,50 +51,51 @@ type PortForwarder interface { } func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider { + serversSlice := allServers.ServersSlice(provider) randSource := rand.NewSource(timeNow().UnixNano()) switch provider { case providers.Custom: return custom.New() case providers.Cyberghost: - return cyberghost.New(allServers.Cyberghost.Servers, randSource) + return cyberghost.New(serversSlice, randSource) case providers.Expressvpn: - return expressvpn.New(allServers.Expressvpn.Servers, randSource) + return expressvpn.New(serversSlice, randSource) case providers.Fastestvpn: - return fastestvpn.New(allServers.Fastestvpn.Servers, randSource) + return fastestvpn.New(serversSlice, randSource) case providers.HideMyAss: - return hidemyass.New(allServers.HideMyAss.Servers, randSource) + return hidemyass.New(serversSlice, randSource) case providers.Ipvanish: - return ipvanish.New(allServers.Ipvanish.Servers, randSource) + return ipvanish.New(serversSlice, randSource) case providers.Ivpn: - return ivpn.New(allServers.Ivpn.Servers, randSource) + return ivpn.New(serversSlice, randSource) case providers.Mullvad: - return mullvad.New(allServers.Mullvad.Servers, randSource) + return mullvad.New(serversSlice, randSource) case providers.Nordvpn: - return nordvpn.New(allServers.Nordvpn.Servers, randSource) + return nordvpn.New(serversSlice, randSource) case providers.Perfectprivacy: - return perfectprivacy.New(allServers.Perfectprivacy.Servers, randSource) + return perfectprivacy.New(serversSlice, randSource) case providers.Privado: - return privado.New(allServers.Privado.Servers, randSource) + return privado.New(serversSlice, randSource) case providers.PrivateInternetAccess: - return privateinternetaccess.New(allServers.Pia.Servers, randSource, timeNow) + return privateinternetaccess.New(serversSlice, randSource, timeNow) case providers.Privatevpn: - return privatevpn.New(allServers.Privatevpn.Servers, randSource) + return privatevpn.New(serversSlice, randSource) case providers.Protonvpn: - return protonvpn.New(allServers.Protonvpn.Servers, randSource) + return protonvpn.New(serversSlice, randSource) case providers.Purevpn: - return purevpn.New(allServers.Purevpn.Servers, randSource) + return purevpn.New(serversSlice, randSource) case providers.Surfshark: - return surfshark.New(allServers.Surfshark.Servers, randSource) + return surfshark.New(serversSlice, randSource) case providers.Torguard: - return torguard.New(allServers.Torguard.Servers, randSource) + return torguard.New(serversSlice, randSource) case providers.VPNUnlimited: - return vpnunlimited.New(allServers.VPNUnlimited.Servers, randSource) + return vpnunlimited.New(serversSlice, randSource) case providers.Vyprvpn: - return vyprvpn.New(allServers.Vyprvpn.Servers, randSource) + return vyprvpn.New(serversSlice, randSource) case providers.Wevpn: - return wevpn.New(allServers.Wevpn.Servers, randSource) + return wevpn.New(serversSlice, randSource) case providers.Windscribe: - return windscribe.New(allServers.Windscribe.Servers, randSource) + return windscribe.New(serversSlice, randSource) default: panic("provider " + provider + " is unknown") // should never occur } diff --git a/internal/storage/flush.go b/internal/storage/flush.go index 9ea29f7c..ac730e51 100644 --- a/internal/storage/flush.go +++ b/internal/storage/flush.go @@ -11,14 +11,14 @@ import ( var _ Flusher = (*Storage)(nil) type Flusher interface { - FlushToFile(allServers models.AllServers) error + FlushToFile(allServers *models.AllServers) error } -func (s *Storage) FlushToFile(allServers models.AllServers) error { +func (s *Storage) FlushToFile(allServers *models.AllServers) error { return flushToFile(s.filepath, allServers) } -func flushToFile(path string, servers models.AllServers) error { +func flushToFile(path string, servers *models.AllServers) error { dirPath := filepath.Dir(path) if err := os.MkdirAll(dirPath, 0644); err != nil { return err diff --git a/internal/storage/hardcoded_test.go b/internal/storage/hardcoded_test.go index 53f886de..7d8db6a0 100644 --- a/internal/storage/hardcoded_test.go +++ b/internal/storage/hardcoded_test.go @@ -3,6 +3,8 @@ package storage import ( "testing" + "github.com/qdm12/gluetun/internal/constants/providers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -12,5 +14,13 @@ func Test_parseHardcodedServers(t *testing.T) { servers, err := parseHardcodedServers() require.NoError(t, err) - require.NotEmpty(t, len(servers.Cyberghost.Servers)) + + // all providers minus custom + allProviders := providers.All() + require.Equal(t, len(allProviders), len(servers.ProviderToServers)) + for _, provider := range allProviders { + servers, ok := servers.ProviderToServers[provider] + assert.Truef(t, ok, "for provider %s", provider) + assert.NotEmptyf(t, servers, "for provider %s", provider) + } } diff --git a/internal/storage/merge.go b/internal/storage/merge.go index 914068d5..ec79b636 100644 --- a/internal/storage/merge.go +++ b/internal/storage/merge.go @@ -28,29 +28,20 @@ func (s *Storage) logTimeDiff(provider string, persistedUnix, hardcodedUnix int6 } func (s *Storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers { - return models.AllServers{ - Version: hardcoded.Version, - Cyberghost: s.mergeProviderServers(providers.Cyberghost, hardcoded.Cyberghost, persisted.Cyberghost), - Expressvpn: s.mergeProviderServers(providers.Expressvpn, hardcoded.Expressvpn, persisted.Expressvpn), - Fastestvpn: s.mergeProviderServers(providers.Fastestvpn, hardcoded.Fastestvpn, persisted.Fastestvpn), - HideMyAss: s.mergeProviderServers(providers.HideMyAss, hardcoded.HideMyAss, persisted.HideMyAss), - Ipvanish: s.mergeProviderServers(providers.Ipvanish, hardcoded.Ipvanish, persisted.Ipvanish), - Ivpn: s.mergeProviderServers(providers.Ivpn, hardcoded.Ivpn, persisted.Ivpn), - Mullvad: s.mergeProviderServers(providers.Mullvad, hardcoded.Mullvad, persisted.Mullvad), - Nordvpn: s.mergeProviderServers(providers.Nordvpn, hardcoded.Nordvpn, persisted.Nordvpn), - Perfectprivacy: s.mergeProviderServers(providers.Perfectprivacy, hardcoded.Perfectprivacy, persisted.Perfectprivacy), - Privado: s.mergeProviderServers(providers.Privado, hardcoded.Privado, persisted.Privado), - Pia: s.mergeProviderServers(providers.PrivateInternetAccess, hardcoded.Pia, persisted.Pia), - Privatevpn: s.mergeProviderServers(providers.Privatevpn, hardcoded.Privatevpn, persisted.Privatevpn), - Protonvpn: s.mergeProviderServers(providers.Protonvpn, hardcoded.Protonvpn, persisted.Protonvpn), - Purevpn: s.mergeProviderServers(providers.Purevpn, hardcoded.Purevpn, persisted.Purevpn), - Surfshark: s.mergeProviderServers(providers.Surfshark, hardcoded.Surfshark, persisted.Surfshark), - Torguard: s.mergeProviderServers(providers.Torguard, hardcoded.Torguard, persisted.Torguard), - VPNUnlimited: s.mergeProviderServers(providers.VPNUnlimited, hardcoded.VPNUnlimited, persisted.VPNUnlimited), - Vyprvpn: s.mergeProviderServers(providers.Vyprvpn, hardcoded.Vyprvpn, persisted.Vyprvpn), - Wevpn: s.mergeProviderServers(providers.Wevpn, hardcoded.Wevpn, persisted.Wevpn), - Windscribe: s.mergeProviderServers(providers.Windscribe, hardcoded.Windscribe, persisted.Windscribe), + allProviders := providers.All() + merged := models.AllServers{ + Version: hardcoded.Version, + ProviderToServers: make(map[string]models.Servers, len(allProviders)), } + + for _, provider := range allProviders { + hardcodedServers := hardcoded.ProviderToServers[provider] + persistedServers := persisted.ProviderToServers[provider] + merged.ProviderToServers[provider] = s.mergeProviderServers(provider, + hardcodedServers, persistedServers) + } + + return merged } func (s *Storage) mergeProviderServers(provider string, diff --git a/internal/storage/read.go b/internal/storage/read.go index 6b4b80b9..1bc8122c 100644 --- a/internal/storage/read.go +++ b/internal/storage/read.go @@ -38,172 +38,39 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) ( func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) ( servers models.AllServers, err error) { - var versions allVersions - if err := json.Unmarshal(b, &versions); err != nil { - return servers, fmt.Errorf("cannot decode versions: %w", err) - } - - var rawMessages allJSONRawMessages + rawMessages := make(map[string]json.RawMessage) if err := json.Unmarshal(b, &rawMessages); err != nil { return servers, fmt.Errorf("cannot decode servers: %w", err) } - type element struct { - provider string - hardcoded models.Servers - serverVersion serverVersion - rawMessage json.RawMessage - target *models.Servers - } - elements := []element{ - { - provider: providers.Cyberghost, - hardcoded: hardcoded.Cyberghost, - serverVersion: versions.Cyberghost, - rawMessage: rawMessages.Cyberghost, - target: &servers.Cyberghost, - }, - { - provider: providers.Expressvpn, - hardcoded: hardcoded.Expressvpn, - serverVersion: versions.Expressvpn, - rawMessage: rawMessages.Expressvpn, - target: &servers.Expressvpn, - }, - { - provider: providers.Fastestvpn, - hardcoded: hardcoded.Fastestvpn, - serverVersion: versions.Fastestvpn, - rawMessage: rawMessages.Fastestvpn, - target: &servers.Fastestvpn, - }, - { - provider: providers.HideMyAss, - hardcoded: hardcoded.HideMyAss, - serverVersion: versions.HideMyAss, - rawMessage: rawMessages.HideMyAss, - target: &servers.HideMyAss, - }, - { - provider: providers.Ipvanish, - hardcoded: hardcoded.Ipvanish, - serverVersion: versions.Ipvanish, - rawMessage: rawMessages.Ipvanish, - target: &servers.Ipvanish, - }, - { - provider: providers.Ivpn, - hardcoded: hardcoded.Ivpn, - serverVersion: versions.Ivpn, - rawMessage: rawMessages.Ivpn, - target: &servers.Ivpn, - }, - { - provider: providers.Mullvad, - hardcoded: hardcoded.Mullvad, - serverVersion: versions.Mullvad, - rawMessage: rawMessages.Mullvad, - target: &servers.Mullvad, - }, - { - provider: providers.Nordvpn, - hardcoded: hardcoded.Nordvpn, - serverVersion: versions.Nordvpn, - rawMessage: rawMessages.Nordvpn, - target: &servers.Nordvpn, - }, - { - provider: providers.Perfectprivacy, - hardcoded: hardcoded.Perfectprivacy, - serverVersion: versions.Perfectprivacy, - rawMessage: rawMessages.Perfectprivacy, - target: &servers.Perfectprivacy, - }, - { - provider: providers.Privado, - hardcoded: hardcoded.Privado, - serverVersion: versions.Privado, - rawMessage: rawMessages.Privado, - target: &servers.Privado, - }, - { - provider: providers.PrivateInternetAccess, - hardcoded: hardcoded.Pia, - serverVersion: versions.Pia, - rawMessage: rawMessages.Pia, - target: &servers.Pia, - }, - { - provider: providers.Privatevpn, - hardcoded: hardcoded.Privatevpn, - serverVersion: versions.Privatevpn, - rawMessage: rawMessages.Privatevpn, - target: &servers.Privatevpn, - }, - { - provider: providers.Protonvpn, - hardcoded: hardcoded.Protonvpn, - serverVersion: versions.Protonvpn, - rawMessage: rawMessages.Protonvpn, - target: &servers.Protonvpn, - }, - { - provider: providers.Purevpn, - hardcoded: hardcoded.Purevpn, - serverVersion: versions.Purevpn, - rawMessage: rawMessages.Purevpn, - target: &servers.Purevpn, - }, - { - provider: providers.Surfshark, - hardcoded: hardcoded.Surfshark, - serverVersion: versions.Surfshark, - rawMessage: rawMessages.Surfshark, - target: &servers.Surfshark, - }, - { - provider: providers.Torguard, - hardcoded: hardcoded.Torguard, - serverVersion: versions.Torguard, - rawMessage: rawMessages.Torguard, - target: &servers.Torguard, - }, - { - provider: providers.VPNUnlimited, - hardcoded: hardcoded.VPNUnlimited, - serverVersion: versions.VPNUnlimited, - rawMessage: rawMessages.VPNUnlimited, - target: &servers.VPNUnlimited, - }, - { - provider: providers.Vyprvpn, - hardcoded: hardcoded.Vyprvpn, - serverVersion: versions.Vyprvpn, - rawMessage: rawMessages.Vyprvpn, - target: &servers.Vyprvpn, - }, - { - provider: providers.Wevpn, - hardcoded: hardcoded.Wevpn, - serverVersion: versions.Wevpn, - rawMessage: rawMessages.Wevpn, - target: &servers.Wevpn, - }, - { - provider: providers.Windscribe, - hardcoded: hardcoded.Windscribe, - serverVersion: versions.Windscribe, - rawMessage: rawMessages.Windscribe, - target: &servers.Windscribe, - }, - } + // Note schema version is at map key "version" as number - for _, element := range elements { - *element.target, err = s.readServers(element.provider, - element.hardcoded, element.serverVersion, element.rawMessage) - if err != nil { - return servers, err + allProviders := providers.All() + servers.ProviderToServers = make(map[string]models.Servers, len(allProviders)) + for _, provider := range allProviders { + hardcoded, ok := hardcoded.ProviderToServers[provider] + if !ok { + panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider)) } + + rawMessage, ok := rawMessages[provider] + if !ok { + // If the provider is not found in the data bytes, just don't set it in + // the providers map. That way the hardcoded servers will override them. + // This is user provided and could come from different sources in the + // future (e.g. a file or API request). + continue + } + + mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage) + if err != nil { + return models.AllServers{}, err + } else if !versionsMatch { + // mergedServers is the empty struct in this case, so don't set the key + // in the providerToServers map. + continue + } + servers.ProviderToServers[provider] = mergedServers } return servers, nil @@ -214,73 +81,20 @@ var ( ) func (s *Storage) readServers(provider string, hardcoded models.Servers, - serverVersion serverVersion, rawMessage json.RawMessage) ( - servers models.Servers, err error) { + rawMessage json.RawMessage) (servers models.Servers, versionsMatch bool, err error) { provider = strings.Title(provider) - if hardcoded.Version != serverVersion.Version { - s.logVersionDiff(provider, hardcoded.Version, serverVersion.Version) - return servers, nil - } - err = json.Unmarshal(rawMessage, &servers) + var persistedServers models.Servers + err = json.Unmarshal(rawMessage, &persistedServers) if err != nil { - return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err) + return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err) } - return servers, nil -} + versionsMatch = hardcoded.Version == persistedServers.Version + if !versionsMatch { + s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version) + return servers, versionsMatch, nil + } -// allVersions is a subset of models.AllServers structure used to track -// versions to avoid unmarshaling errors. -type allVersions struct { - Version uint16 `json:"version"` // used for migration of the top level scheme - Cyberghost serverVersion `json:"cyberghost"` - Expressvpn serverVersion `json:"expressvpn"` - Fastestvpn serverVersion `json:"fastestvpn"` - HideMyAss serverVersion `json:"hidemyass"` - Ipvanish serverVersion `json:"ipvanish"` - Ivpn serverVersion `json:"ivpn"` - Mullvad serverVersion `json:"mullvad"` - Nordvpn serverVersion `json:"nordvpn"` - Perfectprivacy serverVersion `json:"perfect privacy"` - Privado serverVersion `json:"privado"` - Pia serverVersion `json:"private internet access"` - Privatevpn serverVersion `json:"privatevpn"` - Protonvpn serverVersion `json:"protonvpn"` - Purevpn serverVersion `json:"purevpn"` - Surfshark serverVersion `json:"surfshark"` - Torguard serverVersion `json:"torguard"` - VPNUnlimited serverVersion `json:"vpn unlimited"` - Vyprvpn serverVersion `json:"vyprvpn"` - Wevpn serverVersion `json:"wevpn"` - Windscribe serverVersion `json:"windscribe"` -} - -type serverVersion struct { - Version uint16 `json:"version"` -} - -// allJSONRawMessages is to delay decoding of each provider servers. -type allJSONRawMessages struct { - Version uint16 `json:"version"` // used for migration of the top level scheme - Cyberghost json.RawMessage `json:"cyberghost"` - Expressvpn json.RawMessage `json:"expressvpn"` - Fastestvpn json.RawMessage `json:"fastestvpn"` - HideMyAss json.RawMessage `json:"hidemyass"` - Ipvanish json.RawMessage `json:"ipvanish"` - Ivpn json.RawMessage `json:"ivpn"` - Mullvad json.RawMessage `json:"mullvad"` - Nordvpn json.RawMessage `json:"nordvpn"` - Perfectprivacy json.RawMessage `json:"perfect privacy"` - Privado json.RawMessage `json:"privado"` - Pia json.RawMessage `json:"private internet access"` - Privatevpn json.RawMessage `json:"privatevpn"` - Protonvpn json.RawMessage `json:"protonvpn"` - Purevpn json.RawMessage `json:"purevpn"` - Surfshark json.RawMessage `json:"surfshark"` - Torguard json.RawMessage `json:"torguard"` - VPNUnlimited json.RawMessage `json:"vpn unlimited"` - Vyprvpn json.RawMessage `json:"vyprvpn"` - Wevpn json.RawMessage `json:"wevpn"` - Windscribe json.RawMessage `json:"windscribe"` + return persistedServers, versionsMatch, nil } diff --git a/internal/storage/read_test.go b/internal/storage/read_test.go index ff3dfe6a..7f5d77ad 100644 --- a/internal/storage/read_test.go +++ b/internal/storage/read_test.go @@ -1,80 +1,89 @@ package storage import ( - "errors" + "fmt" "testing" - gomock "github.com/golang/mock/gomock" + "github.com/golang/mock/gomock" + "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func populateProviders(allProviderVersion uint16, allProviderTimestamp int64, + servers models.AllServers) models.AllServers { + allProviders := providers.All() + if servers.ProviderToServers == nil { + servers.ProviderToServers = make(map[string]models.Servers, len(allProviders)-1) + } + for _, provider := range allProviders { + _, has := servers.ProviderToServers[provider] + if has { + continue + } + servers.ProviderToServers[provider] = models.Servers{ + Version: allProviderVersion, + Timestamp: allProviderTimestamp, + } + } + return servers +} + func Test_extractServersFromBytes(t *testing.T) { t.Parallel() testCases := map[string]struct { - b []byte - hardcoded models.AllServers - logged []string - persisted models.AllServers - err error + b []byte + hardcoded models.AllServers + logged []string + persisted models.AllServers + errMessage string }{ - "no data": { - err: errors.New("cannot decode versions: unexpected end of JSON input"), + "bad JSON": { + b: []byte("garbage"), + errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value", }, - "empty JSON": { - b: []byte("{}"), - err: errors.New("cannot decode servers for provider: Cyberghost: unexpected end of JSON input"), + "bad provider JSON": { + b: []byte(`{"cyberghost": "garbage"}`), + hardcoded: populateProviders(1, 0, models.AllServers{}), + errMessage: "cannot decode servers for provider: Cyberghost: " + + "json: cannot unmarshal string into Go value of type models.Servers", }, - "different versions": { - b: []byte(`{}`), - hardcoded: models.AllServers{ - Cyberghost: models.Servers{Version: 1}, - Expressvpn: models.Servers{Version: 1}, - Fastestvpn: models.Servers{Version: 1}, - HideMyAss: models.Servers{Version: 1}, - Ipvanish: models.Servers{Version: 1}, - Ivpn: models.Servers{Version: 1}, - Mullvad: models.Servers{Version: 1}, - Nordvpn: models.Servers{Version: 1}, - Perfectprivacy: models.Servers{Version: 1}, - Privado: models.Servers{Version: 1}, - Pia: models.Servers{Version: 1}, - Privatevpn: models.Servers{Version: 1}, - Protonvpn: models.Servers{Version: 1}, - Purevpn: models.Servers{Version: 1}, - Surfshark: models.Servers{Version: 1}, - Torguard: models.Servers{Version: 1}, - VPNUnlimited: models.Servers{Version: 1}, - Vyprvpn: models.Servers{Version: 1}, - Wevpn: models.Servers{Version: 1}, - Windscribe: models.Servers{Version: 1}, - }, - logged: []string{ - "Cyberghost servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Expressvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Fastestvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Hidemyass servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Ipvanish servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Ivpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Mullvad servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Nordvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Perfect Privacy servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Privado servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Private Internet Access servers from file discarded because they have version 0 and hardcoded servers have version 1", //nolint:lll - "Privatevpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Protonvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Purevpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Surfshark servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Torguard servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Vpn Unlimited servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Vyprvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Wevpn servers from file discarded because they have version 0 and hardcoded servers have version 1", - "Windscribe servers from file discarded because they have version 0 and hardcoded servers have version 1", + "absent provider keys": { + b: []byte(`{}`), + hardcoded: populateProviders(1, 0, models.AllServers{}), + persisted: models.AllServers{ + ProviderToServers: map[string]models.Servers{}, }, }, "same versions": { + b: []byte(`{ + "cyberghost": {"version": 1, "timestamp": 1}, + "expressvpn": {"version": 1, "timestamp": 1}, + "fastestvpn": {"version": 1, "timestamp": 1}, + "hidemyass": {"version": 1, "timestamp": 1}, + "ipvanish": {"version": 1, "timestamp": 1}, + "ivpn": {"version": 1, "timestamp": 1}, + "mullvad": {"version": 1, "timestamp": 1}, + "nordvpn": {"version": 1, "timestamp": 1}, + "perfect privacy": {"version": 1, "timestamp": 1}, + "privado": {"version": 1, "timestamp": 1}, + "private internet access": {"version": 1, "timestamp": 1}, + "privatevpn": {"version": 1, "timestamp": 1}, + "protonvpn": {"version": 1, "timestamp": 1}, + "purevpn": {"version": 1, "timestamp": 1}, + "surfshark": {"version": 1, "timestamp": 1}, + "torguard": {"version": 1, "timestamp": 1}, + "vpn unlimited": {"version": 1, "timestamp": 1}, + "vyprvpn": {"version": 1, "timestamp": 1}, + "wevpn": {"version": 1, "timestamp": 1}, + "windscribe": {"version": 1, "timestamp": 1} + }`), + hardcoded: populateProviders(1, 0, models.AllServers{}), + persisted: populateProviders(1, 1, models.AllServers{}), + }, + "different versions": { b: []byte(`{ "cyberghost": {"version": 1, "timestamp": 1}, "expressvpn": {"version": 1, "timestamp": 1}, @@ -97,49 +106,31 @@ func Test_extractServersFromBytes(t *testing.T) { "wevpn": {"version": 1, "timestamp": 1}, "windscribe": {"version": 1, "timestamp": 1} }`), - hardcoded: models.AllServers{ - Cyberghost: models.Servers{Version: 1}, - Expressvpn: models.Servers{Version: 1}, - Fastestvpn: models.Servers{Version: 1}, - HideMyAss: models.Servers{Version: 1}, - Ipvanish: models.Servers{Version: 1}, - Ivpn: models.Servers{Version: 1}, - Mullvad: models.Servers{Version: 1}, - Nordvpn: models.Servers{Version: 1}, - Perfectprivacy: models.Servers{Version: 1}, - Privado: models.Servers{Version: 1}, - Pia: models.Servers{Version: 1}, - Privatevpn: models.Servers{Version: 1}, - Protonvpn: models.Servers{Version: 1}, - Purevpn: models.Servers{Version: 1}, - Surfshark: models.Servers{Version: 1}, - Torguard: models.Servers{Version: 1}, - VPNUnlimited: models.Servers{Version: 1}, - Vyprvpn: models.Servers{Version: 1}, - Wevpn: models.Servers{Version: 1}, - Windscribe: models.Servers{Version: 1}, + hardcoded: populateProviders(2, 0, models.AllServers{}), + logged: []string{ + "Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Fastestvpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Hidemyass servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Ipvanish servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Ivpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Mullvad servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Nordvpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Perfect Privacy servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Privado servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Private Internet Access servers from file discarded because they have version 1 and hardcoded servers have version 2", //nolint:lll + "Privatevpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Protonvpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Purevpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Surfshark servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Torguard servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Vpn Unlimited servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Vyprvpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Wevpn servers from file discarded because they have version 1 and hardcoded servers have version 2", + "Windscribe servers from file discarded because they have version 1 and hardcoded servers have version 2", }, persisted: models.AllServers{ - Cyberghost: models.Servers{Version: 1, Timestamp: 1}, - Expressvpn: models.Servers{Version: 1, Timestamp: 1}, - Fastestvpn: models.Servers{Version: 1, Timestamp: 1}, - HideMyAss: models.Servers{Version: 1, Timestamp: 1}, - Ipvanish: models.Servers{Version: 1, Timestamp: 1}, - Ivpn: models.Servers{Version: 1, Timestamp: 1}, - Mullvad: models.Servers{Version: 1, Timestamp: 1}, - Nordvpn: models.Servers{Version: 1, Timestamp: 1}, - Perfectprivacy: models.Servers{Version: 1, Timestamp: 1}, - Privado: models.Servers{Version: 1, Timestamp: 1}, - Pia: models.Servers{Version: 1, Timestamp: 1}, - Privatevpn: models.Servers{Version: 1, Timestamp: 1}, - Protonvpn: models.Servers{Version: 1, Timestamp: 1}, - Purevpn: models.Servers{Version: 1, Timestamp: 1}, - Surfshark: models.Servers{Version: 1, Timestamp: 1}, - Torguard: models.Servers{Version: 1, Timestamp: 1}, - VPNUnlimited: models.Servers{Version: 1, Timestamp: 1}, - Vyprvpn: models.Servers{Version: 1, Timestamp: 1}, - Wevpn: models.Servers{Version: 1, Timestamp: 1}, - Windscribe: models.Servers{Version: 1, Timestamp: 1}, + ProviderToServers: map[string]models.Servers{}, }, }, } @@ -151,8 +142,13 @@ func Test_extractServersFromBytes(t *testing.T) { ctrl := gomock.NewController(t) logger := NewMockInfoErrorer(ctrl) + var previousLogCall *gomock.Call for _, logged := range testCase.logged { - logger.EXPECT().Info(logged) + call := logger.EXPECT().Info(logged) + if previousLogCall != nil { + call.After(previousLogCall) + } + previousLogCall = call } s := &Storage{ @@ -161,9 +157,8 @@ func Test_extractServersFromBytes(t *testing.T) { servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded) - if testCase.err != nil { - require.Error(t, err) - assert.Equal(t, testCase.err.Error(), err.Error()) + if testCase.errMessage != "" { + assert.EqualError(t, err, testCase.errMessage) } else { assert.NoError(t, err) } @@ -171,4 +166,25 @@ func Test_extractServersFromBytes(t *testing.T) { assert.Equal(t, testCase.persisted, servers) }) } + + t.Run("hardcoded panic", func(t *testing.T) { + t.Parallel() + + s := &Storage{} + + allProviders := providers.All() + require.GreaterOrEqual(t, len(allProviders), 2) + + b := []byte(`{}`) + hardcoded := models.AllServers{ + ProviderToServers: map[string]models.Servers{ + allProviders[0]: {}, + // Missing provider allProviders[1] + }, + } + expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1]) + assert.PanicsWithValue(t, expectedPanicValue, func() { + _, _ = s.extractServersFromBytes(b, hardcoded) + }) + }) } diff --git a/internal/storage/sync.go b/internal/storage/sync.go index e6ca57bb..881bcdce 100644 --- a/internal/storage/sync.go +++ b/internal/storage/sync.go @@ -7,27 +7,11 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -func countServers(allServers models.AllServers) int { - return len(allServers.Cyberghost.Servers) + - len(allServers.Expressvpn.Servers) + - len(allServers.Fastestvpn.Servers) + - len(allServers.HideMyAss.Servers) + - len(allServers.Ipvanish.Servers) + - len(allServers.Ivpn.Servers) + - len(allServers.Mullvad.Servers) + - len(allServers.Nordvpn.Servers) + - len(allServers.Perfectprivacy.Servers) + - len(allServers.Privado.Servers) + - len(allServers.Pia.Servers) + - len(allServers.Privatevpn.Servers) + - len(allServers.Protonvpn.Servers) + - len(allServers.Purevpn.Servers) + - len(allServers.Surfshark.Servers) + - len(allServers.Torguard.Servers) + - len(allServers.VPNUnlimited.Servers) + - len(allServers.Vyprvpn.Servers) + - len(allServers.Wevpn.Servers) + - len(allServers.Windscribe.Servers) +func countServers(allServers models.AllServers) (count int) { + for _, servers := range allServers.ProviderToServers { + count += len(servers.Servers) + } + return count } func (s *Storage) SyncServers() (err error) { @@ -57,7 +41,7 @@ func (s *Storage) SyncServers() (err error) { return nil } - if err := flushToFile(s.filepath, s.mergedServers); err != nil { + if err := flushToFile(s.filepath, &s.mergedServers); err != nil { return fmt.Errorf("cannot write servers to file: %w", err) } return nil diff --git a/internal/updater/loop.go b/internal/updater/loop.go index 2113231d..93f13ea3 100644 --- a/internal/updater/loop.go +++ b/internal/updater/loop.go @@ -139,7 +139,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { l.stopped <- struct{}{} case servers := <-serversCh: l.setAllServers(servers) - if err := l.flusher.FlushToFile(servers); err != nil { + if err := l.flusher.FlushToFile(&servers); err != nil { l.logger.Error(err.Error()) } runWg.Wait() diff --git a/internal/updater/providers.go b/internal/updater/providers.go index dc050482..498cb906 100644 --- a/internal/updater/providers.go +++ b/internal/updater/providers.go @@ -4,8 +4,10 @@ import ( "context" "fmt" "reflect" + "time" "github.com/qdm12/gluetun/internal/constants/providers" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/updater/providers/cyberghost" "github.com/qdm12/gluetun/internal/updater/providers/expressvpn" "github.com/qdm12/gluetun/internal/updater/providers/fastestvpn" @@ -28,419 +30,96 @@ import ( "github.com/qdm12/gluetun/internal/updater/providers/windscribe" ) -func (u *updater) updateCyberghost(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Cyberghost.Servers)) - servers, err := cyberghost.GetServers(ctx, u.presolver, minServers) +func (u *updater) updateProvider(ctx context.Context, provider string) ( + warnings []string, err error) { + existingServers := u.getProviderServers(provider) + minServers := getMinServers(existingServers) + servers, warnings, err := u.getServers(ctx, provider, minServers) if err != nil { - return err + return warnings, err } - if reflect.DeepEqual(u.servers.Cyberghost.Servers, servers) { - return nil + if reflect.DeepEqual(existingServers, servers) { + return warnings, nil } - u.servers.Cyberghost.Timestamp = u.timeNow().Unix() - u.servers.Cyberghost.Servers = servers - return nil + u.patchProvider(provider, servers) + return warnings, nil } -func (u *updater) updateExpressvpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Expressvpn.Servers)) - servers, warnings, err := expressvpn.GetServers( - ctx, u.unzipper, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("ExpressVPN: " + warning) - } +func (u *updater) getServers(ctx context.Context, provider string, + minServers int) (servers []models.Server, warnings []string, err error) { + switch provider { + case providers.Custom: + panic("cannot update custom provider") + case providers.Cyberghost: + servers, err = cyberghost.GetServers(ctx, u.presolver, minServers) + return servers, nil, err + case providers.Expressvpn: + return expressvpn.GetServers(ctx, u.unzipper, u.presolver, minServers) + case providers.Fastestvpn: + return fastestvpn.GetServers(ctx, u.unzipper, u.presolver, minServers) + case providers.HideMyAss: + return hidemyass.GetServers(ctx, u.client, u.presolver, minServers) + case providers.Ipvanish: + return ipvanish.GetServers(ctx, u.unzipper, u.presolver, minServers) + case providers.Ivpn: + return ivpn.GetServers(ctx, u.client, u.presolver, minServers) + case providers.Mullvad: + servers, err = mullvad.GetServers(ctx, u.client, minServers) + return servers, nil, err + case providers.Nordvpn: + return nordvpn.GetServers(ctx, u.client, minServers) + case providers.Perfectprivacy: + return perfectprivacy.GetServers(ctx, u.unzipper, minServers) + case providers.Privado: + return privado.GetServers(ctx, u.unzipper, u.client, u.presolver, minServers) + case providers.PrivateInternetAccess: + servers, err = pia.GetServers(ctx, u.client, minServers) + return servers, nil, err + case providers.Privatevpn: + return privatevpn.GetServers(ctx, u.unzipper, u.presolver, minServers) + case providers.Protonvpn: + return protonvpn.GetServers(ctx, u.client, minServers) + case providers.Purevpn: + return purevpn.GetServers(ctx, u.client, u.unzipper, u.presolver, minServers) + case providers.Surfshark: + return surfshark.GetServers(ctx, u.unzipper, u.client, u.presolver, minServers) + case providers.Torguard: + return torguard.GetServers(ctx, u.unzipper, u.presolver, minServers) + case providers.VPNUnlimited: + return vpnunlimited.GetServers(ctx, u.unzipper, u.presolver, minServers) + case providers.Vyprvpn: + return vyprvpn.GetServers(ctx, u.unzipper, u.presolver, minServers) + case providers.Wevpn: + return wevpn.GetServers(ctx, u.presolver, minServers) + case providers.Windscribe: + servers, err = windscribe.GetServers(ctx, u.client, minServers) + return servers, nil, err + default: + panic("provider " + provider + " is unknown") } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Expressvpn.Servers, servers) { - return nil - } - - u.servers.Expressvpn.Timestamp = u.timeNow().Unix() - u.servers.Expressvpn.Servers = servers - return nil } -func (u *updater) updateFastestvpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Fastestvpn.Servers)) - servers, warnings, err := fastestvpn.GetServers( - ctx, u.unzipper, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("FastestVPN: " + warning) - } +func (u *updater) getProviderServers(provider string) (servers []models.Server) { + providerServers, ok := u.servers.ProviderToServers[provider] + if !ok { + panic(fmt.Sprintf("provider %s is unknown", provider)) } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Fastestvpn.Servers, servers) { - return nil - } - - u.servers.Fastestvpn.Timestamp = u.timeNow().Unix() - u.servers.Fastestvpn.Servers = servers - return nil + return providerServers.Servers } -func (u *updater) updateHideMyAss(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.HideMyAss.Servers)) - servers, warnings, err := hidemyass.GetServers( - ctx, u.client, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("HideMyAss: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.HideMyAss.Servers, servers) { - return nil - } - - u.servers.HideMyAss.Timestamp = u.timeNow().Unix() - u.servers.HideMyAss.Servers = servers - return nil -} - -func (u *updater) updateIpvanish(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Ipvanish.Servers)) - servers, warnings, err := ipvanish.GetServers( - ctx, u.unzipper, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("Ipvanish: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Ipvanish.Servers, servers) { - return nil - } - - u.servers.Ipvanish.Timestamp = u.timeNow().Unix() - u.servers.Ipvanish.Servers = servers - return nil -} - -func (u *updater) updateIvpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Ivpn.Servers)) - servers, warnings, err := ivpn.GetServers( - ctx, u.client, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("Ivpn: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Ivpn.Servers, servers) { - return nil - } - - u.servers.Ivpn.Timestamp = u.timeNow().Unix() - u.servers.Ivpn.Servers = servers - return nil -} - -func (u *updater) updateMullvad(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Mullvad.Servers)) - servers, err := mullvad.GetServers(ctx, u.client, minServers) - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Mullvad.Servers, servers) { - return nil - } - - u.servers.Mullvad.Timestamp = u.timeNow().Unix() - u.servers.Mullvad.Servers = servers - return nil -} - -func (u *updater) updateNordvpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Nordvpn.Servers)) - servers, warnings, err := nordvpn.GetServers(ctx, u.client, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("NordVPN: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Nordvpn.Servers, servers) { - return nil - } - - u.servers.Nordvpn.Timestamp = u.timeNow().Unix() - u.servers.Nordvpn.Servers = servers - return nil -} - -func (u *updater) updatePerfectprivacy(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Perfectprivacy.Servers)) - servers, warnings, err := perfectprivacy.GetServers(ctx, u.unzipper, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn(providers.Perfectprivacy + ": " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Perfectprivacy.Servers, servers) { - return nil - } - - u.servers.Perfectprivacy.Timestamp = u.timeNow().Unix() - u.servers.Perfectprivacy.Servers = servers - return nil -} - -func (u *updater) updatePIA(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Pia.Servers)) - servers, err := pia.GetServers(ctx, u.client, minServers) - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Pia.Servers, servers) { - return nil - } - - u.servers.Pia.Timestamp = u.timeNow().Unix() - u.servers.Pia.Servers = servers - return nil -} - -func (u *updater) updatePrivado(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Privado.Servers)) - servers, warnings, err := privado.GetServers( - ctx, u.unzipper, u.client, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("Privado: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Privado.Servers, servers) { - return nil - } - - u.servers.Privado.Timestamp = u.timeNow().Unix() - u.servers.Privado.Servers = servers - return nil -} - -func (u *updater) updatePrivatevpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Privatevpn.Servers)) - servers, warnings, err := privatevpn.GetServers( - ctx, u.unzipper, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("PrivateVPN: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Privatevpn.Servers, servers) { - return nil - } - - u.servers.Privatevpn.Timestamp = u.timeNow().Unix() - u.servers.Privatevpn.Servers = servers - return nil -} - -func (u *updater) updateProtonvpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Privatevpn.Servers)) - servers, warnings, err := protonvpn.GetServers(ctx, u.client, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("ProtonVPN: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Protonvpn.Servers, servers) { - return nil - } - - u.servers.Protonvpn.Timestamp = u.timeNow().Unix() - u.servers.Protonvpn.Servers = servers - return nil -} - -func (u *updater) updatePurevpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Purevpn.Servers)) - servers, warnings, err := purevpn.GetServers( - ctx, u.client, u.unzipper, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("PureVPN: " + warning) - } - } - if err != nil { - return fmt.Errorf("cannot update Purevpn servers: %w", err) - } - - if reflect.DeepEqual(u.servers.Purevpn.Servers, servers) { - return nil - } - - u.servers.Purevpn.Timestamp = u.timeNow().Unix() - u.servers.Purevpn.Servers = servers - return nil -} - -func (u *updater) updateSurfshark(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Surfshark.Servers)) - servers, warnings, err := surfshark.GetServers( - ctx, u.unzipper, u.client, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("Surfshark: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Surfshark.Servers, servers) { - return nil - } - - u.servers.Surfshark.Timestamp = u.timeNow().Unix() - u.servers.Surfshark.Servers = servers - return nil -} - -func (u *updater) updateTorguard(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Torguard.Servers)) - servers, warnings, err := torguard.GetServers( - ctx, u.unzipper, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("Torguard: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Torguard.Servers, servers) { - return nil - } - - u.servers.Torguard.Timestamp = u.timeNow().Unix() - u.servers.Torguard.Servers = servers - return nil -} - -func (u *updater) updateVPNUnlimited(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.VPNUnlimited.Servers)) - servers, warnings, err := vpnunlimited.GetServers( - ctx, u.unzipper, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn(providers.VPNUnlimited + ": " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.VPNUnlimited.Servers, servers) { - return nil - } - - u.servers.VPNUnlimited.Timestamp = u.timeNow().Unix() - u.servers.VPNUnlimited.Servers = servers - return nil -} - -func (u *updater) updateVyprvpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Vyprvpn.Servers)) - servers, warnings, err := vyprvpn.GetServers( - ctx, u.unzipper, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("VyprVPN: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Vyprvpn.Servers, servers) { - return nil - } - - u.servers.Vyprvpn.Timestamp = u.timeNow().Unix() - u.servers.Vyprvpn.Servers = servers - return nil -} - -func (u *updater) updateWevpn(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Wevpn.Servers)) - servers, warnings, err := wevpn.GetServers(ctx, u.presolver, minServers) - if *u.options.CLI { - for _, warning := range warnings { - u.logger.Warn("WeVPN: " + warning) - } - } - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Wevpn.Servers, servers) { - return nil - } - - u.servers.Wevpn.Timestamp = u.timeNow().Unix() - u.servers.Wevpn.Servers = servers - return nil -} - -func (u *updater) updateWindscribe(ctx context.Context) (err error) { - minServers := getMinServers(len(u.servers.Windscribe.Servers)) - servers, err := windscribe.GetServers(ctx, u.client, minServers) - if err != nil { - return err - } - - if reflect.DeepEqual(u.servers.Windscribe.Servers, servers) { - return nil - } - - u.servers.Windscribe.Timestamp = u.timeNow().Unix() - u.servers.Windscribe.Servers = servers - return nil -} - -func getMinServers(existingServers int) (minServers int) { +func getMinServers(servers []models.Server) (minServers int) { const minRatio = 0.8 - return int(minRatio * float64(existingServers)) + return int(minRatio * float64(len(servers))) +} + +func (u *updater) patchProvider(provider string, servers []models.Server) { + providerServers, ok := u.servers.ProviderToServers[provider] + if !ok { + panic(fmt.Sprintf("provider %s is unknown", provider)) + } + providerServers.Timestamp = time.Now().Unix() + providerServers.Servers = servers + u.servers.ProviderToServers[provider] = providerServers } diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 30706492..2703b892 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -8,7 +8,6 @@ import ( "time" "github.com/qdm12/gluetun/internal/configuration/settings" - "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/unzip" @@ -47,16 +46,17 @@ func New(settings settings.Updater, httpClient *http.Client, } } -type updateFunc func(ctx context.Context) (err error) - func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) { for _, provider := range u.options.Providers { u.logger.Info("updating " + strings.Title(provider) + " servers...") - updateProvider := u.getUpdateFunction(provider) - // TODO support servers offering only TCP or only UDP // for NordVPN and PureVPN - err = updateProvider(ctx) + warnings, err := u.updateProvider(ctx, provider) + if *u.options.CLI { + for _, warning := range warnings { + u.logger.Warn(provider + ": " + warning) + } + } if err != nil { if ctxErr := ctx.Err(); ctxErr != nil { return allServers, ctxErr @@ -67,52 +67,3 @@ func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServe return u.servers, nil } - -func (u *updater) getUpdateFunction(provider string) (updateFunction updateFunc) { - switch provider { - case providers.Custom: - panic("cannot update custom provider") - case providers.Cyberghost: - return func(ctx context.Context) (err error) { return u.updateCyberghost(ctx) } - case providers.Expressvpn: - return func(ctx context.Context) (err error) { return u.updateExpressvpn(ctx) } - case providers.Fastestvpn: - return func(ctx context.Context) (err error) { return u.updateFastestvpn(ctx) } - case providers.HideMyAss: - return func(ctx context.Context) (err error) { return u.updateHideMyAss(ctx) } - case providers.Ipvanish: - return func(ctx context.Context) (err error) { return u.updateIpvanish(ctx) } - case providers.Ivpn: - return func(ctx context.Context) (err error) { return u.updateIvpn(ctx) } - case providers.Mullvad: - return func(ctx context.Context) (err error) { return u.updateMullvad(ctx) } - case providers.Nordvpn: - return func(ctx context.Context) (err error) { return u.updateNordvpn(ctx) } - case providers.Perfectprivacy: - return func(ctx context.Context) (err error) { return u.updatePerfectprivacy(ctx) } - case providers.Privado: - return func(ctx context.Context) (err error) { return u.updatePrivado(ctx) } - case providers.PrivateInternetAccess: - return func(ctx context.Context) (err error) { return u.updatePIA(ctx) } - case providers.Privatevpn: - return func(ctx context.Context) (err error) { return u.updatePrivatevpn(ctx) } - case providers.Protonvpn: - return func(ctx context.Context) (err error) { return u.updateProtonvpn(ctx) } - case providers.Purevpn: - return func(ctx context.Context) (err error) { return u.updatePurevpn(ctx) } - case providers.Surfshark: - return func(ctx context.Context) (err error) { return u.updateSurfshark(ctx) } - case providers.Torguard: - return func(ctx context.Context) (err error) { return u.updateTorguard(ctx) } - case providers.VPNUnlimited: - return func(ctx context.Context) (err error) { return u.updateVPNUnlimited(ctx) } - case providers.Vyprvpn: - return func(ctx context.Context) (err error) { return u.updateVyprvpn(ctx) } - case providers.Wevpn: - return func(ctx context.Context) (err error) { return u.updateWevpn(ctx) } - case providers.Windscribe: - return func(ctx context.Context) (err error) { return u.updateWindscribe(ctx) } - default: - panic("provider " + provider + " is unknown") - } -}