chore(all): provider to servers map in allServers

- Simplify formatting CLI
- Simplify updater code
- Simplify filter choices for config validation
- Simplify all servers deep copying
- Custom JSON marshaling methods for `AllServers`
- Simplify provider constructor switch
- Simplify storage merging
- Simplify storage reading and extraction
- Simplify updating code
This commit is contained in:
Quentin McGaw
2022-05-27 00:59:47 +00:00
parent 5ffe8555ba
commit bd0868d764
22 changed files with 854 additions and 1295 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
) )
@@ -18,8 +19,9 @@ type ServersFormatter interface {
} }
var ( var (
ErrFormatNotRecognized = errors.New("format is not recognized") ErrFormatNotRecognized = errors.New("format is not recognized")
ErrProviderUnspecified = errors.New("VPN provider to format was not specified") ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified")
) )
func addProviderFlag(flagSet *flag.FlagSet, func addProviderFlag(flagSet *flag.FlagSet,
@@ -31,21 +33,12 @@ func addProviderFlag(flagSet *flag.FlagSet,
flagSet.BoolVar(boolPtr, provider, false, "Format "+strings.Title(provider)+" servers") flagSet.BoolVar(boolPtr, provider, false, "Format "+strings.Title(provider)+" servers")
} }
func getFormatForProvider(providerToFormat map[string]*bool, provider string) (format bool) {
formatPtr, ok := providerToFormat[provider]
if !ok {
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
}
return *formatPtr
}
func (c *CLI) FormatServers(args []string) error { func (c *CLI) FormatServers(args []string) error {
var format, output string var format, output string
allProviders := providers.All() allProviders := providers.All()
providersToFormat := make(map[string]*bool, len(allProviders)) providersToFormat := make(map[string]*bool, len(allProviders))
for _, provider := range allProviders { for _, provider := range allProviders {
value := false providersToFormat[provider] = new(bool)
providersToFormat[provider] = &value
} }
flagSet := flag.NewFlagSet("markdown", flag.ExitOnError) flagSet := flag.NewFlagSet("markdown", flag.ExitOnError)
flagSet.StringVar(&format, "format", "markdown", "Format to use which can be: 'markdown'") flagSet.StringVar(&format, "format", "markdown", "Format to use which can be: 'markdown'")
@@ -61,6 +54,24 @@ func (c *CLI) FormatServers(args []string) error {
return fmt.Errorf("%w: %s", ErrFormatNotRecognized, format) return fmt.Errorf("%w: %s", ErrFormatNotRecognized, format)
} }
// Verify only one provider is set to be formatted.
var providers []string
for provider, formatPtr := range providersToFormat {
if *formatPtr {
providers = append(providers, provider)
}
}
switch len(providers) {
case 0:
return ErrProviderUnspecified
case 1:
default:
return fmt.Errorf("%w: %d specified: %s",
ErrMultipleProvidersToFormat, len(providers),
strings.Join(providers, ", "))
}
providerToFormat := providers[0]
logger := newNoopLogger() logger := newNoopLogger()
storage, err := storage.New(logger, constants.ServersData) storage, err := storage.New(logger, constants.ServersData)
if err != nil { if err != nil {
@@ -68,51 +79,7 @@ func (c *CLI) FormatServers(args []string) error {
} }
currentServers := storage.GetServers() currentServers := storage.GetServers()
var formatted string formatted := formatServers(currentServers, providerToFormat)
switch {
case getFormatForProvider(providersToFormat, providers.Cyberghost):
formatted = currentServers.Cyberghost.ToMarkdown(providers.Cyberghost)
case getFormatForProvider(providersToFormat, providers.Expressvpn):
formatted = currentServers.Expressvpn.ToMarkdown(providers.Expressvpn)
case getFormatForProvider(providersToFormat, providers.Fastestvpn):
formatted = currentServers.Fastestvpn.ToMarkdown(providers.Fastestvpn)
case getFormatForProvider(providersToFormat, providers.HideMyAss):
formatted = currentServers.HideMyAss.ToMarkdown(providers.HideMyAss)
case getFormatForProvider(providersToFormat, providers.Ipvanish):
formatted = currentServers.Ipvanish.ToMarkdown(providers.Ipvanish)
case getFormatForProvider(providersToFormat, providers.Ivpn):
formatted = currentServers.Ivpn.ToMarkdown(providers.Ivpn)
case getFormatForProvider(providersToFormat, providers.Mullvad):
formatted = currentServers.Mullvad.ToMarkdown(providers.Mullvad)
case getFormatForProvider(providersToFormat, providers.Nordvpn):
formatted = currentServers.Nordvpn.ToMarkdown(providers.Nordvpn)
case getFormatForProvider(providersToFormat, providers.Perfectprivacy):
formatted = currentServers.Perfectprivacy.ToMarkdown(providers.Perfectprivacy)
case getFormatForProvider(providersToFormat, providers.PrivateInternetAccess):
formatted = currentServers.Pia.ToMarkdown(providers.PrivateInternetAccess)
case getFormatForProvider(providersToFormat, providers.Privado):
formatted = currentServers.Privado.ToMarkdown(providers.Privado)
case getFormatForProvider(providersToFormat, providers.Privatevpn):
formatted = currentServers.Privatevpn.ToMarkdown(providers.Privatevpn)
case getFormatForProvider(providersToFormat, providers.Protonvpn):
formatted = currentServers.Protonvpn.ToMarkdown(providers.Protonvpn)
case getFormatForProvider(providersToFormat, providers.Purevpn):
formatted = currentServers.Purevpn.ToMarkdown(providers.Purevpn)
case getFormatForProvider(providersToFormat, providers.Surfshark):
formatted = currentServers.Surfshark.ToMarkdown(providers.Surfshark)
case getFormatForProvider(providersToFormat, providers.Torguard):
formatted = currentServers.Torguard.ToMarkdown(providers.Torguard)
case getFormatForProvider(providersToFormat, providers.VPNUnlimited):
formatted = currentServers.VPNUnlimited.ToMarkdown(providers.VPNUnlimited)
case getFormatForProvider(providersToFormat, providers.Vyprvpn):
formatted = currentServers.Vyprvpn.ToMarkdown(providers.Vyprvpn)
case getFormatForProvider(providersToFormat, providers.Wevpn):
formatted = currentServers.Wevpn.ToMarkdown(providers.Wevpn)
case getFormatForProvider(providersToFormat, providers.Windscribe):
formatted = currentServers.Windscribe.ToMarkdown(providers.Windscribe)
default:
return ErrProviderUnspecified
}
output = filepath.Clean(output) output = filepath.Clean(output)
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644) file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
@@ -133,3 +100,11 @@ func (c *CLI) FormatServers(args []string) error {
return nil return nil
} }
func formatServers(allServers models.AllServers, provider string) (formatted string) {
servers, ok := allServers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
}
return servers.ToMarkdown(providers.Cyberghost)
}

View File

@@ -63,12 +63,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
} }
if updateAll { if updateAll {
for _, provider := range providers.All() { options.Providers = providers.All()
if provider == providers.Custom {
continue
}
options.Providers = append(options.Providers, provider)
}
} else { } else {
if csvProviders == "" { if csvProviders == "" {
return ErrNoProviderSpecified return ErrNoProviderSpecified
@@ -99,13 +94,13 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
} }
if endUserMode { if endUserMode {
if err := storage.FlushToFile(allServers); err != nil { if err := storage.FlushToFile(&allServers); err != nil {
return fmt.Errorf("cannot write updated information to file: %w", err) return fmt.Errorf("cannot write updated information to file: %w", err)
} }
} }
if maintainerMode { if maintainerMode {
if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil { if err := writeToEmbeddedJSON(c.repoServersPath, &allServers); err != nil {
return fmt.Errorf("cannot write updated information to file: %w", err) return fmt.Errorf("cannot write updated information to file: %w", err)
} }
} }
@@ -114,7 +109,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
} }
func writeToEmbeddedJSON(repoServersPath string, func writeToEmbeddedJSON(repoServersPath string,
allServers models.AllServers) error { allServers *models.AllServers) error {
const perms = 0600 const perms = 0600
f, err := os.OpenFile(repoServersPath, f, err := os.OpenFile(repoServersPath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms) os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)

View File

@@ -27,7 +27,7 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e
// Validate Name // Validate Name
var validNames []string var validNames []string
if vpnType == vpn.OpenVPN { if vpnType == vpn.OpenVPN {
validNames = providers.All() validNames = providers.AllWithCustom()
validNames = append(validNames, "pia") // Retro-compatibility validNames = append(validNames, "pia") // Retro-compatibility
} else { // Wireguard } else { // Wireguard
validNames = []string{ validNames = []string{

View File

@@ -140,118 +140,26 @@ func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection,
countryChoices, regionChoices, cityChoices, countryChoices, regionChoices, cityChoices,
ispChoices, nameChoices, hostnameChoices []string, ispChoices, nameChoices, hostnameChoices []string,
err error) { err error) {
switch vpnServiceProvider { providerServers, ok := allServers.ProviderToServers[vpnServiceProvider]
case providers.Custom: if !ok {
case providers.Cyberghost: panic(fmt.Sprintf("VPN service provider unknown: %s", vpnServiceProvider))
servers := allServers.GetCyberghost() }
countryChoices = validation.ExtractCountries(servers) servers := providerServers.Servers
hostnameChoices = validation.ExtractHostnames(servers) countryChoices = validation.ExtractCountries(servers)
case providers.Expressvpn: regionChoices = validation.ExtractRegions(servers)
servers := allServers.GetExpressvpn() cityChoices = validation.ExtractCities(servers)
countryChoices = validation.ExtractCountries(servers) ispChoices = validation.ExtractISPs(servers)
cityChoices = validation.ExtractCities(servers) nameChoices = validation.ExtractServerNames(servers)
hostnameChoices = validation.ExtractHostnames(servers) hostnameChoices = validation.ExtractHostnames(servers)
case providers.Fastestvpn:
servers := allServers.GetFastestvpn() if vpnServiceProvider == providers.Surfshark {
countryChoices = validation.ExtractCountries(servers) // // Retro compatibility
hostnameChoices = validation.ExtractHostnames(servers)
case providers.HideMyAss:
servers := allServers.GetHideMyAss()
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Ipvanish:
servers := allServers.GetIpvanish()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Ivpn:
servers := allServers.GetIvpn()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
ispChoices = validation.ExtractISPs(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Mullvad:
servers := allServers.GetMullvad()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
ispChoices = validation.ExtractISPs(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Nordvpn:
servers := allServers.GetNordvpn()
regionChoices = validation.ExtractRegions(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Perfectprivacy:
servers := allServers.GetPerfectprivacy()
cityChoices = validation.ExtractCities(servers)
case providers.Privado:
servers := allServers.GetPrivado()
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.PrivateInternetAccess:
servers := allServers.GetPia()
regionChoices = validation.ExtractRegions(servers)
hostnameChoices = validation.ExtractHostnames(servers)
nameChoices = validation.ExtractServerNames(servers)
case providers.Privatevpn:
servers := allServers.GetPrivatevpn()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Protonvpn:
servers := allServers.GetProtonvpn()
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
nameChoices = validation.ExtractServerNames(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Purevpn:
servers := allServers.GetPurevpn()
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Surfshark:
servers := allServers.GetSurfshark()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
regionChoices = validation.ExtractRegions(servers)
// TODO v4 remove // TODO v4 remove
regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...) regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...)
if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil { if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err) return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err)
} }
// Retro compatibility
// TODO remove in v4
*ss = surfsharkRetroRegion(*ss) *ss = surfsharkRetroRegion(*ss)
case providers.Torguard:
servers := allServers.GetTorguard()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.VPNUnlimited:
servers := allServers.GetVPNUnlimited()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Vyprvpn:
servers := allServers.GetVyprvpn()
regionChoices = validation.ExtractRegions(servers)
case providers.Wevpn:
servers := allServers.GetWevpn()
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Windscribe:
servers := allServers.GetWindscribe()
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrVPNProviderNameNotValid, vpnServiceProvider)
} }
return countryChoices, regionChoices, cityChoices, return countryChoices, regionChoices, cityChoices,

View File

@@ -43,10 +43,6 @@ func (u Updater) Validate() (err error) {
for i, provider := range u.Providers { for i, provider := range u.Providers {
valid := false valid := false
for _, validProvider := range providers.All() { for _, validProvider := range providers.All() {
if validProvider == providers.Custom {
continue
}
if provider == validProvider { if provider == validProvider {
valid = true valid = true
break break

View File

@@ -19,6 +19,9 @@ func ExtractCountries(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers)) values = make([]string, 0, len(servers))
for _, server := range servers { for _, server := range servers {
value := server.Country value := server.Country
if value == "" {
continue
}
_, alreadySeen := seen[value] _, alreadySeen := seen[value]
if alreadySeen { if alreadySeen {
continue continue
@@ -35,6 +38,9 @@ func ExtractRegions(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers)) values = make([]string, 0, len(servers))
for _, server := range servers { for _, server := range servers {
value := server.Region value := server.Region
if value == "" {
continue
}
_, alreadySeen := seen[value] _, alreadySeen := seen[value]
if alreadySeen { if alreadySeen {
continue continue
@@ -51,6 +57,9 @@ func ExtractCities(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers)) values = make([]string, 0, len(servers))
for _, server := range servers { for _, server := range servers {
value := server.City value := server.City
if value == "" {
continue
}
_, alreadySeen := seen[value] _, alreadySeen := seen[value]
if alreadySeen { if alreadySeen {
continue continue
@@ -67,6 +76,9 @@ func ExtractISPs(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers)) values = make([]string, 0, len(servers))
for _, server := range servers { for _, server := range servers {
value := server.ISP value := server.ISP
if value == "" {
continue
}
_, alreadySeen := seen[value] _, alreadySeen := seen[value]
if alreadySeen { if alreadySeen {
continue continue
@@ -83,6 +95,9 @@ func ExtractServerNames(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers)) values = make([]string, 0, len(servers))
for _, server := range servers { for _, server := range servers {
value := server.ServerName value := server.ServerName
if value == "" {
continue
}
_, alreadySeen := seen[value] _, alreadySeen := seen[value]
if alreadySeen { if alreadySeen {
continue continue
@@ -99,6 +114,9 @@ func ExtractHostnames(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers)) values = make([]string, 0, len(servers))
for _, server := range servers { for _, server := range servers {
value := server.Hostname value := server.Hostname
if value == "" {
continue
}
_, alreadySeen := seen[value] _, alreadySeen := seen[value]
if alreadySeen { if alreadySeen {
continue continue

View File

@@ -26,9 +26,9 @@ const (
Windscribe = "windscribe" Windscribe = "windscribe"
) )
// All returns all the providers except the custom provider.
func All() []string { func All() []string {
return []string{ return []string{
Custom,
Cyberghost, Cyberghost,
Expressvpn, Expressvpn,
Fastestvpn, Fastestvpn,
@@ -51,3 +51,11 @@ func All() []string {
Windscribe, Windscribe,
} }
} }
func AllWithCustom() []string {
allProviders := All()
allProvidersWithCustom := make([]string, len(allProviders)+1)
copy(allProvidersWithCustom, allProviders)
allProvidersWithCustom[len(allProvidersWithCustom)-1] = Custom
return allProvidersWithCustom
}

View File

@@ -0,0 +1,23 @@
package providers
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_All(t *testing.T) {
t.Parallel()
all := All()
assert.NotContains(t, all, Custom)
assert.NotEmpty(t, all)
}
func Test_AllWithCustom(t *testing.T) {
t.Parallel()
all := AllWithCustom()
assert.Contains(t, all, Custom)
assert.Len(t, all, len(All())+1)
}

View File

@@ -4,108 +4,17 @@ import (
"net" "net"
) )
func (a AllServers) GetCopy() (servers AllServers) { func (a AllServers) GetCopy() (allServersCopy AllServers) {
servers = a // copy versions and timestamps allServersCopy.Version = a.Version
servers.Cyberghost.Servers = a.GetCyberghost() allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers))
servers.Expressvpn.Servers = a.GetExpressvpn() for provider, servers := range a.ProviderToServers {
servers.Fastestvpn.Servers = a.GetFastestvpn() allServersCopy.ProviderToServers[provider] = Servers{
servers.HideMyAss.Servers = a.GetHideMyAss() Version: servers.Version,
servers.Ipvanish.Servers = a.GetIpvanish() Timestamp: servers.Timestamp,
servers.Ivpn.Servers = a.GetIvpn() Servers: copyServers(servers.Servers),
servers.Mullvad.Servers = a.GetMullvad() }
servers.Nordvpn.Servers = a.GetNordvpn() }
servers.Perfectprivacy.Servers = a.GetPerfectprivacy() return allServersCopy
servers.Privado.Servers = a.GetPrivado()
servers.Pia.Servers = a.GetPia()
servers.Privatevpn.Servers = a.GetPrivatevpn()
servers.Protonvpn.Servers = a.GetProtonvpn()
servers.Purevpn.Servers = a.GetPurevpn()
servers.Surfshark.Servers = a.GetSurfshark()
servers.Torguard.Servers = a.GetTorguard()
servers.VPNUnlimited.Servers = a.GetVPNUnlimited()
servers.Vyprvpn.Servers = a.GetVyprvpn()
servers.Windscribe.Servers = a.GetWindscribe()
return servers
}
func (a *AllServers) GetCyberghost() (servers []Server) {
return copyServers(a.Cyberghost.Servers)
}
func (a *AllServers) GetExpressvpn() (servers []Server) {
return copyServers(a.Expressvpn.Servers)
}
func (a *AllServers) GetFastestvpn() (servers []Server) {
return copyServers(a.Fastestvpn.Servers)
}
func (a *AllServers) GetHideMyAss() (servers []Server) {
return copyServers(a.HideMyAss.Servers)
}
func (a *AllServers) GetIpvanish() (servers []Server) {
return copyServers(a.Ipvanish.Servers)
}
func (a *AllServers) GetIvpn() (servers []Server) {
return copyServers(a.Ivpn.Servers)
}
func (a *AllServers) GetMullvad() (servers []Server) {
return copyServers(a.Mullvad.Servers)
}
func (a *AllServers) GetNordvpn() (servers []Server) {
return copyServers(a.Nordvpn.Servers)
}
func (a *AllServers) GetPerfectprivacy() (servers []Server) {
return copyServers(a.Perfectprivacy.Servers)
}
func (a *AllServers) GetPia() (servers []Server) {
return copyServers(a.Pia.Servers)
}
func (a *AllServers) GetPrivado() (servers []Server) {
return copyServers(a.Privado.Servers)
}
func (a *AllServers) GetPrivatevpn() (servers []Server) {
return copyServers(a.Privatevpn.Servers)
}
func (a *AllServers) GetProtonvpn() (servers []Server) {
return copyServers(a.Protonvpn.Servers)
}
func (a *AllServers) GetPurevpn() (servers []Server) {
return copyServers(a.Purevpn.Servers)
}
func (a *AllServers) GetSurfshark() (servers []Server) {
return copyServers(a.Surfshark.Servers)
}
func (a *AllServers) GetTorguard() (servers []Server) {
return copyServers(a.Torguard.Servers)
}
func (a *AllServers) GetVPNUnlimited() (servers []Server) {
return copyServers(a.VPNUnlimited.Servers)
}
func (a *AllServers) GetVyprvpn() (servers []Server) {
return copyServers(a.Vyprvpn.Servers)
}
func (a *AllServers) GetWevpn() (servers []Server) {
return copyServers(a.Wevpn.Servers)
}
func (a *AllServers) GetWindscribe() (servers []Server) {
return copyServers(a.Windscribe.Servers)
} }
func copyServers(servers []Server) (serversCopy []Server) { func copyServers(servers []Server) (serversCopy []Server) {

View File

@@ -4,108 +4,117 @@ import (
"net" "net"
"testing" "testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func Test_AllServers_GetCopy(t *testing.T) { func Test_AllServers_GetCopy(t *testing.T) {
allServers := AllServers{ allServers := AllServers{
Cyberghost: Servers{ Version: 1,
Version: 2, ProviderToServers: map[string]Servers{
Servers: []Server{{ providers.Cyberghost: {
IPs: []net.IP{{1, 2, 3, 4}}, Version: 2,
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Expressvpn: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Expressvpn: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Fastestvpn: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Fastestvpn: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
HideMyAss: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.HideMyAss: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Ipvanish: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Ipvanish: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Ivpn: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Ivpn: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Mullvad: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Mullvad: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Nordvpn: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Nordvpn: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Perfectprivacy: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Perfectprivacy: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Privado: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Privado: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Pia: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.PrivateInternetAccess: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Privatevpn: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Privatevpn: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Protonvpn: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Protonvpn: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Purevpn: Servers{ }},
Version: 1, },
Servers: []Server{{ providers.Purevpn: {
IPs: []net.IP{{1, 2, 3, 4}}, Version: 1,
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Surfshark: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Surfshark: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Torguard: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Torguard: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
VPNUnlimited: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.VPNUnlimited: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Vyprvpn: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Vyprvpn: {
}}, Servers: []Server{{
}, IPs: []net.IP{{1, 2, 3, 4}},
Windscribe: Servers{ }},
Servers: []Server{{ },
IPs: []net.IP{{1, 2, 3, 4}}, providers.Wevpn: {
}}, Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Windscribe: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
}, },
} }
@@ -114,32 +123,6 @@ func Test_AllServers_GetCopy(t *testing.T) {
assert.Equal(t, allServers, servers) assert.Equal(t, allServers, servers)
} }
func Test_AllServers_GetVyprvpn(t *testing.T) {
allServers := AllServers{
Vyprvpn: Servers{
Servers: []Server{
{Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}},
{Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}},
},
},
}
servers := allServers.GetVyprvpn()
expectedServers := []Server{
{Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}},
{Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}},
}
assert.Equal(t, expectedServers, servers)
allServers.Vyprvpn.Servers[0].IPs[0][0] = 9
assert.NotEqual(t, 9, servers[0].IPs[0][0])
allServers.Vyprvpn.Servers[0].IPs[0][0] = 1
servers[0].IPs[0][0] = 9
assert.NotEqual(t, 9, allServers.Vyprvpn.Servers[0].IPs[0][0])
}
func Test_copyIPs(t *testing.T) { func Test_copyIPs(t *testing.T) {
t.Parallel() t.Parallel()

View File

@@ -1,54 +1,163 @@
package models package models
import (
"bytes"
"encoding/json"
"fmt"
"math"
"reflect"
"github.com/qdm12/gluetun/internal/constants/providers"
)
type AllServers struct { type AllServers struct {
Version uint16 `json:"version"` // used for migration of the top level scheme Version uint16 // used for migration of the top level scheme
Cyberghost Servers `json:"cyberghost"` ProviderToServers map[string]Servers
Expressvpn Servers `json:"expressvpn"`
Fastestvpn Servers `json:"fastestvpn"`
HideMyAss Servers `json:"hidemyass"`
Ipvanish Servers `json:"ipvanish"`
Ivpn Servers `json:"ivpn"`
Mullvad Servers `json:"mullvad"`
Perfectprivacy Servers `json:"perfect privacy"`
Nordvpn Servers `json:"nordvpn"`
Privado Servers `json:"privado"`
Pia Servers `json:"private internet access"`
Privatevpn Servers `json:"privatevpn"`
Protonvpn Servers `json:"protonvpn"`
Purevpn Servers `json:"purevpn"`
Surfshark Servers `json:"surfshark"`
Torguard Servers `json:"torguard"`
VPNUnlimited Servers `json:"vpn unlimited"`
Vyprvpn Servers `json:"vyprvpn"`
Wevpn Servers `json:"wevpn"`
Windscribe Servers `json:"windscribe"`
} }
func (a *AllServers) Count() int { func (a *AllServers) ServersSlice(provider string) []Server {
return len(a.Cyberghost.Servers) + servers, ok := a.ProviderToServers[provider]
len(a.Expressvpn.Servers) + if !ok {
len(a.Fastestvpn.Servers) + panic(fmt.Sprintf("provider %s not found in all servers", provider))
len(a.HideMyAss.Servers) + }
len(a.Ipvanish.Servers) + return copyServers(servers.Servers)
len(a.Ivpn.Servers) + }
len(a.Mullvad.Servers) +
len(a.Nordvpn.Servers) + var _ json.Marshaler = (*AllServers)(nil)
len(a.Perfectprivacy.Servers) +
len(a.Privado.Servers) + // MarshalJSON marshals all servers to JSON.
len(a.Pia.Servers) + // Note you need to use a pointer to all servers
len(a.Privatevpn.Servers) + // for it to work with native json methods such as
len(a.Protonvpn.Servers) + // json.Marshal.
len(a.Purevpn.Servers) + func (a *AllServers) MarshalJSON() (data []byte, err error) {
len(a.Surfshark.Servers) + buffer := bytes.NewBuffer(nil)
len(a.Torguard.Servers) +
len(a.VPNUnlimited.Servers) + _, err = buffer.WriteString("{")
len(a.Vyprvpn.Servers) + if err != nil {
len(a.Wevpn.Servers) + return nil, fmt.Errorf("cannot write opening bracket: %w", err)
len(a.Windscribe.Servers) }
versionString := fmt.Sprintf(`"version":%d`, a.Version)
_, err = buffer.WriteString(versionString)
if err != nil {
return nil, fmt.Errorf("cannot write schema version string: %w", err)
}
for _, provider := range providers.All() {
servers, ok := a.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in all servers", provider))
}
providerKey := fmt.Sprintf(`,"%s":`, provider)
_, err = buffer.WriteString(providerKey)
if err != nil {
return nil, fmt.Errorf("cannot write provider key %s: %w",
providerKey, err)
}
serversJSON, err := json.Marshal(servers)
if err != nil {
return nil, fmt.Errorf("failed encoding servers for provider %s: %w",
provider, err)
}
_, err = buffer.Write(serversJSON)
if err != nil {
return nil, fmt.Errorf("cannot write JSON servers data for provider %s: %w",
provider, err)
}
}
_, err = buffer.WriteString("}")
if err != nil {
return nil, fmt.Errorf("cannot write closing bracket: %w", err)
}
return buffer.Bytes(), nil
}
var _ json.Unmarshaler = (*AllServers)(nil)
func (a *AllServers) UnmarshalJSON(data []byte) (err error) {
keyValues := make(map[string]interface{})
err = json.Unmarshal(data, &keyValues)
if err != nil {
return err
}
versionUnmarshaled := keyValues["version"]
if versionUnmarshaled != nil { // defaults to 0
version, ok := versionUnmarshaled.(float64)
if !ok {
return &json.UnmarshalTypeError{
Value: fmt.Sprintf("number %v", versionUnmarshaled),
Type: reflect.TypeOf(uint16(0)),
Struct: "models.AllServers",
Field: "Version",
}
}
if math.Round(version) != version ||
version < 0 || version > float64(^uint16(0)) {
return &json.UnmarshalTypeError{
Value: fmt.Sprintf("number %v", version),
Type: reflect.TypeOf(uint16(0)),
Struct: "models.AllServers",
Field: "Version",
}
}
a.Version = uint16(version)
delete(keyValues, "version")
}
if len(keyValues) == 0 {
return nil
}
a.ProviderToServers = make(map[string]Servers, len(keyValues))
allProviders := providers.All()
allProvidersSet := make(map[string]struct{}, len(allProviders))
for _, provider := range allProviders {
allProvidersSet[provider] = struct{}{}
}
for key, value := range keyValues {
if _, ok := allProvidersSet[key]; !ok {
// not a provider known by Gluetun
// or a non-servers field.
continue
}
jsonValue, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("cannot marshal %s servers: %w",
key, err)
}
var servers Servers
err = json.Unmarshal(jsonValue, &servers)
if err != nil {
return fmt.Errorf("cannot unmarshal %s servers: %w",
key, err)
}
a.ProviderToServers[key] = servers
}
return nil
}
func (a *AllServers) Count() (count int) {
for _, servers := range a.ProviderToServers {
count += len(servers.Servers)
}
return count
} }
type Servers struct { type Servers struct {
Version uint16 `json:"version"` Version uint16 `json:"version"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp"`
Servers []Server `json:"servers"` Servers []Server `json:"servers,omitempty"`
} }

View File

@@ -0,0 +1,189 @@
package models
import (
"bytes"
"encoding/json"
"testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AllServers_MarshalJSON(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
allServers *AllServers
dataString string
errWrapped error
errMessage string
}{
"empty": {
allServers: &AllServers{
ProviderToServers: map[string]Servers{},
},
dataString: `{"version":0,` +
`"cyberghost":{"version":0,"timestamp":0},` +
`"expressvpn":{"version":0,"timestamp":0},` +
`"fastestvpn":{"version":0,"timestamp":0},` +
`"hidemyass":{"version":0,"timestamp":0},` +
`"ipvanish":{"version":0,"timestamp":0},` +
`"ivpn":{"version":0,"timestamp":0},` +
`"mullvad":{"version":0,"timestamp":0},` +
`"nordvpn":{"version":0,"timestamp":0},` +
`"perfect privacy":{"version":0,"timestamp":0},` +
`"privado":{"version":0,"timestamp":0},` +
`"private internet access":{"version":0,"timestamp":0},` +
`"privatevpn":{"version":0,"timestamp":0},` +
`"protonvpn":{"version":0,"timestamp":0},` +
`"purevpn":{"version":0,"timestamp":0},` +
`"surfshark":{"version":0,"timestamp":0},` +
`"torguard":{"version":0,"timestamp":0},` +
`"vpn unlimited":{"version":0,"timestamp":0},` +
`"vyprvpn":{"version":0,"timestamp":0},` +
`"wevpn":{"version":0,"timestamp":0},` +
`"windscribe":{"version":0,"timestamp":0}}`,
},
"two known providers": {
allServers: &AllServers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 1,
Timestamp: 1000,
Servers: []Server{
{Country: "A"},
{Country: "B"},
},
},
providers.Privado: {
Version: 2,
Timestamp: 2000,
Servers: []Server{
{City: "C"},
{City: "D"},
},
},
},
},
dataString: `{"version":1,` +
`"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` +
`"expressvpn":{"version":0,"timestamp":0},` +
`"fastestvpn":{"version":0,"timestamp":0},` +
`"hidemyass":{"version":0,"timestamp":0},` +
`"ipvanish":{"version":0,"timestamp":0},` +
`"ivpn":{"version":0,"timestamp":0},` +
`"mullvad":{"version":0,"timestamp":0},` +
`"nordvpn":{"version":0,"timestamp":0},` +
`"perfect privacy":{"version":0,"timestamp":0},` +
`"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]},` +
`"private internet access":{"version":0,"timestamp":0},` +
`"privatevpn":{"version":0,"timestamp":0},` +
`"protonvpn":{"version":0,"timestamp":0},` +
`"purevpn":{"version":0,"timestamp":0},` +
`"surfshark":{"version":0,"timestamp":0},` +
`"torguard":{"version":0,"timestamp":0},` +
`"vpn unlimited":{"version":0,"timestamp":0},` +
`"vyprvpn":{"version":0,"timestamp":0},` +
`"wevpn":{"version":0,"timestamp":0},` +
`"windscribe":{"version":0,"timestamp":0}}`,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
// Populate all providers in all servers
for _, provider := range providers.All() {
_, has := testCase.allServers.ProviderToServers[provider]
if !has {
testCase.allServers.ProviderToServers[provider] = Servers{}
}
}
data, err := testCase.allServers.MarshalJSON()
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.dataString, string(data))
data, err = json.Marshal(testCase.allServers)
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.dataString, string(data))
buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
// encoder.SetIndent("", " ")
err = encoder.Encode(testCase.allServers)
require.NoError(t, err)
assert.Equal(t, testCase.dataString+"\n", buffer.String())
})
}
}
func Test_AllServers_UnmarshalJSON(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
dataString string
allServers AllServers
errWrapped error
errMessage string
}{
"empty": {
dataString: "{}",
allServers: AllServers{},
},
"two known providers": {
dataString: `{"version":1,` +
`"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` +
`"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]}}`,
allServers: AllServers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 1,
Timestamp: 1000,
Servers: []Server{
{Country: "A"},
{Country: "B"},
},
},
providers.Privado: {
Version: 2,
Timestamp: 2000,
Servers: []Server{
{City: "C"},
{City: "D"},
},
},
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
data := []byte(testCase.dataString)
var allServers AllServers
err := json.Unmarshal(data, &allServers)
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.allServers, allServers)
})
}
}

View File

@@ -51,50 +51,51 @@ type PortForwarder interface {
} }
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider { func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider {
serversSlice := allServers.ServersSlice(provider)
randSource := rand.NewSource(timeNow().UnixNano()) randSource := rand.NewSource(timeNow().UnixNano())
switch provider { switch provider {
case providers.Custom: case providers.Custom:
return custom.New() return custom.New()
case providers.Cyberghost: case providers.Cyberghost:
return cyberghost.New(allServers.Cyberghost.Servers, randSource) return cyberghost.New(serversSlice, randSource)
case providers.Expressvpn: case providers.Expressvpn:
return expressvpn.New(allServers.Expressvpn.Servers, randSource) return expressvpn.New(serversSlice, randSource)
case providers.Fastestvpn: case providers.Fastestvpn:
return fastestvpn.New(allServers.Fastestvpn.Servers, randSource) return fastestvpn.New(serversSlice, randSource)
case providers.HideMyAss: case providers.HideMyAss:
return hidemyass.New(allServers.HideMyAss.Servers, randSource) return hidemyass.New(serversSlice, randSource)
case providers.Ipvanish: case providers.Ipvanish:
return ipvanish.New(allServers.Ipvanish.Servers, randSource) return ipvanish.New(serversSlice, randSource)
case providers.Ivpn: case providers.Ivpn:
return ivpn.New(allServers.Ivpn.Servers, randSource) return ivpn.New(serversSlice, randSource)
case providers.Mullvad: case providers.Mullvad:
return mullvad.New(allServers.Mullvad.Servers, randSource) return mullvad.New(serversSlice, randSource)
case providers.Nordvpn: case providers.Nordvpn:
return nordvpn.New(allServers.Nordvpn.Servers, randSource) return nordvpn.New(serversSlice, randSource)
case providers.Perfectprivacy: case providers.Perfectprivacy:
return perfectprivacy.New(allServers.Perfectprivacy.Servers, randSource) return perfectprivacy.New(serversSlice, randSource)
case providers.Privado: case providers.Privado:
return privado.New(allServers.Privado.Servers, randSource) return privado.New(serversSlice, randSource)
case providers.PrivateInternetAccess: case providers.PrivateInternetAccess:
return privateinternetaccess.New(allServers.Pia.Servers, randSource, timeNow) return privateinternetaccess.New(serversSlice, randSource, timeNow)
case providers.Privatevpn: case providers.Privatevpn:
return privatevpn.New(allServers.Privatevpn.Servers, randSource) return privatevpn.New(serversSlice, randSource)
case providers.Protonvpn: case providers.Protonvpn:
return protonvpn.New(allServers.Protonvpn.Servers, randSource) return protonvpn.New(serversSlice, randSource)
case providers.Purevpn: case providers.Purevpn:
return purevpn.New(allServers.Purevpn.Servers, randSource) return purevpn.New(serversSlice, randSource)
case providers.Surfshark: case providers.Surfshark:
return surfshark.New(allServers.Surfshark.Servers, randSource) return surfshark.New(serversSlice, randSource)
case providers.Torguard: case providers.Torguard:
return torguard.New(allServers.Torguard.Servers, randSource) return torguard.New(serversSlice, randSource)
case providers.VPNUnlimited: case providers.VPNUnlimited:
return vpnunlimited.New(allServers.VPNUnlimited.Servers, randSource) return vpnunlimited.New(serversSlice, randSource)
case providers.Vyprvpn: case providers.Vyprvpn:
return vyprvpn.New(allServers.Vyprvpn.Servers, randSource) return vyprvpn.New(serversSlice, randSource)
case providers.Wevpn: case providers.Wevpn:
return wevpn.New(allServers.Wevpn.Servers, randSource) return wevpn.New(serversSlice, randSource)
case providers.Windscribe: case providers.Windscribe:
return windscribe.New(allServers.Windscribe.Servers, randSource) return windscribe.New(serversSlice, randSource)
default: default:
panic("provider " + provider + " is unknown") // should never occur panic("provider " + provider + " is unknown") // should never occur
} }

View File

@@ -11,14 +11,14 @@ import (
var _ Flusher = (*Storage)(nil) var _ Flusher = (*Storage)(nil)
type Flusher interface { type Flusher interface {
FlushToFile(allServers models.AllServers) error FlushToFile(allServers *models.AllServers) error
} }
func (s *Storage) FlushToFile(allServers models.AllServers) error { func (s *Storage) FlushToFile(allServers *models.AllServers) error {
return flushToFile(s.filepath, allServers) return flushToFile(s.filepath, allServers)
} }
func flushToFile(path string, servers models.AllServers) error { func flushToFile(path string, servers *models.AllServers) error {
dirPath := filepath.Dir(path) dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, 0644); err != nil { if err := os.MkdirAll(dirPath, 0644); err != nil {
return err return err

View File

@@ -3,6 +3,8 @@ package storage
import ( import (
"testing" "testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -12,5 +14,13 @@ func Test_parseHardcodedServers(t *testing.T) {
servers, err := parseHardcodedServers() servers, err := parseHardcodedServers()
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, len(servers.Cyberghost.Servers))
// all providers minus custom
allProviders := providers.All()
require.Equal(t, len(allProviders), len(servers.ProviderToServers))
for _, provider := range allProviders {
servers, ok := servers.ProviderToServers[provider]
assert.Truef(t, ok, "for provider %s", provider)
assert.NotEmptyf(t, servers, "for provider %s", provider)
}
} }

View File

@@ -28,29 +28,20 @@ func (s *Storage) logTimeDiff(provider string, persistedUnix, hardcodedUnix int6
} }
func (s *Storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers { func (s *Storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers {
return models.AllServers{ allProviders := providers.All()
Version: hardcoded.Version, merged := models.AllServers{
Cyberghost: s.mergeProviderServers(providers.Cyberghost, hardcoded.Cyberghost, persisted.Cyberghost), Version: hardcoded.Version,
Expressvpn: s.mergeProviderServers(providers.Expressvpn, hardcoded.Expressvpn, persisted.Expressvpn), ProviderToServers: make(map[string]models.Servers, len(allProviders)),
Fastestvpn: s.mergeProviderServers(providers.Fastestvpn, hardcoded.Fastestvpn, persisted.Fastestvpn),
HideMyAss: s.mergeProviderServers(providers.HideMyAss, hardcoded.HideMyAss, persisted.HideMyAss),
Ipvanish: s.mergeProviderServers(providers.Ipvanish, hardcoded.Ipvanish, persisted.Ipvanish),
Ivpn: s.mergeProviderServers(providers.Ivpn, hardcoded.Ivpn, persisted.Ivpn),
Mullvad: s.mergeProviderServers(providers.Mullvad, hardcoded.Mullvad, persisted.Mullvad),
Nordvpn: s.mergeProviderServers(providers.Nordvpn, hardcoded.Nordvpn, persisted.Nordvpn),
Perfectprivacy: s.mergeProviderServers(providers.Perfectprivacy, hardcoded.Perfectprivacy, persisted.Perfectprivacy),
Privado: s.mergeProviderServers(providers.Privado, hardcoded.Privado, persisted.Privado),
Pia: s.mergeProviderServers(providers.PrivateInternetAccess, hardcoded.Pia, persisted.Pia),
Privatevpn: s.mergeProviderServers(providers.Privatevpn, hardcoded.Privatevpn, persisted.Privatevpn),
Protonvpn: s.mergeProviderServers(providers.Protonvpn, hardcoded.Protonvpn, persisted.Protonvpn),
Purevpn: s.mergeProviderServers(providers.Purevpn, hardcoded.Purevpn, persisted.Purevpn),
Surfshark: s.mergeProviderServers(providers.Surfshark, hardcoded.Surfshark, persisted.Surfshark),
Torguard: s.mergeProviderServers(providers.Torguard, hardcoded.Torguard, persisted.Torguard),
VPNUnlimited: s.mergeProviderServers(providers.VPNUnlimited, hardcoded.VPNUnlimited, persisted.VPNUnlimited),
Vyprvpn: s.mergeProviderServers(providers.Vyprvpn, hardcoded.Vyprvpn, persisted.Vyprvpn),
Wevpn: s.mergeProviderServers(providers.Wevpn, hardcoded.Wevpn, persisted.Wevpn),
Windscribe: s.mergeProviderServers(providers.Windscribe, hardcoded.Windscribe, persisted.Windscribe),
} }
for _, provider := range allProviders {
hardcodedServers := hardcoded.ProviderToServers[provider]
persistedServers := persisted.ProviderToServers[provider]
merged.ProviderToServers[provider] = s.mergeProviderServers(provider,
hardcodedServers, persistedServers)
}
return merged
} }
func (s *Storage) mergeProviderServers(provider string, func (s *Storage) mergeProviderServers(provider string,

View File

@@ -38,172 +38,39 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) ( func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) (
servers models.AllServers, err error) { servers models.AllServers, err error) {
var versions allVersions rawMessages := make(map[string]json.RawMessage)
if err := json.Unmarshal(b, &versions); err != nil {
return servers, fmt.Errorf("cannot decode versions: %w", err)
}
var rawMessages allJSONRawMessages
if err := json.Unmarshal(b, &rawMessages); err != nil { if err := json.Unmarshal(b, &rawMessages); err != nil {
return servers, fmt.Errorf("cannot decode servers: %w", err) return servers, fmt.Errorf("cannot decode servers: %w", err)
} }
type element struct { // Note schema version is at map key "version" as number
provider string
hardcoded models.Servers
serverVersion serverVersion
rawMessage json.RawMessage
target *models.Servers
}
elements := []element{
{
provider: providers.Cyberghost,
hardcoded: hardcoded.Cyberghost,
serverVersion: versions.Cyberghost,
rawMessage: rawMessages.Cyberghost,
target: &servers.Cyberghost,
},
{
provider: providers.Expressvpn,
hardcoded: hardcoded.Expressvpn,
serverVersion: versions.Expressvpn,
rawMessage: rawMessages.Expressvpn,
target: &servers.Expressvpn,
},
{
provider: providers.Fastestvpn,
hardcoded: hardcoded.Fastestvpn,
serverVersion: versions.Fastestvpn,
rawMessage: rawMessages.Fastestvpn,
target: &servers.Fastestvpn,
},
{
provider: providers.HideMyAss,
hardcoded: hardcoded.HideMyAss,
serverVersion: versions.HideMyAss,
rawMessage: rawMessages.HideMyAss,
target: &servers.HideMyAss,
},
{
provider: providers.Ipvanish,
hardcoded: hardcoded.Ipvanish,
serverVersion: versions.Ipvanish,
rawMessage: rawMessages.Ipvanish,
target: &servers.Ipvanish,
},
{
provider: providers.Ivpn,
hardcoded: hardcoded.Ivpn,
serverVersion: versions.Ivpn,
rawMessage: rawMessages.Ivpn,
target: &servers.Ivpn,
},
{
provider: providers.Mullvad,
hardcoded: hardcoded.Mullvad,
serverVersion: versions.Mullvad,
rawMessage: rawMessages.Mullvad,
target: &servers.Mullvad,
},
{
provider: providers.Nordvpn,
hardcoded: hardcoded.Nordvpn,
serverVersion: versions.Nordvpn,
rawMessage: rawMessages.Nordvpn,
target: &servers.Nordvpn,
},
{
provider: providers.Perfectprivacy,
hardcoded: hardcoded.Perfectprivacy,
serverVersion: versions.Perfectprivacy,
rawMessage: rawMessages.Perfectprivacy,
target: &servers.Perfectprivacy,
},
{
provider: providers.Privado,
hardcoded: hardcoded.Privado,
serverVersion: versions.Privado,
rawMessage: rawMessages.Privado,
target: &servers.Privado,
},
{
provider: providers.PrivateInternetAccess,
hardcoded: hardcoded.Pia,
serverVersion: versions.Pia,
rawMessage: rawMessages.Pia,
target: &servers.Pia,
},
{
provider: providers.Privatevpn,
hardcoded: hardcoded.Privatevpn,
serverVersion: versions.Privatevpn,
rawMessage: rawMessages.Privatevpn,
target: &servers.Privatevpn,
},
{
provider: providers.Protonvpn,
hardcoded: hardcoded.Protonvpn,
serverVersion: versions.Protonvpn,
rawMessage: rawMessages.Protonvpn,
target: &servers.Protonvpn,
},
{
provider: providers.Purevpn,
hardcoded: hardcoded.Purevpn,
serverVersion: versions.Purevpn,
rawMessage: rawMessages.Purevpn,
target: &servers.Purevpn,
},
{
provider: providers.Surfshark,
hardcoded: hardcoded.Surfshark,
serverVersion: versions.Surfshark,
rawMessage: rawMessages.Surfshark,
target: &servers.Surfshark,
},
{
provider: providers.Torguard,
hardcoded: hardcoded.Torguard,
serverVersion: versions.Torguard,
rawMessage: rawMessages.Torguard,
target: &servers.Torguard,
},
{
provider: providers.VPNUnlimited,
hardcoded: hardcoded.VPNUnlimited,
serverVersion: versions.VPNUnlimited,
rawMessage: rawMessages.VPNUnlimited,
target: &servers.VPNUnlimited,
},
{
provider: providers.Vyprvpn,
hardcoded: hardcoded.Vyprvpn,
serverVersion: versions.Vyprvpn,
rawMessage: rawMessages.Vyprvpn,
target: &servers.Vyprvpn,
},
{
provider: providers.Wevpn,
hardcoded: hardcoded.Wevpn,
serverVersion: versions.Wevpn,
rawMessage: rawMessages.Wevpn,
target: &servers.Wevpn,
},
{
provider: providers.Windscribe,
hardcoded: hardcoded.Windscribe,
serverVersion: versions.Windscribe,
rawMessage: rawMessages.Windscribe,
target: &servers.Windscribe,
},
}
for _, element := range elements { allProviders := providers.All()
*element.target, err = s.readServers(element.provider, servers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
element.hardcoded, element.serverVersion, element.rawMessage) for _, provider := range allProviders {
if err != nil { hardcoded, ok := hardcoded.ProviderToServers[provider]
return servers, err if !ok {
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
} }
rawMessage, ok := rawMessages[provider]
if !ok {
// If the provider is not found in the data bytes, just don't set it in
// the providers map. That way the hardcoded servers will override them.
// This is user provided and could come from different sources in the
// future (e.g. a file or API request).
continue
}
mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage)
if err != nil {
return models.AllServers{}, err
} else if !versionsMatch {
// mergedServers is the empty struct in this case, so don't set the key
// in the providerToServers map.
continue
}
servers.ProviderToServers[provider] = mergedServers
} }
return servers, nil return servers, nil
@@ -214,73 +81,20 @@ var (
) )
func (s *Storage) readServers(provider string, hardcoded models.Servers, func (s *Storage) readServers(provider string, hardcoded models.Servers,
serverVersion serverVersion, rawMessage json.RawMessage) ( rawMessage json.RawMessage) (servers models.Servers, versionsMatch bool, err error) {
servers models.Servers, err error) {
provider = strings.Title(provider) provider = strings.Title(provider)
if hardcoded.Version != serverVersion.Version {
s.logVersionDiff(provider, hardcoded.Version, serverVersion.Version)
return servers, nil
}
err = json.Unmarshal(rawMessage, &servers) var persistedServers models.Servers
err = json.Unmarshal(rawMessage, &persistedServers)
if err != nil { if err != nil {
return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err) return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
} }
return servers, nil versionsMatch = hardcoded.Version == persistedServers.Version
} if !versionsMatch {
s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version)
return servers, versionsMatch, nil
}
// allVersions is a subset of models.AllServers structure used to track return persistedServers, versionsMatch, nil
// versions to avoid unmarshaling errors.
type allVersions struct {
Version uint16 `json:"version"` // used for migration of the top level scheme
Cyberghost serverVersion `json:"cyberghost"`
Expressvpn serverVersion `json:"expressvpn"`
Fastestvpn serverVersion `json:"fastestvpn"`
HideMyAss serverVersion `json:"hidemyass"`
Ipvanish serverVersion `json:"ipvanish"`
Ivpn serverVersion `json:"ivpn"`
Mullvad serverVersion `json:"mullvad"`
Nordvpn serverVersion `json:"nordvpn"`
Perfectprivacy serverVersion `json:"perfect privacy"`
Privado serverVersion `json:"privado"`
Pia serverVersion `json:"private internet access"`
Privatevpn serverVersion `json:"privatevpn"`
Protonvpn serverVersion `json:"protonvpn"`
Purevpn serverVersion `json:"purevpn"`
Surfshark serverVersion `json:"surfshark"`
Torguard serverVersion `json:"torguard"`
VPNUnlimited serverVersion `json:"vpn unlimited"`
Vyprvpn serverVersion `json:"vyprvpn"`
Wevpn serverVersion `json:"wevpn"`
Windscribe serverVersion `json:"windscribe"`
}
type serverVersion struct {
Version uint16 `json:"version"`
}
// allJSONRawMessages is to delay decoding of each provider servers.
type allJSONRawMessages struct {
Version uint16 `json:"version"` // used for migration of the top level scheme
Cyberghost json.RawMessage `json:"cyberghost"`
Expressvpn json.RawMessage `json:"expressvpn"`
Fastestvpn json.RawMessage `json:"fastestvpn"`
HideMyAss json.RawMessage `json:"hidemyass"`
Ipvanish json.RawMessage `json:"ipvanish"`
Ivpn json.RawMessage `json:"ivpn"`
Mullvad json.RawMessage `json:"mullvad"`
Nordvpn json.RawMessage `json:"nordvpn"`
Perfectprivacy json.RawMessage `json:"perfect privacy"`
Privado json.RawMessage `json:"privado"`
Pia json.RawMessage `json:"private internet access"`
Privatevpn json.RawMessage `json:"privatevpn"`
Protonvpn json.RawMessage `json:"protonvpn"`
Purevpn json.RawMessage `json:"purevpn"`
Surfshark json.RawMessage `json:"surfshark"`
Torguard json.RawMessage `json:"torguard"`
VPNUnlimited json.RawMessage `json:"vpn unlimited"`
Vyprvpn json.RawMessage `json:"vyprvpn"`
Wevpn json.RawMessage `json:"wevpn"`
Windscribe json.RawMessage `json:"windscribe"`
} }

View File

@@ -1,80 +1,89 @@
package storage package storage
import ( import (
"errors" "fmt"
"testing" "testing"
gomock "github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
servers models.AllServers) models.AllServers {
allProviders := providers.All()
if servers.ProviderToServers == nil {
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders)-1)
}
for _, provider := range allProviders {
_, has := servers.ProviderToServers[provider]
if has {
continue
}
servers.ProviderToServers[provider] = models.Servers{
Version: allProviderVersion,
Timestamp: allProviderTimestamp,
}
}
return servers
}
func Test_extractServersFromBytes(t *testing.T) { func Test_extractServersFromBytes(t *testing.T) {
t.Parallel() t.Parallel()
testCases := map[string]struct { testCases := map[string]struct {
b []byte b []byte
hardcoded models.AllServers hardcoded models.AllServers
logged []string logged []string
persisted models.AllServers persisted models.AllServers
err error errMessage string
}{ }{
"no data": { "bad JSON": {
err: errors.New("cannot decode versions: unexpected end of JSON input"), b: []byte("garbage"),
errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value",
}, },
"empty JSON": { "bad provider JSON": {
b: []byte("{}"), b: []byte(`{"cyberghost": "garbage"}`),
err: errors.New("cannot decode servers for provider: Cyberghost: unexpected end of JSON input"), hardcoded: populateProviders(1, 0, models.AllServers{}),
errMessage: "cannot decode servers for provider: Cyberghost: " +
"json: cannot unmarshal string into Go value of type models.Servers",
}, },
"different versions": { "absent provider keys": {
b: []byte(`{}`), b: []byte(`{}`),
hardcoded: models.AllServers{ hardcoded: populateProviders(1, 0, models.AllServers{}),
Cyberghost: models.Servers{Version: 1}, persisted: models.AllServers{
Expressvpn: models.Servers{Version: 1}, ProviderToServers: map[string]models.Servers{},
Fastestvpn: models.Servers{Version: 1},
HideMyAss: models.Servers{Version: 1},
Ipvanish: models.Servers{Version: 1},
Ivpn: models.Servers{Version: 1},
Mullvad: models.Servers{Version: 1},
Nordvpn: models.Servers{Version: 1},
Perfectprivacy: models.Servers{Version: 1},
Privado: models.Servers{Version: 1},
Pia: models.Servers{Version: 1},
Privatevpn: models.Servers{Version: 1},
Protonvpn: models.Servers{Version: 1},
Purevpn: models.Servers{Version: 1},
Surfshark: models.Servers{Version: 1},
Torguard: models.Servers{Version: 1},
VPNUnlimited: models.Servers{Version: 1},
Vyprvpn: models.Servers{Version: 1},
Wevpn: models.Servers{Version: 1},
Windscribe: models.Servers{Version: 1},
},
logged: []string{
"Cyberghost servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Expressvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Fastestvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Hidemyass servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Ipvanish servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Ivpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Mullvad servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Nordvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Perfect Privacy servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Privado servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Private Internet Access servers from file discarded because they have version 0 and hardcoded servers have version 1", //nolint:lll
"Privatevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Protonvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Purevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Surfshark servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Torguard servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Vpn Unlimited servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Vyprvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Wevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Windscribe servers from file discarded because they have version 0 and hardcoded servers have version 1",
}, },
}, },
"same versions": { "same versions": {
b: []byte(`{
"cyberghost": {"version": 1, "timestamp": 1},
"expressvpn": {"version": 1, "timestamp": 1},
"fastestvpn": {"version": 1, "timestamp": 1},
"hidemyass": {"version": 1, "timestamp": 1},
"ipvanish": {"version": 1, "timestamp": 1},
"ivpn": {"version": 1, "timestamp": 1},
"mullvad": {"version": 1, "timestamp": 1},
"nordvpn": {"version": 1, "timestamp": 1},
"perfect privacy": {"version": 1, "timestamp": 1},
"privado": {"version": 1, "timestamp": 1},
"private internet access": {"version": 1, "timestamp": 1},
"privatevpn": {"version": 1, "timestamp": 1},
"protonvpn": {"version": 1, "timestamp": 1},
"purevpn": {"version": 1, "timestamp": 1},
"surfshark": {"version": 1, "timestamp": 1},
"torguard": {"version": 1, "timestamp": 1},
"vpn unlimited": {"version": 1, "timestamp": 1},
"vyprvpn": {"version": 1, "timestamp": 1},
"wevpn": {"version": 1, "timestamp": 1},
"windscribe": {"version": 1, "timestamp": 1}
}`),
hardcoded: populateProviders(1, 0, models.AllServers{}),
persisted: populateProviders(1, 1, models.AllServers{}),
},
"different versions": {
b: []byte(`{ b: []byte(`{
"cyberghost": {"version": 1, "timestamp": 1}, "cyberghost": {"version": 1, "timestamp": 1},
"expressvpn": {"version": 1, "timestamp": 1}, "expressvpn": {"version": 1, "timestamp": 1},
@@ -97,49 +106,31 @@ func Test_extractServersFromBytes(t *testing.T) {
"wevpn": {"version": 1, "timestamp": 1}, "wevpn": {"version": 1, "timestamp": 1},
"windscribe": {"version": 1, "timestamp": 1} "windscribe": {"version": 1, "timestamp": 1}
}`), }`),
hardcoded: models.AllServers{ hardcoded: populateProviders(2, 0, models.AllServers{}),
Cyberghost: models.Servers{Version: 1}, logged: []string{
Expressvpn: models.Servers{Version: 1}, "Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2",
Fastestvpn: models.Servers{Version: 1}, "Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
HideMyAss: models.Servers{Version: 1}, "Fastestvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
Ipvanish: models.Servers{Version: 1}, "Hidemyass servers from file discarded because they have version 1 and hardcoded servers have version 2",
Ivpn: models.Servers{Version: 1}, "Ipvanish servers from file discarded because they have version 1 and hardcoded servers have version 2",
Mullvad: models.Servers{Version: 1}, "Ivpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
Nordvpn: models.Servers{Version: 1}, "Mullvad servers from file discarded because they have version 1 and hardcoded servers have version 2",
Perfectprivacy: models.Servers{Version: 1}, "Nordvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
Privado: models.Servers{Version: 1}, "Perfect Privacy servers from file discarded because they have version 1 and hardcoded servers have version 2",
Pia: models.Servers{Version: 1}, "Privado servers from file discarded because they have version 1 and hardcoded servers have version 2",
Privatevpn: models.Servers{Version: 1}, "Private Internet Access servers from file discarded because they have version 1 and hardcoded servers have version 2", //nolint:lll
Protonvpn: models.Servers{Version: 1}, "Privatevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
Purevpn: models.Servers{Version: 1}, "Protonvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
Surfshark: models.Servers{Version: 1}, "Purevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
Torguard: models.Servers{Version: 1}, "Surfshark servers from file discarded because they have version 1 and hardcoded servers have version 2",
VPNUnlimited: models.Servers{Version: 1}, "Torguard servers from file discarded because they have version 1 and hardcoded servers have version 2",
Vyprvpn: models.Servers{Version: 1}, "Vpn Unlimited servers from file discarded because they have version 1 and hardcoded servers have version 2",
Wevpn: models.Servers{Version: 1}, "Vyprvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
Windscribe: models.Servers{Version: 1}, "Wevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Windscribe servers from file discarded because they have version 1 and hardcoded servers have version 2",
}, },
persisted: models.AllServers{ persisted: models.AllServers{
Cyberghost: models.Servers{Version: 1, Timestamp: 1}, ProviderToServers: map[string]models.Servers{},
Expressvpn: models.Servers{Version: 1, Timestamp: 1},
Fastestvpn: models.Servers{Version: 1, Timestamp: 1},
HideMyAss: models.Servers{Version: 1, Timestamp: 1},
Ipvanish: models.Servers{Version: 1, Timestamp: 1},
Ivpn: models.Servers{Version: 1, Timestamp: 1},
Mullvad: models.Servers{Version: 1, Timestamp: 1},
Nordvpn: models.Servers{Version: 1, Timestamp: 1},
Perfectprivacy: models.Servers{Version: 1, Timestamp: 1},
Privado: models.Servers{Version: 1, Timestamp: 1},
Pia: models.Servers{Version: 1, Timestamp: 1},
Privatevpn: models.Servers{Version: 1, Timestamp: 1},
Protonvpn: models.Servers{Version: 1, Timestamp: 1},
Purevpn: models.Servers{Version: 1, Timestamp: 1},
Surfshark: models.Servers{Version: 1, Timestamp: 1},
Torguard: models.Servers{Version: 1, Timestamp: 1},
VPNUnlimited: models.Servers{Version: 1, Timestamp: 1},
Vyprvpn: models.Servers{Version: 1, Timestamp: 1},
Wevpn: models.Servers{Version: 1, Timestamp: 1},
Windscribe: models.Servers{Version: 1, Timestamp: 1},
}, },
}, },
} }
@@ -151,8 +142,13 @@ func Test_extractServersFromBytes(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
logger := NewMockInfoErrorer(ctrl) logger := NewMockInfoErrorer(ctrl)
var previousLogCall *gomock.Call
for _, logged := range testCase.logged { for _, logged := range testCase.logged {
logger.EXPECT().Info(logged) call := logger.EXPECT().Info(logged)
if previousLogCall != nil {
call.After(previousLogCall)
}
previousLogCall = call
} }
s := &Storage{ s := &Storage{
@@ -161,9 +157,8 @@ func Test_extractServersFromBytes(t *testing.T) {
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded) servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded)
if testCase.err != nil { if testCase.errMessage != "" {
require.Error(t, err) assert.EqualError(t, err, testCase.errMessage)
assert.Equal(t, testCase.err.Error(), err.Error())
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
} }
@@ -171,4 +166,25 @@ func Test_extractServersFromBytes(t *testing.T) {
assert.Equal(t, testCase.persisted, servers) assert.Equal(t, testCase.persisted, servers)
}) })
} }
t.Run("hardcoded panic", func(t *testing.T) {
t.Parallel()
s := &Storage{}
allProviders := providers.All()
require.GreaterOrEqual(t, len(allProviders), 2)
b := []byte(`{}`)
hardcoded := models.AllServers{
ProviderToServers: map[string]models.Servers{
allProviders[0]: {},
// Missing provider allProviders[1]
},
}
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
assert.PanicsWithValue(t, expectedPanicValue, func() {
_, _ = s.extractServersFromBytes(b, hardcoded)
})
})
} }

View File

@@ -7,27 +7,11 @@ import (
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
func countServers(allServers models.AllServers) int { func countServers(allServers models.AllServers) (count int) {
return len(allServers.Cyberghost.Servers) + for _, servers := range allServers.ProviderToServers {
len(allServers.Expressvpn.Servers) + count += len(servers.Servers)
len(allServers.Fastestvpn.Servers) + }
len(allServers.HideMyAss.Servers) + return count
len(allServers.Ipvanish.Servers) +
len(allServers.Ivpn.Servers) +
len(allServers.Mullvad.Servers) +
len(allServers.Nordvpn.Servers) +
len(allServers.Perfectprivacy.Servers) +
len(allServers.Privado.Servers) +
len(allServers.Pia.Servers) +
len(allServers.Privatevpn.Servers) +
len(allServers.Protonvpn.Servers) +
len(allServers.Purevpn.Servers) +
len(allServers.Surfshark.Servers) +
len(allServers.Torguard.Servers) +
len(allServers.VPNUnlimited.Servers) +
len(allServers.Vyprvpn.Servers) +
len(allServers.Wevpn.Servers) +
len(allServers.Windscribe.Servers)
} }
func (s *Storage) SyncServers() (err error) { func (s *Storage) SyncServers() (err error) {
@@ -57,7 +41,7 @@ func (s *Storage) SyncServers() (err error) {
return nil return nil
} }
if err := flushToFile(s.filepath, s.mergedServers); err != nil { if err := flushToFile(s.filepath, &s.mergedServers); err != nil {
return fmt.Errorf("cannot write servers to file: %w", err) return fmt.Errorf("cannot write servers to file: %w", err)
} }
return nil return nil

View File

@@ -139,7 +139,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
l.stopped <- struct{}{} l.stopped <- struct{}{}
case servers := <-serversCh: case servers := <-serversCh:
l.setAllServers(servers) l.setAllServers(servers)
if err := l.flusher.FlushToFile(servers); err != nil { if err := l.flusher.FlushToFile(&servers); err != nil {
l.logger.Error(err.Error()) l.logger.Error(err.Error())
} }
runWg.Wait() runWg.Wait()

View File

@@ -4,8 +4,10 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"time"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/updater/providers/cyberghost" "github.com/qdm12/gluetun/internal/updater/providers/cyberghost"
"github.com/qdm12/gluetun/internal/updater/providers/expressvpn" "github.com/qdm12/gluetun/internal/updater/providers/expressvpn"
"github.com/qdm12/gluetun/internal/updater/providers/fastestvpn" "github.com/qdm12/gluetun/internal/updater/providers/fastestvpn"
@@ -28,419 +30,96 @@ import (
"github.com/qdm12/gluetun/internal/updater/providers/windscribe" "github.com/qdm12/gluetun/internal/updater/providers/windscribe"
) )
func (u *updater) updateCyberghost(ctx context.Context) (err error) { func (u *updater) updateProvider(ctx context.Context, provider string) (
minServers := getMinServers(len(u.servers.Cyberghost.Servers)) warnings []string, err error) {
servers, err := cyberghost.GetServers(ctx, u.presolver, minServers) existingServers := u.getProviderServers(provider)
minServers := getMinServers(existingServers)
servers, warnings, err := u.getServers(ctx, provider, minServers)
if err != nil { if err != nil {
return err return warnings, err
} }
if reflect.DeepEqual(u.servers.Cyberghost.Servers, servers) { if reflect.DeepEqual(existingServers, servers) {
return nil return warnings, nil
} }
u.servers.Cyberghost.Timestamp = u.timeNow().Unix() u.patchProvider(provider, servers)
u.servers.Cyberghost.Servers = servers return warnings, nil
return nil
} }
func (u *updater) updateExpressvpn(ctx context.Context) (err error) { func (u *updater) getServers(ctx context.Context, provider string,
minServers := getMinServers(len(u.servers.Expressvpn.Servers)) minServers int) (servers []models.Server, warnings []string, err error) {
servers, warnings, err := expressvpn.GetServers( switch provider {
ctx, u.unzipper, u.presolver, minServers) case providers.Custom:
if *u.options.CLI { panic("cannot update custom provider")
for _, warning := range warnings { case providers.Cyberghost:
u.logger.Warn("ExpressVPN: " + warning) servers, err = cyberghost.GetServers(ctx, u.presolver, minServers)
} return servers, nil, err
case providers.Expressvpn:
return expressvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Fastestvpn:
return fastestvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.HideMyAss:
return hidemyass.GetServers(ctx, u.client, u.presolver, minServers)
case providers.Ipvanish:
return ipvanish.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Ivpn:
return ivpn.GetServers(ctx, u.client, u.presolver, minServers)
case providers.Mullvad:
servers, err = mullvad.GetServers(ctx, u.client, minServers)
return servers, nil, err
case providers.Nordvpn:
return nordvpn.GetServers(ctx, u.client, minServers)
case providers.Perfectprivacy:
return perfectprivacy.GetServers(ctx, u.unzipper, minServers)
case providers.Privado:
return privado.GetServers(ctx, u.unzipper, u.client, u.presolver, minServers)
case providers.PrivateInternetAccess:
servers, err = pia.GetServers(ctx, u.client, minServers)
return servers, nil, err
case providers.Privatevpn:
return privatevpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Protonvpn:
return protonvpn.GetServers(ctx, u.client, minServers)
case providers.Purevpn:
return purevpn.GetServers(ctx, u.client, u.unzipper, u.presolver, minServers)
case providers.Surfshark:
return surfshark.GetServers(ctx, u.unzipper, u.client, u.presolver, minServers)
case providers.Torguard:
return torguard.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.VPNUnlimited:
return vpnunlimited.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Vyprvpn:
return vyprvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Wevpn:
return wevpn.GetServers(ctx, u.presolver, minServers)
case providers.Windscribe:
servers, err = windscribe.GetServers(ctx, u.client, minServers)
return servers, nil, err
default:
panic("provider " + provider + " is unknown")
} }
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Expressvpn.Servers, servers) {
return nil
}
u.servers.Expressvpn.Timestamp = u.timeNow().Unix()
u.servers.Expressvpn.Servers = servers
return nil
} }
func (u *updater) updateFastestvpn(ctx context.Context) (err error) { func (u *updater) getProviderServers(provider string) (servers []models.Server) {
minServers := getMinServers(len(u.servers.Fastestvpn.Servers)) providerServers, ok := u.servers.ProviderToServers[provider]
servers, warnings, err := fastestvpn.GetServers( if !ok {
ctx, u.unzipper, u.presolver, minServers) panic(fmt.Sprintf("provider %s is unknown", provider))
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("FastestVPN: " + warning)
}
} }
if err != nil { return providerServers.Servers
return err
}
if reflect.DeepEqual(u.servers.Fastestvpn.Servers, servers) {
return nil
}
u.servers.Fastestvpn.Timestamp = u.timeNow().Unix()
u.servers.Fastestvpn.Servers = servers
return nil
} }
func (u *updater) updateHideMyAss(ctx context.Context) (err error) { func getMinServers(servers []models.Server) (minServers int) {
minServers := getMinServers(len(u.servers.HideMyAss.Servers))
servers, warnings, err := hidemyass.GetServers(
ctx, u.client, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("HideMyAss: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.HideMyAss.Servers, servers) {
return nil
}
u.servers.HideMyAss.Timestamp = u.timeNow().Unix()
u.servers.HideMyAss.Servers = servers
return nil
}
func (u *updater) updateIpvanish(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Ipvanish.Servers))
servers, warnings, err := ipvanish.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Ipvanish: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Ipvanish.Servers, servers) {
return nil
}
u.servers.Ipvanish.Timestamp = u.timeNow().Unix()
u.servers.Ipvanish.Servers = servers
return nil
}
func (u *updater) updateIvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Ivpn.Servers))
servers, warnings, err := ivpn.GetServers(
ctx, u.client, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Ivpn: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Ivpn.Servers, servers) {
return nil
}
u.servers.Ivpn.Timestamp = u.timeNow().Unix()
u.servers.Ivpn.Servers = servers
return nil
}
func (u *updater) updateMullvad(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Mullvad.Servers))
servers, err := mullvad.GetServers(ctx, u.client, minServers)
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Mullvad.Servers, servers) {
return nil
}
u.servers.Mullvad.Timestamp = u.timeNow().Unix()
u.servers.Mullvad.Servers = servers
return nil
}
func (u *updater) updateNordvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Nordvpn.Servers))
servers, warnings, err := nordvpn.GetServers(ctx, u.client, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("NordVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Nordvpn.Servers, servers) {
return nil
}
u.servers.Nordvpn.Timestamp = u.timeNow().Unix()
u.servers.Nordvpn.Servers = servers
return nil
}
func (u *updater) updatePerfectprivacy(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Perfectprivacy.Servers))
servers, warnings, err := perfectprivacy.GetServers(ctx, u.unzipper, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn(providers.Perfectprivacy + ": " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Perfectprivacy.Servers, servers) {
return nil
}
u.servers.Perfectprivacy.Timestamp = u.timeNow().Unix()
u.servers.Perfectprivacy.Servers = servers
return nil
}
func (u *updater) updatePIA(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Pia.Servers))
servers, err := pia.GetServers(ctx, u.client, minServers)
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Pia.Servers, servers) {
return nil
}
u.servers.Pia.Timestamp = u.timeNow().Unix()
u.servers.Pia.Servers = servers
return nil
}
func (u *updater) updatePrivado(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Privado.Servers))
servers, warnings, err := privado.GetServers(
ctx, u.unzipper, u.client, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Privado: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Privado.Servers, servers) {
return nil
}
u.servers.Privado.Timestamp = u.timeNow().Unix()
u.servers.Privado.Servers = servers
return nil
}
func (u *updater) updatePrivatevpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Privatevpn.Servers))
servers, warnings, err := privatevpn.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("PrivateVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Privatevpn.Servers, servers) {
return nil
}
u.servers.Privatevpn.Timestamp = u.timeNow().Unix()
u.servers.Privatevpn.Servers = servers
return nil
}
func (u *updater) updateProtonvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Privatevpn.Servers))
servers, warnings, err := protonvpn.GetServers(ctx, u.client, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("ProtonVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Protonvpn.Servers, servers) {
return nil
}
u.servers.Protonvpn.Timestamp = u.timeNow().Unix()
u.servers.Protonvpn.Servers = servers
return nil
}
func (u *updater) updatePurevpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Purevpn.Servers))
servers, warnings, err := purevpn.GetServers(
ctx, u.client, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("PureVPN: " + warning)
}
}
if err != nil {
return fmt.Errorf("cannot update Purevpn servers: %w", err)
}
if reflect.DeepEqual(u.servers.Purevpn.Servers, servers) {
return nil
}
u.servers.Purevpn.Timestamp = u.timeNow().Unix()
u.servers.Purevpn.Servers = servers
return nil
}
func (u *updater) updateSurfshark(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Surfshark.Servers))
servers, warnings, err := surfshark.GetServers(
ctx, u.unzipper, u.client, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Surfshark: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Surfshark.Servers, servers) {
return nil
}
u.servers.Surfshark.Timestamp = u.timeNow().Unix()
u.servers.Surfshark.Servers = servers
return nil
}
func (u *updater) updateTorguard(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Torguard.Servers))
servers, warnings, err := torguard.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Torguard: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Torguard.Servers, servers) {
return nil
}
u.servers.Torguard.Timestamp = u.timeNow().Unix()
u.servers.Torguard.Servers = servers
return nil
}
func (u *updater) updateVPNUnlimited(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.VPNUnlimited.Servers))
servers, warnings, err := vpnunlimited.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn(providers.VPNUnlimited + ": " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.VPNUnlimited.Servers, servers) {
return nil
}
u.servers.VPNUnlimited.Timestamp = u.timeNow().Unix()
u.servers.VPNUnlimited.Servers = servers
return nil
}
func (u *updater) updateVyprvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Vyprvpn.Servers))
servers, warnings, err := vyprvpn.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("VyprVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Vyprvpn.Servers, servers) {
return nil
}
u.servers.Vyprvpn.Timestamp = u.timeNow().Unix()
u.servers.Vyprvpn.Servers = servers
return nil
}
func (u *updater) updateWevpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Wevpn.Servers))
servers, warnings, err := wevpn.GetServers(ctx, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("WeVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Wevpn.Servers, servers) {
return nil
}
u.servers.Wevpn.Timestamp = u.timeNow().Unix()
u.servers.Wevpn.Servers = servers
return nil
}
func (u *updater) updateWindscribe(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Windscribe.Servers))
servers, err := windscribe.GetServers(ctx, u.client, minServers)
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Windscribe.Servers, servers) {
return nil
}
u.servers.Windscribe.Timestamp = u.timeNow().Unix()
u.servers.Windscribe.Servers = servers
return nil
}
func getMinServers(existingServers int) (minServers int) {
const minRatio = 0.8 const minRatio = 0.8
return int(minRatio * float64(existingServers)) return int(minRatio * float64(len(servers)))
}
func (u *updater) patchProvider(provider string, servers []models.Server) {
providerServers, ok := u.servers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s is unknown", provider))
}
providerServers.Timestamp = time.Now().Unix()
providerServers.Servers = servers
u.servers.ProviderToServers[provider] = providerServers
} }

View File

@@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
"github.com/qdm12/gluetun/internal/updater/unzip" "github.com/qdm12/gluetun/internal/updater/unzip"
@@ -47,16 +46,17 @@ func New(settings settings.Updater, httpClient *http.Client,
} }
} }
type updateFunc func(ctx context.Context) (err error)
func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) { func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) {
for _, provider := range u.options.Providers { for _, provider := range u.options.Providers {
u.logger.Info("updating " + strings.Title(provider) + " servers...") u.logger.Info("updating " + strings.Title(provider) + " servers...")
updateProvider := u.getUpdateFunction(provider)
// TODO support servers offering only TCP or only UDP // TODO support servers offering only TCP or only UDP
// for NordVPN and PureVPN // for NordVPN and PureVPN
err = updateProvider(ctx) warnings, err := u.updateProvider(ctx, provider)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn(provider + ": " + warning)
}
}
if err != nil { if err != nil {
if ctxErr := ctx.Err(); ctxErr != nil { if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr return allServers, ctxErr
@@ -67,52 +67,3 @@ func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServe
return u.servers, nil return u.servers, nil
} }
func (u *updater) getUpdateFunction(provider string) (updateFunction updateFunc) {
switch provider {
case providers.Custom:
panic("cannot update custom provider")
case providers.Cyberghost:
return func(ctx context.Context) (err error) { return u.updateCyberghost(ctx) }
case providers.Expressvpn:
return func(ctx context.Context) (err error) { return u.updateExpressvpn(ctx) }
case providers.Fastestvpn:
return func(ctx context.Context) (err error) { return u.updateFastestvpn(ctx) }
case providers.HideMyAss:
return func(ctx context.Context) (err error) { return u.updateHideMyAss(ctx) }
case providers.Ipvanish:
return func(ctx context.Context) (err error) { return u.updateIpvanish(ctx) }
case providers.Ivpn:
return func(ctx context.Context) (err error) { return u.updateIvpn(ctx) }
case providers.Mullvad:
return func(ctx context.Context) (err error) { return u.updateMullvad(ctx) }
case providers.Nordvpn:
return func(ctx context.Context) (err error) { return u.updateNordvpn(ctx) }
case providers.Perfectprivacy:
return func(ctx context.Context) (err error) { return u.updatePerfectprivacy(ctx) }
case providers.Privado:
return func(ctx context.Context) (err error) { return u.updatePrivado(ctx) }
case providers.PrivateInternetAccess:
return func(ctx context.Context) (err error) { return u.updatePIA(ctx) }
case providers.Privatevpn:
return func(ctx context.Context) (err error) { return u.updatePrivatevpn(ctx) }
case providers.Protonvpn:
return func(ctx context.Context) (err error) { return u.updateProtonvpn(ctx) }
case providers.Purevpn:
return func(ctx context.Context) (err error) { return u.updatePurevpn(ctx) }
case providers.Surfshark:
return func(ctx context.Context) (err error) { return u.updateSurfshark(ctx) }
case providers.Torguard:
return func(ctx context.Context) (err error) { return u.updateTorguard(ctx) }
case providers.VPNUnlimited:
return func(ctx context.Context) (err error) { return u.updateVPNUnlimited(ctx) }
case providers.Vyprvpn:
return func(ctx context.Context) (err error) { return u.updateVyprvpn(ctx) }
case providers.Wevpn:
return func(ctx context.Context) (err error) { return u.updateWevpn(ctx) }
case providers.Windscribe:
return func(ctx context.Context) (err error) { return u.updateWindscribe(ctx) }
default:
panic("provider " + provider + " is unknown")
}
}