chore(all): provider to servers map in allServers
- Simplify formatting CLI - Simplify updater code - Simplify filter choices for config validation - Simplify all servers deep copying - Custom JSON marshaling methods for `AllServers` - Simplify provider constructor switch - Simplify storage merging - Simplify storage reading and extraction - Simplify updating code
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
)
|
||||
|
||||
@@ -20,6 +21,7 @@ type ServersFormatter interface {
|
||||
var (
|
||||
ErrFormatNotRecognized = errors.New("format is not recognized")
|
||||
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
|
||||
ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified")
|
||||
)
|
||||
|
||||
func addProviderFlag(flagSet *flag.FlagSet,
|
||||
@@ -31,21 +33,12 @@ func addProviderFlag(flagSet *flag.FlagSet,
|
||||
flagSet.BoolVar(boolPtr, provider, false, "Format "+strings.Title(provider)+" servers")
|
||||
}
|
||||
|
||||
func getFormatForProvider(providerToFormat map[string]*bool, provider string) (format bool) {
|
||||
formatPtr, ok := providerToFormat[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
|
||||
}
|
||||
return *formatPtr
|
||||
}
|
||||
|
||||
func (c *CLI) FormatServers(args []string) error {
|
||||
var format, output string
|
||||
allProviders := providers.All()
|
||||
providersToFormat := make(map[string]*bool, len(allProviders))
|
||||
for _, provider := range allProviders {
|
||||
value := false
|
||||
providersToFormat[provider] = &value
|
||||
providersToFormat[provider] = new(bool)
|
||||
}
|
||||
flagSet := flag.NewFlagSet("markdown", flag.ExitOnError)
|
||||
flagSet.StringVar(&format, "format", "markdown", "Format to use which can be: 'markdown'")
|
||||
@@ -61,6 +54,24 @@ func (c *CLI) FormatServers(args []string) error {
|
||||
return fmt.Errorf("%w: %s", ErrFormatNotRecognized, format)
|
||||
}
|
||||
|
||||
// Verify only one provider is set to be formatted.
|
||||
var providers []string
|
||||
for provider, formatPtr := range providersToFormat {
|
||||
if *formatPtr {
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
}
|
||||
switch len(providers) {
|
||||
case 0:
|
||||
return ErrProviderUnspecified
|
||||
case 1:
|
||||
default:
|
||||
return fmt.Errorf("%w: %d specified: %s",
|
||||
ErrMultipleProvidersToFormat, len(providers),
|
||||
strings.Join(providers, ", "))
|
||||
}
|
||||
providerToFormat := providers[0]
|
||||
|
||||
logger := newNoopLogger()
|
||||
storage, err := storage.New(logger, constants.ServersData)
|
||||
if err != nil {
|
||||
@@ -68,51 +79,7 @@ func (c *CLI) FormatServers(args []string) error {
|
||||
}
|
||||
currentServers := storage.GetServers()
|
||||
|
||||
var formatted string
|
||||
switch {
|
||||
case getFormatForProvider(providersToFormat, providers.Cyberghost):
|
||||
formatted = currentServers.Cyberghost.ToMarkdown(providers.Cyberghost)
|
||||
case getFormatForProvider(providersToFormat, providers.Expressvpn):
|
||||
formatted = currentServers.Expressvpn.ToMarkdown(providers.Expressvpn)
|
||||
case getFormatForProvider(providersToFormat, providers.Fastestvpn):
|
||||
formatted = currentServers.Fastestvpn.ToMarkdown(providers.Fastestvpn)
|
||||
case getFormatForProvider(providersToFormat, providers.HideMyAss):
|
||||
formatted = currentServers.HideMyAss.ToMarkdown(providers.HideMyAss)
|
||||
case getFormatForProvider(providersToFormat, providers.Ipvanish):
|
||||
formatted = currentServers.Ipvanish.ToMarkdown(providers.Ipvanish)
|
||||
case getFormatForProvider(providersToFormat, providers.Ivpn):
|
||||
formatted = currentServers.Ivpn.ToMarkdown(providers.Ivpn)
|
||||
case getFormatForProvider(providersToFormat, providers.Mullvad):
|
||||
formatted = currentServers.Mullvad.ToMarkdown(providers.Mullvad)
|
||||
case getFormatForProvider(providersToFormat, providers.Nordvpn):
|
||||
formatted = currentServers.Nordvpn.ToMarkdown(providers.Nordvpn)
|
||||
case getFormatForProvider(providersToFormat, providers.Perfectprivacy):
|
||||
formatted = currentServers.Perfectprivacy.ToMarkdown(providers.Perfectprivacy)
|
||||
case getFormatForProvider(providersToFormat, providers.PrivateInternetAccess):
|
||||
formatted = currentServers.Pia.ToMarkdown(providers.PrivateInternetAccess)
|
||||
case getFormatForProvider(providersToFormat, providers.Privado):
|
||||
formatted = currentServers.Privado.ToMarkdown(providers.Privado)
|
||||
case getFormatForProvider(providersToFormat, providers.Privatevpn):
|
||||
formatted = currentServers.Privatevpn.ToMarkdown(providers.Privatevpn)
|
||||
case getFormatForProvider(providersToFormat, providers.Protonvpn):
|
||||
formatted = currentServers.Protonvpn.ToMarkdown(providers.Protonvpn)
|
||||
case getFormatForProvider(providersToFormat, providers.Purevpn):
|
||||
formatted = currentServers.Purevpn.ToMarkdown(providers.Purevpn)
|
||||
case getFormatForProvider(providersToFormat, providers.Surfshark):
|
||||
formatted = currentServers.Surfshark.ToMarkdown(providers.Surfshark)
|
||||
case getFormatForProvider(providersToFormat, providers.Torguard):
|
||||
formatted = currentServers.Torguard.ToMarkdown(providers.Torguard)
|
||||
case getFormatForProvider(providersToFormat, providers.VPNUnlimited):
|
||||
formatted = currentServers.VPNUnlimited.ToMarkdown(providers.VPNUnlimited)
|
||||
case getFormatForProvider(providersToFormat, providers.Vyprvpn):
|
||||
formatted = currentServers.Vyprvpn.ToMarkdown(providers.Vyprvpn)
|
||||
case getFormatForProvider(providersToFormat, providers.Wevpn):
|
||||
formatted = currentServers.Wevpn.ToMarkdown(providers.Wevpn)
|
||||
case getFormatForProvider(providersToFormat, providers.Windscribe):
|
||||
formatted = currentServers.Windscribe.ToMarkdown(providers.Windscribe)
|
||||
default:
|
||||
return ErrProviderUnspecified
|
||||
}
|
||||
formatted := formatServers(currentServers, providerToFormat)
|
||||
|
||||
output = filepath.Clean(output)
|
||||
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
|
||||
@@ -133,3 +100,11 @@ func (c *CLI) FormatServers(args []string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatServers(allServers models.AllServers, provider string) (formatted string) {
|
||||
servers, ok := allServers.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
|
||||
}
|
||||
return servers.ToMarkdown(providers.Cyberghost)
|
||||
}
|
||||
|
||||
@@ -63,12 +63,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
||||
}
|
||||
|
||||
if updateAll {
|
||||
for _, provider := range providers.All() {
|
||||
if provider == providers.Custom {
|
||||
continue
|
||||
}
|
||||
options.Providers = append(options.Providers, provider)
|
||||
}
|
||||
options.Providers = providers.All()
|
||||
} else {
|
||||
if csvProviders == "" {
|
||||
return ErrNoProviderSpecified
|
||||
@@ -99,13 +94,13 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
||||
}
|
||||
|
||||
if endUserMode {
|
||||
if err := storage.FlushToFile(allServers); err != nil {
|
||||
if err := storage.FlushToFile(&allServers); err != nil {
|
||||
return fmt.Errorf("cannot write updated information to file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if maintainerMode {
|
||||
if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil {
|
||||
if err := writeToEmbeddedJSON(c.repoServersPath, &allServers); err != nil {
|
||||
return fmt.Errorf("cannot write updated information to file: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -114,7 +109,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
||||
}
|
||||
|
||||
func writeToEmbeddedJSON(repoServersPath string,
|
||||
allServers models.AllServers) error {
|
||||
allServers *models.AllServers) error {
|
||||
const perms = 0600
|
||||
f, err := os.OpenFile(repoServersPath,
|
||||
os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)
|
||||
|
||||
@@ -27,7 +27,7 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e
|
||||
// Validate Name
|
||||
var validNames []string
|
||||
if vpnType == vpn.OpenVPN {
|
||||
validNames = providers.All()
|
||||
validNames = providers.AllWithCustom()
|
||||
validNames = append(validNames, "pia") // Retro-compatibility
|
||||
} else { // Wireguard
|
||||
validNames = []string{
|
||||
|
||||
@@ -140,118 +140,26 @@ func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection,
|
||||
countryChoices, regionChoices, cityChoices,
|
||||
ispChoices, nameChoices, hostnameChoices []string,
|
||||
err error) {
|
||||
switch vpnServiceProvider {
|
||||
case providers.Custom:
|
||||
case providers.Cyberghost:
|
||||
servers := allServers.GetCyberghost()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Expressvpn:
|
||||
servers := allServers.GetExpressvpn()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Fastestvpn:
|
||||
servers := allServers.GetFastestvpn()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.HideMyAss:
|
||||
servers := allServers.GetHideMyAss()
|
||||
providerServers, ok := allServers.ProviderToServers[vpnServiceProvider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("VPN service provider unknown: %s", vpnServiceProvider))
|
||||
}
|
||||
servers := providerServers.Servers
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Ipvanish:
|
||||
servers := allServers.GetIpvanish()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Ivpn:
|
||||
servers := allServers.GetIvpn()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
ispChoices = validation.ExtractISPs(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Mullvad:
|
||||
servers := allServers.GetMullvad()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
ispChoices = validation.ExtractISPs(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Nordvpn:
|
||||
servers := allServers.GetNordvpn()
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Perfectprivacy:
|
||||
servers := allServers.GetPerfectprivacy()
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
case providers.Privado:
|
||||
servers := allServers.GetPrivado()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.PrivateInternetAccess:
|
||||
servers := allServers.GetPia()
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
nameChoices = validation.ExtractServerNames(servers)
|
||||
case providers.Privatevpn:
|
||||
servers := allServers.GetPrivatevpn()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Protonvpn:
|
||||
servers := allServers.GetProtonvpn()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
nameChoices = validation.ExtractServerNames(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Purevpn:
|
||||
servers := allServers.GetPurevpn()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Surfshark:
|
||||
servers := allServers.GetSurfshark()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
|
||||
if vpnServiceProvider == providers.Surfshark {
|
||||
// // Retro compatibility
|
||||
// TODO v4 remove
|
||||
regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...)
|
||||
if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err)
|
||||
}
|
||||
// Retro compatibility
|
||||
// TODO remove in v4
|
||||
*ss = surfsharkRetroRegion(*ss)
|
||||
case providers.Torguard:
|
||||
servers := allServers.GetTorguard()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.VPNUnlimited:
|
||||
servers := allServers.GetVPNUnlimited()
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Vyprvpn:
|
||||
servers := allServers.GetVyprvpn()
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
case providers.Wevpn:
|
||||
servers := allServers.GetWevpn()
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
case providers.Windscribe:
|
||||
servers := allServers.GetWindscribe()
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrVPNProviderNameNotValid, vpnServiceProvider)
|
||||
}
|
||||
|
||||
return countryChoices, regionChoices, cityChoices,
|
||||
|
||||
@@ -43,10 +43,6 @@ func (u Updater) Validate() (err error) {
|
||||
for i, provider := range u.Providers {
|
||||
valid := false
|
||||
for _, validProvider := range providers.All() {
|
||||
if validProvider == providers.Custom {
|
||||
continue
|
||||
}
|
||||
|
||||
if provider == validProvider {
|
||||
valid = true
|
||||
break
|
||||
|
||||
@@ -19,6 +19,9 @@ func ExtractCountries(servers []models.Server) (values []string) {
|
||||
values = make([]string, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
value := server.Country
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
_, alreadySeen := seen[value]
|
||||
if alreadySeen {
|
||||
continue
|
||||
@@ -35,6 +38,9 @@ func ExtractRegions(servers []models.Server) (values []string) {
|
||||
values = make([]string, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
value := server.Region
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
_, alreadySeen := seen[value]
|
||||
if alreadySeen {
|
||||
continue
|
||||
@@ -51,6 +57,9 @@ func ExtractCities(servers []models.Server) (values []string) {
|
||||
values = make([]string, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
value := server.City
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
_, alreadySeen := seen[value]
|
||||
if alreadySeen {
|
||||
continue
|
||||
@@ -67,6 +76,9 @@ func ExtractISPs(servers []models.Server) (values []string) {
|
||||
values = make([]string, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
value := server.ISP
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
_, alreadySeen := seen[value]
|
||||
if alreadySeen {
|
||||
continue
|
||||
@@ -83,6 +95,9 @@ func ExtractServerNames(servers []models.Server) (values []string) {
|
||||
values = make([]string, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
value := server.ServerName
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
_, alreadySeen := seen[value]
|
||||
if alreadySeen {
|
||||
continue
|
||||
@@ -99,6 +114,9 @@ func ExtractHostnames(servers []models.Server) (values []string) {
|
||||
values = make([]string, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
value := server.Hostname
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
_, alreadySeen := seen[value]
|
||||
if alreadySeen {
|
||||
continue
|
||||
|
||||
@@ -26,9 +26,9 @@ const (
|
||||
Windscribe = "windscribe"
|
||||
)
|
||||
|
||||
// All returns all the providers except the custom provider.
|
||||
func All() []string {
|
||||
return []string{
|
||||
Custom,
|
||||
Cyberghost,
|
||||
Expressvpn,
|
||||
Fastestvpn,
|
||||
@@ -51,3 +51,11 @@ func All() []string {
|
||||
Windscribe,
|
||||
}
|
||||
}
|
||||
|
||||
func AllWithCustom() []string {
|
||||
allProviders := All()
|
||||
allProvidersWithCustom := make([]string, len(allProviders)+1)
|
||||
copy(allProvidersWithCustom, allProviders)
|
||||
allProvidersWithCustom[len(allProvidersWithCustom)-1] = Custom
|
||||
return allProvidersWithCustom
|
||||
}
|
||||
|
||||
23
internal/constants/providers/providers_test.go
Normal file
23
internal/constants/providers/providers_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_All(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
all := All()
|
||||
assert.NotContains(t, all, Custom)
|
||||
assert.NotEmpty(t, all)
|
||||
}
|
||||
|
||||
func Test_AllWithCustom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
all := AllWithCustom()
|
||||
assert.Contains(t, all, Custom)
|
||||
assert.Len(t, all, len(All())+1)
|
||||
}
|
||||
@@ -4,108 +4,17 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func (a AllServers) GetCopy() (servers AllServers) {
|
||||
servers = a // copy versions and timestamps
|
||||
servers.Cyberghost.Servers = a.GetCyberghost()
|
||||
servers.Expressvpn.Servers = a.GetExpressvpn()
|
||||
servers.Fastestvpn.Servers = a.GetFastestvpn()
|
||||
servers.HideMyAss.Servers = a.GetHideMyAss()
|
||||
servers.Ipvanish.Servers = a.GetIpvanish()
|
||||
servers.Ivpn.Servers = a.GetIvpn()
|
||||
servers.Mullvad.Servers = a.GetMullvad()
|
||||
servers.Nordvpn.Servers = a.GetNordvpn()
|
||||
servers.Perfectprivacy.Servers = a.GetPerfectprivacy()
|
||||
servers.Privado.Servers = a.GetPrivado()
|
||||
servers.Pia.Servers = a.GetPia()
|
||||
servers.Privatevpn.Servers = a.GetPrivatevpn()
|
||||
servers.Protonvpn.Servers = a.GetProtonvpn()
|
||||
servers.Purevpn.Servers = a.GetPurevpn()
|
||||
servers.Surfshark.Servers = a.GetSurfshark()
|
||||
servers.Torguard.Servers = a.GetTorguard()
|
||||
servers.VPNUnlimited.Servers = a.GetVPNUnlimited()
|
||||
servers.Vyprvpn.Servers = a.GetVyprvpn()
|
||||
servers.Windscribe.Servers = a.GetWindscribe()
|
||||
return servers
|
||||
}
|
||||
|
||||
func (a *AllServers) GetCyberghost() (servers []Server) {
|
||||
return copyServers(a.Cyberghost.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetExpressvpn() (servers []Server) {
|
||||
return copyServers(a.Expressvpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetFastestvpn() (servers []Server) {
|
||||
return copyServers(a.Fastestvpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetHideMyAss() (servers []Server) {
|
||||
return copyServers(a.HideMyAss.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetIpvanish() (servers []Server) {
|
||||
return copyServers(a.Ipvanish.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetIvpn() (servers []Server) {
|
||||
return copyServers(a.Ivpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetMullvad() (servers []Server) {
|
||||
return copyServers(a.Mullvad.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetNordvpn() (servers []Server) {
|
||||
return copyServers(a.Nordvpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetPerfectprivacy() (servers []Server) {
|
||||
return copyServers(a.Perfectprivacy.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetPia() (servers []Server) {
|
||||
return copyServers(a.Pia.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetPrivado() (servers []Server) {
|
||||
return copyServers(a.Privado.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetPrivatevpn() (servers []Server) {
|
||||
return copyServers(a.Privatevpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetProtonvpn() (servers []Server) {
|
||||
return copyServers(a.Protonvpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetPurevpn() (servers []Server) {
|
||||
return copyServers(a.Purevpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetSurfshark() (servers []Server) {
|
||||
return copyServers(a.Surfshark.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetTorguard() (servers []Server) {
|
||||
return copyServers(a.Torguard.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetVPNUnlimited() (servers []Server) {
|
||||
return copyServers(a.VPNUnlimited.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetVyprvpn() (servers []Server) {
|
||||
return copyServers(a.Vyprvpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetWevpn() (servers []Server) {
|
||||
return copyServers(a.Wevpn.Servers)
|
||||
}
|
||||
|
||||
func (a *AllServers) GetWindscribe() (servers []Server) {
|
||||
return copyServers(a.Windscribe.Servers)
|
||||
func (a AllServers) GetCopy() (allServersCopy AllServers) {
|
||||
allServersCopy.Version = a.Version
|
||||
allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers))
|
||||
for provider, servers := range a.ProviderToServers {
|
||||
allServersCopy.ProviderToServers[provider] = Servers{
|
||||
Version: servers.Version,
|
||||
Timestamp: servers.Timestamp,
|
||||
Servers: copyServers(servers.Servers),
|
||||
}
|
||||
}
|
||||
return allServersCopy
|
||||
}
|
||||
|
||||
func copyServers(servers []Server) (serversCopy []Server) {
|
||||
|
||||
@@ -4,109 +4,118 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_AllServers_GetCopy(t *testing.T) {
|
||||
allServers := AllServers{
|
||||
Cyberghost: Servers{
|
||||
Version: 1,
|
||||
ProviderToServers: map[string]Servers{
|
||||
providers.Cyberghost: {
|
||||
Version: 2,
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Expressvpn: Servers{
|
||||
providers.Expressvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Fastestvpn: Servers{
|
||||
providers.Fastestvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
HideMyAss: Servers{
|
||||
providers.HideMyAss: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Ipvanish: Servers{
|
||||
providers.Ipvanish: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Ivpn: Servers{
|
||||
providers.Ivpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Mullvad: Servers{
|
||||
providers.Mullvad: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Nordvpn: Servers{
|
||||
providers.Nordvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Perfectprivacy: Servers{
|
||||
providers.Perfectprivacy: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Privado: Servers{
|
||||
providers.Privado: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Pia: Servers{
|
||||
providers.PrivateInternetAccess: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Privatevpn: Servers{
|
||||
providers.Privatevpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Protonvpn: Servers{
|
||||
providers.Protonvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Purevpn: Servers{
|
||||
providers.Purevpn: {
|
||||
Version: 1,
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Surfshark: Servers{
|
||||
providers.Surfshark: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Torguard: Servers{
|
||||
providers.Torguard: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
VPNUnlimited: Servers{
|
||||
providers.VPNUnlimited: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Vyprvpn: Servers{
|
||||
providers.Vyprvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
Windscribe: Servers{
|
||||
providers.Wevpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Windscribe: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
servers := allServers.GetCopy()
|
||||
@@ -114,32 +123,6 @@ func Test_AllServers_GetCopy(t *testing.T) {
|
||||
assert.Equal(t, allServers, servers)
|
||||
}
|
||||
|
||||
func Test_AllServers_GetVyprvpn(t *testing.T) {
|
||||
allServers := AllServers{
|
||||
Vyprvpn: Servers{
|
||||
Servers: []Server{
|
||||
{Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}},
|
||||
{Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
servers := allServers.GetVyprvpn()
|
||||
|
||||
expectedServers := []Server{
|
||||
{Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}},
|
||||
{Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}},
|
||||
}
|
||||
assert.Equal(t, expectedServers, servers)
|
||||
|
||||
allServers.Vyprvpn.Servers[0].IPs[0][0] = 9
|
||||
assert.NotEqual(t, 9, servers[0].IPs[0][0])
|
||||
|
||||
allServers.Vyprvpn.Servers[0].IPs[0][0] = 1
|
||||
servers[0].IPs[0][0] = 9
|
||||
assert.NotEqual(t, 9, allServers.Vyprvpn.Servers[0].IPs[0][0])
|
||||
}
|
||||
|
||||
func Test_copyIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1,54 +1,163 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
)
|
||||
|
||||
type AllServers struct {
|
||||
Version uint16 `json:"version"` // used for migration of the top level scheme
|
||||
Cyberghost Servers `json:"cyberghost"`
|
||||
Expressvpn Servers `json:"expressvpn"`
|
||||
Fastestvpn Servers `json:"fastestvpn"`
|
||||
HideMyAss Servers `json:"hidemyass"`
|
||||
Ipvanish Servers `json:"ipvanish"`
|
||||
Ivpn Servers `json:"ivpn"`
|
||||
Mullvad Servers `json:"mullvad"`
|
||||
Perfectprivacy Servers `json:"perfect privacy"`
|
||||
Nordvpn Servers `json:"nordvpn"`
|
||||
Privado Servers `json:"privado"`
|
||||
Pia Servers `json:"private internet access"`
|
||||
Privatevpn Servers `json:"privatevpn"`
|
||||
Protonvpn Servers `json:"protonvpn"`
|
||||
Purevpn Servers `json:"purevpn"`
|
||||
Surfshark Servers `json:"surfshark"`
|
||||
Torguard Servers `json:"torguard"`
|
||||
VPNUnlimited Servers `json:"vpn unlimited"`
|
||||
Vyprvpn Servers `json:"vyprvpn"`
|
||||
Wevpn Servers `json:"wevpn"`
|
||||
Windscribe Servers `json:"windscribe"`
|
||||
Version uint16 // used for migration of the top level scheme
|
||||
ProviderToServers map[string]Servers
|
||||
}
|
||||
|
||||
func (a *AllServers) Count() int {
|
||||
return len(a.Cyberghost.Servers) +
|
||||
len(a.Expressvpn.Servers) +
|
||||
len(a.Fastestvpn.Servers) +
|
||||
len(a.HideMyAss.Servers) +
|
||||
len(a.Ipvanish.Servers) +
|
||||
len(a.Ivpn.Servers) +
|
||||
len(a.Mullvad.Servers) +
|
||||
len(a.Nordvpn.Servers) +
|
||||
len(a.Perfectprivacy.Servers) +
|
||||
len(a.Privado.Servers) +
|
||||
len(a.Pia.Servers) +
|
||||
len(a.Privatevpn.Servers) +
|
||||
len(a.Protonvpn.Servers) +
|
||||
len(a.Purevpn.Servers) +
|
||||
len(a.Surfshark.Servers) +
|
||||
len(a.Torguard.Servers) +
|
||||
len(a.VPNUnlimited.Servers) +
|
||||
len(a.Vyprvpn.Servers) +
|
||||
len(a.Wevpn.Servers) +
|
||||
len(a.Windscribe.Servers)
|
||||
func (a *AllServers) ServersSlice(provider string) []Server {
|
||||
servers, ok := a.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s not found in all servers", provider))
|
||||
}
|
||||
return copyServers(servers.Servers)
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*AllServers)(nil)
|
||||
|
||||
// MarshalJSON marshals all servers to JSON.
|
||||
// Note you need to use a pointer to all servers
|
||||
// for it to work with native json methods such as
|
||||
// json.Marshal.
|
||||
func (a *AllServers) MarshalJSON() (data []byte, err error) {
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
|
||||
_, err = buffer.WriteString("{")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot write opening bracket: %w", err)
|
||||
}
|
||||
|
||||
versionString := fmt.Sprintf(`"version":%d`, a.Version)
|
||||
_, err = buffer.WriteString(versionString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot write schema version string: %w", err)
|
||||
}
|
||||
|
||||
for _, provider := range providers.All() {
|
||||
servers, ok := a.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s not found in all servers", provider))
|
||||
}
|
||||
|
||||
providerKey := fmt.Sprintf(`,"%s":`, provider)
|
||||
_, err = buffer.WriteString(providerKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot write provider key %s: %w",
|
||||
providerKey, err)
|
||||
}
|
||||
|
||||
serversJSON, err := json.Marshal(servers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed encoding servers for provider %s: %w",
|
||||
provider, err)
|
||||
}
|
||||
_, err = buffer.Write(serversJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot write JSON servers data for provider %s: %w",
|
||||
provider, err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = buffer.WriteString("}")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot write closing bracket: %w", err)
|
||||
}
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
var _ json.Unmarshaler = (*AllServers)(nil)
|
||||
|
||||
func (a *AllServers) UnmarshalJSON(data []byte) (err error) {
|
||||
keyValues := make(map[string]interface{})
|
||||
err = json.Unmarshal(data, &keyValues)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
versionUnmarshaled := keyValues["version"]
|
||||
if versionUnmarshaled != nil { // defaults to 0
|
||||
version, ok := versionUnmarshaled.(float64)
|
||||
if !ok {
|
||||
return &json.UnmarshalTypeError{
|
||||
Value: fmt.Sprintf("number %v", versionUnmarshaled),
|
||||
Type: reflect.TypeOf(uint16(0)),
|
||||
Struct: "models.AllServers",
|
||||
Field: "Version",
|
||||
}
|
||||
}
|
||||
|
||||
if math.Round(version) != version ||
|
||||
version < 0 || version > float64(^uint16(0)) {
|
||||
return &json.UnmarshalTypeError{
|
||||
Value: fmt.Sprintf("number %v", version),
|
||||
Type: reflect.TypeOf(uint16(0)),
|
||||
Struct: "models.AllServers",
|
||||
Field: "Version",
|
||||
}
|
||||
}
|
||||
|
||||
a.Version = uint16(version)
|
||||
delete(keyValues, "version")
|
||||
}
|
||||
|
||||
if len(keyValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
a.ProviderToServers = make(map[string]Servers, len(keyValues))
|
||||
|
||||
allProviders := providers.All()
|
||||
allProvidersSet := make(map[string]struct{}, len(allProviders))
|
||||
for _, provider := range allProviders {
|
||||
allProvidersSet[provider] = struct{}{}
|
||||
}
|
||||
|
||||
for key, value := range keyValues {
|
||||
if _, ok := allProvidersSet[key]; !ok {
|
||||
// not a provider known by Gluetun
|
||||
// or a non-servers field.
|
||||
continue
|
||||
}
|
||||
|
||||
jsonValue, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot marshal %s servers: %w",
|
||||
key, err)
|
||||
}
|
||||
|
||||
var servers Servers
|
||||
err = json.Unmarshal(jsonValue, &servers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot unmarshal %s servers: %w",
|
||||
key, err)
|
||||
}
|
||||
|
||||
a.ProviderToServers[key] = servers
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AllServers) Count() (count int) {
|
||||
for _, servers := range a.ProviderToServers {
|
||||
count += len(servers.Servers)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
type Servers struct {
|
||||
Version uint16 `json:"version"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Servers []Server `json:"servers"`
|
||||
Servers []Server `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
189
internal/models/servers_test.go
Normal file
189
internal/models/servers_test.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_AllServers_MarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
allServers *AllServers
|
||||
dataString string
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
allServers: &AllServers{
|
||||
ProviderToServers: map[string]Servers{},
|
||||
},
|
||||
dataString: `{"version":0,` +
|
||||
`"cyberghost":{"version":0,"timestamp":0},` +
|
||||
`"expressvpn":{"version":0,"timestamp":0},` +
|
||||
`"fastestvpn":{"version":0,"timestamp":0},` +
|
||||
`"hidemyass":{"version":0,"timestamp":0},` +
|
||||
`"ipvanish":{"version":0,"timestamp":0},` +
|
||||
`"ivpn":{"version":0,"timestamp":0},` +
|
||||
`"mullvad":{"version":0,"timestamp":0},` +
|
||||
`"nordvpn":{"version":0,"timestamp":0},` +
|
||||
`"perfect privacy":{"version":0,"timestamp":0},` +
|
||||
`"privado":{"version":0,"timestamp":0},` +
|
||||
`"private internet access":{"version":0,"timestamp":0},` +
|
||||
`"privatevpn":{"version":0,"timestamp":0},` +
|
||||
`"protonvpn":{"version":0,"timestamp":0},` +
|
||||
`"purevpn":{"version":0,"timestamp":0},` +
|
||||
`"surfshark":{"version":0,"timestamp":0},` +
|
||||
`"torguard":{"version":0,"timestamp":0},` +
|
||||
`"vpn unlimited":{"version":0,"timestamp":0},` +
|
||||
`"vyprvpn":{"version":0,"timestamp":0},` +
|
||||
`"wevpn":{"version":0,"timestamp":0},` +
|
||||
`"windscribe":{"version":0,"timestamp":0}}`,
|
||||
},
|
||||
"two known providers": {
|
||||
allServers: &AllServers{
|
||||
Version: 1,
|
||||
ProviderToServers: map[string]Servers{
|
||||
providers.Cyberghost: {
|
||||
Version: 1,
|
||||
Timestamp: 1000,
|
||||
Servers: []Server{
|
||||
{Country: "A"},
|
||||
{Country: "B"},
|
||||
},
|
||||
},
|
||||
providers.Privado: {
|
||||
Version: 2,
|
||||
Timestamp: 2000,
|
||||
Servers: []Server{
|
||||
{City: "C"},
|
||||
{City: "D"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
dataString: `{"version":1,` +
|
||||
`"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` +
|
||||
`"expressvpn":{"version":0,"timestamp":0},` +
|
||||
`"fastestvpn":{"version":0,"timestamp":0},` +
|
||||
`"hidemyass":{"version":0,"timestamp":0},` +
|
||||
`"ipvanish":{"version":0,"timestamp":0},` +
|
||||
`"ivpn":{"version":0,"timestamp":0},` +
|
||||
`"mullvad":{"version":0,"timestamp":0},` +
|
||||
`"nordvpn":{"version":0,"timestamp":0},` +
|
||||
`"perfect privacy":{"version":0,"timestamp":0},` +
|
||||
`"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]},` +
|
||||
`"private internet access":{"version":0,"timestamp":0},` +
|
||||
`"privatevpn":{"version":0,"timestamp":0},` +
|
||||
`"protonvpn":{"version":0,"timestamp":0},` +
|
||||
`"purevpn":{"version":0,"timestamp":0},` +
|
||||
`"surfshark":{"version":0,"timestamp":0},` +
|
||||
`"torguard":{"version":0,"timestamp":0},` +
|
||||
`"vpn unlimited":{"version":0,"timestamp":0},` +
|
||||
`"vyprvpn":{"version":0,"timestamp":0},` +
|
||||
`"wevpn":{"version":0,"timestamp":0},` +
|
||||
`"windscribe":{"version":0,"timestamp":0}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Populate all providers in all servers
|
||||
for _, provider := range providers.All() {
|
||||
_, has := testCase.allServers.ProviderToServers[provider]
|
||||
if !has {
|
||||
testCase.allServers.ProviderToServers[provider] = Servers{}
|
||||
}
|
||||
}
|
||||
|
||||
data, err := testCase.allServers.MarshalJSON()
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
assert.Equal(t, testCase.dataString, string(data))
|
||||
|
||||
data, err = json.Marshal(testCase.allServers)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
assert.Equal(t, testCase.dataString, string(data))
|
||||
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
encoder := json.NewEncoder(buffer)
|
||||
// encoder.SetIndent("", " ")
|
||||
err = encoder.Encode(testCase.allServers)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testCase.dataString+"\n", buffer.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AllServers_UnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
dataString string
|
||||
allServers AllServers
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
dataString: "{}",
|
||||
allServers: AllServers{},
|
||||
},
|
||||
"two known providers": {
|
||||
dataString: `{"version":1,` +
|
||||
`"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` +
|
||||
`"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]}}`,
|
||||
allServers: AllServers{
|
||||
Version: 1,
|
||||
ProviderToServers: map[string]Servers{
|
||||
providers.Cyberghost: {
|
||||
Version: 1,
|
||||
Timestamp: 1000,
|
||||
Servers: []Server{
|
||||
{Country: "A"},
|
||||
{Country: "B"},
|
||||
},
|
||||
},
|
||||
providers.Privado: {
|
||||
Version: 2,
|
||||
Timestamp: 2000,
|
||||
Servers: []Server{
|
||||
{City: "C"},
|
||||
{City: "D"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := []byte(testCase.dataString)
|
||||
var allServers AllServers
|
||||
|
||||
err := json.Unmarshal(data, &allServers)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if err != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
assert.Equal(t, testCase.allServers, allServers)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -51,50 +51,51 @@ type PortForwarder interface {
|
||||
}
|
||||
|
||||
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider {
|
||||
serversSlice := allServers.ServersSlice(provider)
|
||||
randSource := rand.NewSource(timeNow().UnixNano())
|
||||
switch provider {
|
||||
case providers.Custom:
|
||||
return custom.New()
|
||||
case providers.Cyberghost:
|
||||
return cyberghost.New(allServers.Cyberghost.Servers, randSource)
|
||||
return cyberghost.New(serversSlice, randSource)
|
||||
case providers.Expressvpn:
|
||||
return expressvpn.New(allServers.Expressvpn.Servers, randSource)
|
||||
return expressvpn.New(serversSlice, randSource)
|
||||
case providers.Fastestvpn:
|
||||
return fastestvpn.New(allServers.Fastestvpn.Servers, randSource)
|
||||
return fastestvpn.New(serversSlice, randSource)
|
||||
case providers.HideMyAss:
|
||||
return hidemyass.New(allServers.HideMyAss.Servers, randSource)
|
||||
return hidemyass.New(serversSlice, randSource)
|
||||
case providers.Ipvanish:
|
||||
return ipvanish.New(allServers.Ipvanish.Servers, randSource)
|
||||
return ipvanish.New(serversSlice, randSource)
|
||||
case providers.Ivpn:
|
||||
return ivpn.New(allServers.Ivpn.Servers, randSource)
|
||||
return ivpn.New(serversSlice, randSource)
|
||||
case providers.Mullvad:
|
||||
return mullvad.New(allServers.Mullvad.Servers, randSource)
|
||||
return mullvad.New(serversSlice, randSource)
|
||||
case providers.Nordvpn:
|
||||
return nordvpn.New(allServers.Nordvpn.Servers, randSource)
|
||||
return nordvpn.New(serversSlice, randSource)
|
||||
case providers.Perfectprivacy:
|
||||
return perfectprivacy.New(allServers.Perfectprivacy.Servers, randSource)
|
||||
return perfectprivacy.New(serversSlice, randSource)
|
||||
case providers.Privado:
|
||||
return privado.New(allServers.Privado.Servers, randSource)
|
||||
return privado.New(serversSlice, randSource)
|
||||
case providers.PrivateInternetAccess:
|
||||
return privateinternetaccess.New(allServers.Pia.Servers, randSource, timeNow)
|
||||
return privateinternetaccess.New(serversSlice, randSource, timeNow)
|
||||
case providers.Privatevpn:
|
||||
return privatevpn.New(allServers.Privatevpn.Servers, randSource)
|
||||
return privatevpn.New(serversSlice, randSource)
|
||||
case providers.Protonvpn:
|
||||
return protonvpn.New(allServers.Protonvpn.Servers, randSource)
|
||||
return protonvpn.New(serversSlice, randSource)
|
||||
case providers.Purevpn:
|
||||
return purevpn.New(allServers.Purevpn.Servers, randSource)
|
||||
return purevpn.New(serversSlice, randSource)
|
||||
case providers.Surfshark:
|
||||
return surfshark.New(allServers.Surfshark.Servers, randSource)
|
||||
return surfshark.New(serversSlice, randSource)
|
||||
case providers.Torguard:
|
||||
return torguard.New(allServers.Torguard.Servers, randSource)
|
||||
return torguard.New(serversSlice, randSource)
|
||||
case providers.VPNUnlimited:
|
||||
return vpnunlimited.New(allServers.VPNUnlimited.Servers, randSource)
|
||||
return vpnunlimited.New(serversSlice, randSource)
|
||||
case providers.Vyprvpn:
|
||||
return vyprvpn.New(allServers.Vyprvpn.Servers, randSource)
|
||||
return vyprvpn.New(serversSlice, randSource)
|
||||
case providers.Wevpn:
|
||||
return wevpn.New(allServers.Wevpn.Servers, randSource)
|
||||
return wevpn.New(serversSlice, randSource)
|
||||
case providers.Windscribe:
|
||||
return windscribe.New(allServers.Windscribe.Servers, randSource)
|
||||
return windscribe.New(serversSlice, randSource)
|
||||
default:
|
||||
panic("provider " + provider + " is unknown") // should never occur
|
||||
}
|
||||
|
||||
@@ -11,14 +11,14 @@ import (
|
||||
var _ Flusher = (*Storage)(nil)
|
||||
|
||||
type Flusher interface {
|
||||
FlushToFile(allServers models.AllServers) error
|
||||
FlushToFile(allServers *models.AllServers) error
|
||||
}
|
||||
|
||||
func (s *Storage) FlushToFile(allServers models.AllServers) error {
|
||||
func (s *Storage) FlushToFile(allServers *models.AllServers) error {
|
||||
return flushToFile(s.filepath, allServers)
|
||||
}
|
||||
|
||||
func flushToFile(path string, servers models.AllServers) error {
|
||||
func flushToFile(path string, servers *models.AllServers) error {
|
||||
dirPath := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dirPath, 0644); err != nil {
|
||||
return err
|
||||
|
||||
@@ -3,6 +3,8 @@ package storage
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -12,5 +14,13 @@ func Test_parseHardcodedServers(t *testing.T) {
|
||||
servers, err := parseHardcodedServers()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, len(servers.Cyberghost.Servers))
|
||||
|
||||
// all providers minus custom
|
||||
allProviders := providers.All()
|
||||
require.Equal(t, len(allProviders), len(servers.ProviderToServers))
|
||||
for _, provider := range allProviders {
|
||||
servers, ok := servers.ProviderToServers[provider]
|
||||
assert.Truef(t, ok, "for provider %s", provider)
|
||||
assert.NotEmptyf(t, servers, "for provider %s", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,29 +28,20 @@ func (s *Storage) logTimeDiff(provider string, persistedUnix, hardcodedUnix int6
|
||||
}
|
||||
|
||||
func (s *Storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers {
|
||||
return models.AllServers{
|
||||
allProviders := providers.All()
|
||||
merged := models.AllServers{
|
||||
Version: hardcoded.Version,
|
||||
Cyberghost: s.mergeProviderServers(providers.Cyberghost, hardcoded.Cyberghost, persisted.Cyberghost),
|
||||
Expressvpn: s.mergeProviderServers(providers.Expressvpn, hardcoded.Expressvpn, persisted.Expressvpn),
|
||||
Fastestvpn: s.mergeProviderServers(providers.Fastestvpn, hardcoded.Fastestvpn, persisted.Fastestvpn),
|
||||
HideMyAss: s.mergeProviderServers(providers.HideMyAss, hardcoded.HideMyAss, persisted.HideMyAss),
|
||||
Ipvanish: s.mergeProviderServers(providers.Ipvanish, hardcoded.Ipvanish, persisted.Ipvanish),
|
||||
Ivpn: s.mergeProviderServers(providers.Ivpn, hardcoded.Ivpn, persisted.Ivpn),
|
||||
Mullvad: s.mergeProviderServers(providers.Mullvad, hardcoded.Mullvad, persisted.Mullvad),
|
||||
Nordvpn: s.mergeProviderServers(providers.Nordvpn, hardcoded.Nordvpn, persisted.Nordvpn),
|
||||
Perfectprivacy: s.mergeProviderServers(providers.Perfectprivacy, hardcoded.Perfectprivacy, persisted.Perfectprivacy),
|
||||
Privado: s.mergeProviderServers(providers.Privado, hardcoded.Privado, persisted.Privado),
|
||||
Pia: s.mergeProviderServers(providers.PrivateInternetAccess, hardcoded.Pia, persisted.Pia),
|
||||
Privatevpn: s.mergeProviderServers(providers.Privatevpn, hardcoded.Privatevpn, persisted.Privatevpn),
|
||||
Protonvpn: s.mergeProviderServers(providers.Protonvpn, hardcoded.Protonvpn, persisted.Protonvpn),
|
||||
Purevpn: s.mergeProviderServers(providers.Purevpn, hardcoded.Purevpn, persisted.Purevpn),
|
||||
Surfshark: s.mergeProviderServers(providers.Surfshark, hardcoded.Surfshark, persisted.Surfshark),
|
||||
Torguard: s.mergeProviderServers(providers.Torguard, hardcoded.Torguard, persisted.Torguard),
|
||||
VPNUnlimited: s.mergeProviderServers(providers.VPNUnlimited, hardcoded.VPNUnlimited, persisted.VPNUnlimited),
|
||||
Vyprvpn: s.mergeProviderServers(providers.Vyprvpn, hardcoded.Vyprvpn, persisted.Vyprvpn),
|
||||
Wevpn: s.mergeProviderServers(providers.Wevpn, hardcoded.Wevpn, persisted.Wevpn),
|
||||
Windscribe: s.mergeProviderServers(providers.Windscribe, hardcoded.Windscribe, persisted.Windscribe),
|
||||
ProviderToServers: make(map[string]models.Servers, len(allProviders)),
|
||||
}
|
||||
|
||||
for _, provider := range allProviders {
|
||||
hardcodedServers := hardcoded.ProviderToServers[provider]
|
||||
persistedServers := persisted.ProviderToServers[provider]
|
||||
merged.ProviderToServers[provider] = s.mergeProviderServers(provider,
|
||||
hardcodedServers, persistedServers)
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func (s *Storage) mergeProviderServers(provider string,
|
||||
|
||||
@@ -38,172 +38,39 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
|
||||
|
||||
func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) (
|
||||
servers models.AllServers, err error) {
|
||||
var versions allVersions
|
||||
if err := json.Unmarshal(b, &versions); err != nil {
|
||||
return servers, fmt.Errorf("cannot decode versions: %w", err)
|
||||
}
|
||||
|
||||
var rawMessages allJSONRawMessages
|
||||
rawMessages := make(map[string]json.RawMessage)
|
||||
if err := json.Unmarshal(b, &rawMessages); err != nil {
|
||||
return servers, fmt.Errorf("cannot decode servers: %w", err)
|
||||
}
|
||||
|
||||
type element struct {
|
||||
provider string
|
||||
hardcoded models.Servers
|
||||
serverVersion serverVersion
|
||||
rawMessage json.RawMessage
|
||||
target *models.Servers
|
||||
}
|
||||
elements := []element{
|
||||
{
|
||||
provider: providers.Cyberghost,
|
||||
hardcoded: hardcoded.Cyberghost,
|
||||
serverVersion: versions.Cyberghost,
|
||||
rawMessage: rawMessages.Cyberghost,
|
||||
target: &servers.Cyberghost,
|
||||
},
|
||||
{
|
||||
provider: providers.Expressvpn,
|
||||
hardcoded: hardcoded.Expressvpn,
|
||||
serverVersion: versions.Expressvpn,
|
||||
rawMessage: rawMessages.Expressvpn,
|
||||
target: &servers.Expressvpn,
|
||||
},
|
||||
{
|
||||
provider: providers.Fastestvpn,
|
||||
hardcoded: hardcoded.Fastestvpn,
|
||||
serverVersion: versions.Fastestvpn,
|
||||
rawMessage: rawMessages.Fastestvpn,
|
||||
target: &servers.Fastestvpn,
|
||||
},
|
||||
{
|
||||
provider: providers.HideMyAss,
|
||||
hardcoded: hardcoded.HideMyAss,
|
||||
serverVersion: versions.HideMyAss,
|
||||
rawMessage: rawMessages.HideMyAss,
|
||||
target: &servers.HideMyAss,
|
||||
},
|
||||
{
|
||||
provider: providers.Ipvanish,
|
||||
hardcoded: hardcoded.Ipvanish,
|
||||
serverVersion: versions.Ipvanish,
|
||||
rawMessage: rawMessages.Ipvanish,
|
||||
target: &servers.Ipvanish,
|
||||
},
|
||||
{
|
||||
provider: providers.Ivpn,
|
||||
hardcoded: hardcoded.Ivpn,
|
||||
serverVersion: versions.Ivpn,
|
||||
rawMessage: rawMessages.Ivpn,
|
||||
target: &servers.Ivpn,
|
||||
},
|
||||
{
|
||||
provider: providers.Mullvad,
|
||||
hardcoded: hardcoded.Mullvad,
|
||||
serverVersion: versions.Mullvad,
|
||||
rawMessage: rawMessages.Mullvad,
|
||||
target: &servers.Mullvad,
|
||||
},
|
||||
{
|
||||
provider: providers.Nordvpn,
|
||||
hardcoded: hardcoded.Nordvpn,
|
||||
serverVersion: versions.Nordvpn,
|
||||
rawMessage: rawMessages.Nordvpn,
|
||||
target: &servers.Nordvpn,
|
||||
},
|
||||
{
|
||||
provider: providers.Perfectprivacy,
|
||||
hardcoded: hardcoded.Perfectprivacy,
|
||||
serverVersion: versions.Perfectprivacy,
|
||||
rawMessage: rawMessages.Perfectprivacy,
|
||||
target: &servers.Perfectprivacy,
|
||||
},
|
||||
{
|
||||
provider: providers.Privado,
|
||||
hardcoded: hardcoded.Privado,
|
||||
serverVersion: versions.Privado,
|
||||
rawMessage: rawMessages.Privado,
|
||||
target: &servers.Privado,
|
||||
},
|
||||
{
|
||||
provider: providers.PrivateInternetAccess,
|
||||
hardcoded: hardcoded.Pia,
|
||||
serverVersion: versions.Pia,
|
||||
rawMessage: rawMessages.Pia,
|
||||
target: &servers.Pia,
|
||||
},
|
||||
{
|
||||
provider: providers.Privatevpn,
|
||||
hardcoded: hardcoded.Privatevpn,
|
||||
serverVersion: versions.Privatevpn,
|
||||
rawMessage: rawMessages.Privatevpn,
|
||||
target: &servers.Privatevpn,
|
||||
},
|
||||
{
|
||||
provider: providers.Protonvpn,
|
||||
hardcoded: hardcoded.Protonvpn,
|
||||
serverVersion: versions.Protonvpn,
|
||||
rawMessage: rawMessages.Protonvpn,
|
||||
target: &servers.Protonvpn,
|
||||
},
|
||||
{
|
||||
provider: providers.Purevpn,
|
||||
hardcoded: hardcoded.Purevpn,
|
||||
serverVersion: versions.Purevpn,
|
||||
rawMessage: rawMessages.Purevpn,
|
||||
target: &servers.Purevpn,
|
||||
},
|
||||
{
|
||||
provider: providers.Surfshark,
|
||||
hardcoded: hardcoded.Surfshark,
|
||||
serverVersion: versions.Surfshark,
|
||||
rawMessage: rawMessages.Surfshark,
|
||||
target: &servers.Surfshark,
|
||||
},
|
||||
{
|
||||
provider: providers.Torguard,
|
||||
hardcoded: hardcoded.Torguard,
|
||||
serverVersion: versions.Torguard,
|
||||
rawMessage: rawMessages.Torguard,
|
||||
target: &servers.Torguard,
|
||||
},
|
||||
{
|
||||
provider: providers.VPNUnlimited,
|
||||
hardcoded: hardcoded.VPNUnlimited,
|
||||
serverVersion: versions.VPNUnlimited,
|
||||
rawMessage: rawMessages.VPNUnlimited,
|
||||
target: &servers.VPNUnlimited,
|
||||
},
|
||||
{
|
||||
provider: providers.Vyprvpn,
|
||||
hardcoded: hardcoded.Vyprvpn,
|
||||
serverVersion: versions.Vyprvpn,
|
||||
rawMessage: rawMessages.Vyprvpn,
|
||||
target: &servers.Vyprvpn,
|
||||
},
|
||||
{
|
||||
provider: providers.Wevpn,
|
||||
hardcoded: hardcoded.Wevpn,
|
||||
serverVersion: versions.Wevpn,
|
||||
rawMessage: rawMessages.Wevpn,
|
||||
target: &servers.Wevpn,
|
||||
},
|
||||
{
|
||||
provider: providers.Windscribe,
|
||||
hardcoded: hardcoded.Windscribe,
|
||||
serverVersion: versions.Windscribe,
|
||||
rawMessage: rawMessages.Windscribe,
|
||||
target: &servers.Windscribe,
|
||||
},
|
||||
// Note schema version is at map key "version" as number
|
||||
|
||||
allProviders := providers.All()
|
||||
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
|
||||
for _, provider := range allProviders {
|
||||
hardcoded, ok := hardcoded.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
|
||||
}
|
||||
|
||||
for _, element := range elements {
|
||||
*element.target, err = s.readServers(element.provider,
|
||||
element.hardcoded, element.serverVersion, element.rawMessage)
|
||||
if err != nil {
|
||||
return servers, err
|
||||
rawMessage, ok := rawMessages[provider]
|
||||
if !ok {
|
||||
// If the provider is not found in the data bytes, just don't set it in
|
||||
// the providers map. That way the hardcoded servers will override them.
|
||||
// This is user provided and could come from different sources in the
|
||||
// future (e.g. a file or API request).
|
||||
continue
|
||||
}
|
||||
|
||||
mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage)
|
||||
if err != nil {
|
||||
return models.AllServers{}, err
|
||||
} else if !versionsMatch {
|
||||
// mergedServers is the empty struct in this case, so don't set the key
|
||||
// in the providerToServers map.
|
||||
continue
|
||||
}
|
||||
servers.ProviderToServers[provider] = mergedServers
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
@@ -214,73 +81,20 @@ var (
|
||||
)
|
||||
|
||||
func (s *Storage) readServers(provider string, hardcoded models.Servers,
|
||||
serverVersion serverVersion, rawMessage json.RawMessage) (
|
||||
servers models.Servers, err error) {
|
||||
rawMessage json.RawMessage) (servers models.Servers, versionsMatch bool, err error) {
|
||||
provider = strings.Title(provider)
|
||||
if hardcoded.Version != serverVersion.Version {
|
||||
s.logVersionDiff(provider, hardcoded.Version, serverVersion.Version)
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
err = json.Unmarshal(rawMessage, &servers)
|
||||
var persistedServers models.Servers
|
||||
err = json.Unmarshal(rawMessage, &persistedServers)
|
||||
if err != nil {
|
||||
return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
|
||||
return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
versionsMatch = hardcoded.Version == persistedServers.Version
|
||||
if !versionsMatch {
|
||||
s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version)
|
||||
return servers, versionsMatch, nil
|
||||
}
|
||||
|
||||
// allVersions is a subset of models.AllServers structure used to track
|
||||
// versions to avoid unmarshaling errors.
|
||||
type allVersions struct {
|
||||
Version uint16 `json:"version"` // used for migration of the top level scheme
|
||||
Cyberghost serverVersion `json:"cyberghost"`
|
||||
Expressvpn serverVersion `json:"expressvpn"`
|
||||
Fastestvpn serverVersion `json:"fastestvpn"`
|
||||
HideMyAss serverVersion `json:"hidemyass"`
|
||||
Ipvanish serverVersion `json:"ipvanish"`
|
||||
Ivpn serverVersion `json:"ivpn"`
|
||||
Mullvad serverVersion `json:"mullvad"`
|
||||
Nordvpn serverVersion `json:"nordvpn"`
|
||||
Perfectprivacy serverVersion `json:"perfect privacy"`
|
||||
Privado serverVersion `json:"privado"`
|
||||
Pia serverVersion `json:"private internet access"`
|
||||
Privatevpn serverVersion `json:"privatevpn"`
|
||||
Protonvpn serverVersion `json:"protonvpn"`
|
||||
Purevpn serverVersion `json:"purevpn"`
|
||||
Surfshark serverVersion `json:"surfshark"`
|
||||
Torguard serverVersion `json:"torguard"`
|
||||
VPNUnlimited serverVersion `json:"vpn unlimited"`
|
||||
Vyprvpn serverVersion `json:"vyprvpn"`
|
||||
Wevpn serverVersion `json:"wevpn"`
|
||||
Windscribe serverVersion `json:"windscribe"`
|
||||
}
|
||||
|
||||
type serverVersion struct {
|
||||
Version uint16 `json:"version"`
|
||||
}
|
||||
|
||||
// allJSONRawMessages is to delay decoding of each provider servers.
|
||||
type allJSONRawMessages struct {
|
||||
Version uint16 `json:"version"` // used for migration of the top level scheme
|
||||
Cyberghost json.RawMessage `json:"cyberghost"`
|
||||
Expressvpn json.RawMessage `json:"expressvpn"`
|
||||
Fastestvpn json.RawMessage `json:"fastestvpn"`
|
||||
HideMyAss json.RawMessage `json:"hidemyass"`
|
||||
Ipvanish json.RawMessage `json:"ipvanish"`
|
||||
Ivpn json.RawMessage `json:"ivpn"`
|
||||
Mullvad json.RawMessage `json:"mullvad"`
|
||||
Nordvpn json.RawMessage `json:"nordvpn"`
|
||||
Perfectprivacy json.RawMessage `json:"perfect privacy"`
|
||||
Privado json.RawMessage `json:"privado"`
|
||||
Pia json.RawMessage `json:"private internet access"`
|
||||
Privatevpn json.RawMessage `json:"privatevpn"`
|
||||
Protonvpn json.RawMessage `json:"protonvpn"`
|
||||
Purevpn json.RawMessage `json:"purevpn"`
|
||||
Surfshark json.RawMessage `json:"surfshark"`
|
||||
Torguard json.RawMessage `json:"torguard"`
|
||||
VPNUnlimited json.RawMessage `json:"vpn unlimited"`
|
||||
Vyprvpn json.RawMessage `json:"vyprvpn"`
|
||||
Wevpn json.RawMessage `json:"wevpn"`
|
||||
Windscribe json.RawMessage `json:"windscribe"`
|
||||
return persistedServers, versionsMatch, nil
|
||||
}
|
||||
|
||||
@@ -1,15 +1,35 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
|
||||
servers models.AllServers) models.AllServers {
|
||||
allProviders := providers.All()
|
||||
if servers.ProviderToServers == nil {
|
||||
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders)-1)
|
||||
}
|
||||
for _, provider := range allProviders {
|
||||
_, has := servers.ProviderToServers[provider]
|
||||
if has {
|
||||
continue
|
||||
}
|
||||
servers.ProviderToServers[provider] = models.Servers{
|
||||
Version: allProviderVersion,
|
||||
Timestamp: allProviderTimestamp,
|
||||
}
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
func Test_extractServersFromBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -18,60 +38,23 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
hardcoded models.AllServers
|
||||
logged []string
|
||||
persisted models.AllServers
|
||||
err error
|
||||
errMessage string
|
||||
}{
|
||||
"no data": {
|
||||
err: errors.New("cannot decode versions: unexpected end of JSON input"),
|
||||
"bad JSON": {
|
||||
b: []byte("garbage"),
|
||||
errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value",
|
||||
},
|
||||
"empty JSON": {
|
||||
b: []byte("{}"),
|
||||
err: errors.New("cannot decode servers for provider: Cyberghost: unexpected end of JSON input"),
|
||||
"bad provider JSON": {
|
||||
b: []byte(`{"cyberghost": "garbage"}`),
|
||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
||||
errMessage: "cannot decode servers for provider: Cyberghost: " +
|
||||
"json: cannot unmarshal string into Go value of type models.Servers",
|
||||
},
|
||||
"different versions": {
|
||||
"absent provider keys": {
|
||||
b: []byte(`{}`),
|
||||
hardcoded: models.AllServers{
|
||||
Cyberghost: models.Servers{Version: 1},
|
||||
Expressvpn: models.Servers{Version: 1},
|
||||
Fastestvpn: models.Servers{Version: 1},
|
||||
HideMyAss: models.Servers{Version: 1},
|
||||
Ipvanish: models.Servers{Version: 1},
|
||||
Ivpn: models.Servers{Version: 1},
|
||||
Mullvad: models.Servers{Version: 1},
|
||||
Nordvpn: models.Servers{Version: 1},
|
||||
Perfectprivacy: models.Servers{Version: 1},
|
||||
Privado: models.Servers{Version: 1},
|
||||
Pia: models.Servers{Version: 1},
|
||||
Privatevpn: models.Servers{Version: 1},
|
||||
Protonvpn: models.Servers{Version: 1},
|
||||
Purevpn: models.Servers{Version: 1},
|
||||
Surfshark: models.Servers{Version: 1},
|
||||
Torguard: models.Servers{Version: 1},
|
||||
VPNUnlimited: models.Servers{Version: 1},
|
||||
Vyprvpn: models.Servers{Version: 1},
|
||||
Wevpn: models.Servers{Version: 1},
|
||||
Windscribe: models.Servers{Version: 1},
|
||||
},
|
||||
logged: []string{
|
||||
"Cyberghost servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Expressvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Fastestvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Hidemyass servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Ipvanish servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Ivpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Mullvad servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Nordvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Perfect Privacy servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Privado servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Private Internet Access servers from file discarded because they have version 0 and hardcoded servers have version 1", //nolint:lll
|
||||
"Privatevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Protonvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Purevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Surfshark servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Torguard servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Vpn Unlimited servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Vyprvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Wevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
"Windscribe servers from file discarded because they have version 0 and hardcoded servers have version 1",
|
||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
||||
persisted: models.AllServers{
|
||||
ProviderToServers: map[string]models.Servers{},
|
||||
},
|
||||
},
|
||||
"same versions": {
|
||||
@@ -97,49 +80,57 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
"wevpn": {"version": 1, "timestamp": 1},
|
||||
"windscribe": {"version": 1, "timestamp": 1}
|
||||
}`),
|
||||
hardcoded: models.AllServers{
|
||||
Cyberghost: models.Servers{Version: 1},
|
||||
Expressvpn: models.Servers{Version: 1},
|
||||
Fastestvpn: models.Servers{Version: 1},
|
||||
HideMyAss: models.Servers{Version: 1},
|
||||
Ipvanish: models.Servers{Version: 1},
|
||||
Ivpn: models.Servers{Version: 1},
|
||||
Mullvad: models.Servers{Version: 1},
|
||||
Nordvpn: models.Servers{Version: 1},
|
||||
Perfectprivacy: models.Servers{Version: 1},
|
||||
Privado: models.Servers{Version: 1},
|
||||
Pia: models.Servers{Version: 1},
|
||||
Privatevpn: models.Servers{Version: 1},
|
||||
Protonvpn: models.Servers{Version: 1},
|
||||
Purevpn: models.Servers{Version: 1},
|
||||
Surfshark: models.Servers{Version: 1},
|
||||
Torguard: models.Servers{Version: 1},
|
||||
VPNUnlimited: models.Servers{Version: 1},
|
||||
Vyprvpn: models.Servers{Version: 1},
|
||||
Wevpn: models.Servers{Version: 1},
|
||||
Windscribe: models.Servers{Version: 1},
|
||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
||||
persisted: populateProviders(1, 1, models.AllServers{}),
|
||||
},
|
||||
"different versions": {
|
||||
b: []byte(`{
|
||||
"cyberghost": {"version": 1, "timestamp": 1},
|
||||
"expressvpn": {"version": 1, "timestamp": 1},
|
||||
"fastestvpn": {"version": 1, "timestamp": 1},
|
||||
"hidemyass": {"version": 1, "timestamp": 1},
|
||||
"ipvanish": {"version": 1, "timestamp": 1},
|
||||
"ivpn": {"version": 1, "timestamp": 1},
|
||||
"mullvad": {"version": 1, "timestamp": 1},
|
||||
"nordvpn": {"version": 1, "timestamp": 1},
|
||||
"perfect privacy": {"version": 1, "timestamp": 1},
|
||||
"privado": {"version": 1, "timestamp": 1},
|
||||
"private internet access": {"version": 1, "timestamp": 1},
|
||||
"privatevpn": {"version": 1, "timestamp": 1},
|
||||
"protonvpn": {"version": 1, "timestamp": 1},
|
||||
"purevpn": {"version": 1, "timestamp": 1},
|
||||
"surfshark": {"version": 1, "timestamp": 1},
|
||||
"torguard": {"version": 1, "timestamp": 1},
|
||||
"vpn unlimited": {"version": 1, "timestamp": 1},
|
||||
"vyprvpn": {"version": 1, "timestamp": 1},
|
||||
"wevpn": {"version": 1, "timestamp": 1},
|
||||
"windscribe": {"version": 1, "timestamp": 1}
|
||||
}`),
|
||||
hardcoded: populateProviders(2, 0, models.AllServers{}),
|
||||
logged: []string{
|
||||
"Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Fastestvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Hidemyass servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Ipvanish servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Ivpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Mullvad servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Nordvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Perfect Privacy servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Privado servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Private Internet Access servers from file discarded because they have version 1 and hardcoded servers have version 2", //nolint:lll
|
||||
"Privatevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Protonvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Purevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Surfshark servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Torguard servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Vpn Unlimited servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Vyprvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Wevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Windscribe servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
},
|
||||
persisted: models.AllServers{
|
||||
Cyberghost: models.Servers{Version: 1, Timestamp: 1},
|
||||
Expressvpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
Fastestvpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
HideMyAss: models.Servers{Version: 1, Timestamp: 1},
|
||||
Ipvanish: models.Servers{Version: 1, Timestamp: 1},
|
||||
Ivpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
Mullvad: models.Servers{Version: 1, Timestamp: 1},
|
||||
Nordvpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
Perfectprivacy: models.Servers{Version: 1, Timestamp: 1},
|
||||
Privado: models.Servers{Version: 1, Timestamp: 1},
|
||||
Pia: models.Servers{Version: 1, Timestamp: 1},
|
||||
Privatevpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
Protonvpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
Purevpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
Surfshark: models.Servers{Version: 1, Timestamp: 1},
|
||||
Torguard: models.Servers{Version: 1, Timestamp: 1},
|
||||
VPNUnlimited: models.Servers{Version: 1, Timestamp: 1},
|
||||
Vyprvpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
Wevpn: models.Servers{Version: 1, Timestamp: 1},
|
||||
Windscribe: models.Servers{Version: 1, Timestamp: 1},
|
||||
ProviderToServers: map[string]models.Servers{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -151,8 +142,13 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
logger := NewMockInfoErrorer(ctrl)
|
||||
var previousLogCall *gomock.Call
|
||||
for _, logged := range testCase.logged {
|
||||
logger.EXPECT().Info(logged)
|
||||
call := logger.EXPECT().Info(logged)
|
||||
if previousLogCall != nil {
|
||||
call.After(previousLogCall)
|
||||
}
|
||||
previousLogCall = call
|
||||
}
|
||||
|
||||
s := &Storage{
|
||||
@@ -161,9 +157,8 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
|
||||
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -171,4 +166,25 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
assert.Equal(t, testCase.persisted, servers)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("hardcoded panic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := &Storage{}
|
||||
|
||||
allProviders := providers.All()
|
||||
require.GreaterOrEqual(t, len(allProviders), 2)
|
||||
|
||||
b := []byte(`{}`)
|
||||
hardcoded := models.AllServers{
|
||||
ProviderToServers: map[string]models.Servers{
|
||||
allProviders[0]: {},
|
||||
// Missing provider allProviders[1]
|
||||
},
|
||||
}
|
||||
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
|
||||
assert.PanicsWithValue(t, expectedPanicValue, func() {
|
||||
_, _ = s.extractServersFromBytes(b, hardcoded)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,27 +7,11 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func countServers(allServers models.AllServers) int {
|
||||
return len(allServers.Cyberghost.Servers) +
|
||||
len(allServers.Expressvpn.Servers) +
|
||||
len(allServers.Fastestvpn.Servers) +
|
||||
len(allServers.HideMyAss.Servers) +
|
||||
len(allServers.Ipvanish.Servers) +
|
||||
len(allServers.Ivpn.Servers) +
|
||||
len(allServers.Mullvad.Servers) +
|
||||
len(allServers.Nordvpn.Servers) +
|
||||
len(allServers.Perfectprivacy.Servers) +
|
||||
len(allServers.Privado.Servers) +
|
||||
len(allServers.Pia.Servers) +
|
||||
len(allServers.Privatevpn.Servers) +
|
||||
len(allServers.Protonvpn.Servers) +
|
||||
len(allServers.Purevpn.Servers) +
|
||||
len(allServers.Surfshark.Servers) +
|
||||
len(allServers.Torguard.Servers) +
|
||||
len(allServers.VPNUnlimited.Servers) +
|
||||
len(allServers.Vyprvpn.Servers) +
|
||||
len(allServers.Wevpn.Servers) +
|
||||
len(allServers.Windscribe.Servers)
|
||||
func countServers(allServers models.AllServers) (count int) {
|
||||
for _, servers := range allServers.ProviderToServers {
|
||||
count += len(servers.Servers)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (s *Storage) SyncServers() (err error) {
|
||||
@@ -57,7 +41,7 @@ func (s *Storage) SyncServers() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := flushToFile(s.filepath, s.mergedServers); err != nil {
|
||||
if err := flushToFile(s.filepath, &s.mergedServers); err != nil {
|
||||
return fmt.Errorf("cannot write servers to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -139,7 +139,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||
l.stopped <- struct{}{}
|
||||
case servers := <-serversCh:
|
||||
l.setAllServers(servers)
|
||||
if err := l.flusher.FlushToFile(servers); err != nil {
|
||||
if err := l.flusher.FlushToFile(&servers); err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
runWg.Wait()
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/updater/providers/cyberghost"
|
||||
"github.com/qdm12/gluetun/internal/updater/providers/expressvpn"
|
||||
"github.com/qdm12/gluetun/internal/updater/providers/fastestvpn"
|
||||
@@ -28,419 +30,96 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/updater/providers/windscribe"
|
||||
)
|
||||
|
||||
func (u *updater) updateCyberghost(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Cyberghost.Servers))
|
||||
servers, err := cyberghost.GetServers(ctx, u.presolver, minServers)
|
||||
func (u *updater) updateProvider(ctx context.Context, provider string) (
|
||||
warnings []string, err error) {
|
||||
existingServers := u.getProviderServers(provider)
|
||||
minServers := getMinServers(existingServers)
|
||||
servers, warnings, err := u.getServers(ctx, provider, minServers)
|
||||
if err != nil {
|
||||
return err
|
||||
return warnings, err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Cyberghost.Servers, servers) {
|
||||
return nil
|
||||
if reflect.DeepEqual(existingServers, servers) {
|
||||
return warnings, nil
|
||||
}
|
||||
|
||||
u.servers.Cyberghost.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Cyberghost.Servers = servers
|
||||
return nil
|
||||
u.patchProvider(provider, servers)
|
||||
return warnings, nil
|
||||
}
|
||||
|
||||
func (u *updater) updateExpressvpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Expressvpn.Servers))
|
||||
servers, warnings, err := expressvpn.GetServers(
|
||||
ctx, u.unzipper, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("ExpressVPN: " + warning)
|
||||
func (u *updater) getServers(ctx context.Context, provider string,
|
||||
minServers int) (servers []models.Server, warnings []string, err error) {
|
||||
switch provider {
|
||||
case providers.Custom:
|
||||
panic("cannot update custom provider")
|
||||
case providers.Cyberghost:
|
||||
servers, err = cyberghost.GetServers(ctx, u.presolver, minServers)
|
||||
return servers, nil, err
|
||||
case providers.Expressvpn:
|
||||
return expressvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
|
||||
case providers.Fastestvpn:
|
||||
return fastestvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
|
||||
case providers.HideMyAss:
|
||||
return hidemyass.GetServers(ctx, u.client, u.presolver, minServers)
|
||||
case providers.Ipvanish:
|
||||
return ipvanish.GetServers(ctx, u.unzipper, u.presolver, minServers)
|
||||
case providers.Ivpn:
|
||||
return ivpn.GetServers(ctx, u.client, u.presolver, minServers)
|
||||
case providers.Mullvad:
|
||||
servers, err = mullvad.GetServers(ctx, u.client, minServers)
|
||||
return servers, nil, err
|
||||
case providers.Nordvpn:
|
||||
return nordvpn.GetServers(ctx, u.client, minServers)
|
||||
case providers.Perfectprivacy:
|
||||
return perfectprivacy.GetServers(ctx, u.unzipper, minServers)
|
||||
case providers.Privado:
|
||||
return privado.GetServers(ctx, u.unzipper, u.client, u.presolver, minServers)
|
||||
case providers.PrivateInternetAccess:
|
||||
servers, err = pia.GetServers(ctx, u.client, minServers)
|
||||
return servers, nil, err
|
||||
case providers.Privatevpn:
|
||||
return privatevpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
|
||||
case providers.Protonvpn:
|
||||
return protonvpn.GetServers(ctx, u.client, minServers)
|
||||
case providers.Purevpn:
|
||||
return purevpn.GetServers(ctx, u.client, u.unzipper, u.presolver, minServers)
|
||||
case providers.Surfshark:
|
||||
return surfshark.GetServers(ctx, u.unzipper, u.client, u.presolver, minServers)
|
||||
case providers.Torguard:
|
||||
return torguard.GetServers(ctx, u.unzipper, u.presolver, minServers)
|
||||
case providers.VPNUnlimited:
|
||||
return vpnunlimited.GetServers(ctx, u.unzipper, u.presolver, minServers)
|
||||
case providers.Vyprvpn:
|
||||
return vyprvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
|
||||
case providers.Wevpn:
|
||||
return wevpn.GetServers(ctx, u.presolver, minServers)
|
||||
case providers.Windscribe:
|
||||
servers, err = windscribe.GetServers(ctx, u.client, minServers)
|
||||
return servers, nil, err
|
||||
default:
|
||||
panic("provider " + provider + " is unknown")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Expressvpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Expressvpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Expressvpn.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateFastestvpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Fastestvpn.Servers))
|
||||
servers, warnings, err := fastestvpn.GetServers(
|
||||
ctx, u.unzipper, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("FastestVPN: " + warning)
|
||||
func (u *updater) getProviderServers(provider string) (servers []models.Server) {
|
||||
providerServers, ok := u.servers.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s is unknown", provider))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Fastestvpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Fastestvpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Fastestvpn.Servers = servers
|
||||
return nil
|
||||
return providerServers.Servers
|
||||
}
|
||||
|
||||
func (u *updater) updateHideMyAss(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.HideMyAss.Servers))
|
||||
servers, warnings, err := hidemyass.GetServers(
|
||||
ctx, u.client, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("HideMyAss: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.HideMyAss.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.HideMyAss.Timestamp = u.timeNow().Unix()
|
||||
u.servers.HideMyAss.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateIpvanish(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Ipvanish.Servers))
|
||||
servers, warnings, err := ipvanish.GetServers(
|
||||
ctx, u.unzipper, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("Ipvanish: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Ipvanish.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Ipvanish.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Ipvanish.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateIvpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Ivpn.Servers))
|
||||
servers, warnings, err := ivpn.GetServers(
|
||||
ctx, u.client, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("Ivpn: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Ivpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Ivpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Ivpn.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateMullvad(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Mullvad.Servers))
|
||||
servers, err := mullvad.GetServers(ctx, u.client, minServers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Mullvad.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Mullvad.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Mullvad.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateNordvpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Nordvpn.Servers))
|
||||
servers, warnings, err := nordvpn.GetServers(ctx, u.client, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("NordVPN: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Nordvpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Nordvpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Nordvpn.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updatePerfectprivacy(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Perfectprivacy.Servers))
|
||||
servers, warnings, err := perfectprivacy.GetServers(ctx, u.unzipper, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn(providers.Perfectprivacy + ": " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Perfectprivacy.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Perfectprivacy.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Perfectprivacy.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updatePIA(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Pia.Servers))
|
||||
servers, err := pia.GetServers(ctx, u.client, minServers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Pia.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Pia.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Pia.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updatePrivado(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Privado.Servers))
|
||||
servers, warnings, err := privado.GetServers(
|
||||
ctx, u.unzipper, u.client, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("Privado: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Privado.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Privado.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Privado.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updatePrivatevpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Privatevpn.Servers))
|
||||
servers, warnings, err := privatevpn.GetServers(
|
||||
ctx, u.unzipper, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("PrivateVPN: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Privatevpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Privatevpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Privatevpn.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateProtonvpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Privatevpn.Servers))
|
||||
servers, warnings, err := protonvpn.GetServers(ctx, u.client, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("ProtonVPN: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Protonvpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Protonvpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Protonvpn.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updatePurevpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Purevpn.Servers))
|
||||
servers, warnings, err := purevpn.GetServers(
|
||||
ctx, u.client, u.unzipper, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("PureVPN: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot update Purevpn servers: %w", err)
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Purevpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Purevpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Purevpn.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateSurfshark(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Surfshark.Servers))
|
||||
servers, warnings, err := surfshark.GetServers(
|
||||
ctx, u.unzipper, u.client, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("Surfshark: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Surfshark.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Surfshark.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Surfshark.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateTorguard(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Torguard.Servers))
|
||||
servers, warnings, err := torguard.GetServers(
|
||||
ctx, u.unzipper, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("Torguard: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Torguard.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Torguard.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Torguard.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateVPNUnlimited(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.VPNUnlimited.Servers))
|
||||
servers, warnings, err := vpnunlimited.GetServers(
|
||||
ctx, u.unzipper, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn(providers.VPNUnlimited + ": " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.VPNUnlimited.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.VPNUnlimited.Timestamp = u.timeNow().Unix()
|
||||
u.servers.VPNUnlimited.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateVyprvpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Vyprvpn.Servers))
|
||||
servers, warnings, err := vyprvpn.GetServers(
|
||||
ctx, u.unzipper, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("VyprVPN: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Vyprvpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Vyprvpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Vyprvpn.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateWevpn(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Wevpn.Servers))
|
||||
servers, warnings, err := wevpn.GetServers(ctx, u.presolver, minServers)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn("WeVPN: " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Wevpn.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Wevpn.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Wevpn.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *updater) updateWindscribe(ctx context.Context) (err error) {
|
||||
minServers := getMinServers(len(u.servers.Windscribe.Servers))
|
||||
servers, err := windscribe.GetServers(ctx, u.client, minServers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(u.servers.Windscribe.Servers, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.servers.Windscribe.Timestamp = u.timeNow().Unix()
|
||||
u.servers.Windscribe.Servers = servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMinServers(existingServers int) (minServers int) {
|
||||
func getMinServers(servers []models.Server) (minServers int) {
|
||||
const minRatio = 0.8
|
||||
return int(minRatio * float64(existingServers))
|
||||
return int(minRatio * float64(len(servers)))
|
||||
}
|
||||
|
||||
func (u *updater) patchProvider(provider string, servers []models.Server) {
|
||||
providerServers, ok := u.servers.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s is unknown", provider))
|
||||
}
|
||||
providerServers.Timestamp = time.Now().Unix()
|
||||
providerServers.Servers = servers
|
||||
u.servers.ProviderToServers[provider] = providerServers
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/updater/resolver"
|
||||
"github.com/qdm12/gluetun/internal/updater/unzip"
|
||||
@@ -47,16 +46,17 @@ func New(settings settings.Updater, httpClient *http.Client,
|
||||
}
|
||||
}
|
||||
|
||||
type updateFunc func(ctx context.Context) (err error)
|
||||
|
||||
func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) {
|
||||
for _, provider := range u.options.Providers {
|
||||
u.logger.Info("updating " + strings.Title(provider) + " servers...")
|
||||
updateProvider := u.getUpdateFunction(provider)
|
||||
|
||||
// TODO support servers offering only TCP or only UDP
|
||||
// for NordVPN and PureVPN
|
||||
err = updateProvider(ctx)
|
||||
warnings, err := u.updateProvider(ctx, provider)
|
||||
if *u.options.CLI {
|
||||
for _, warning := range warnings {
|
||||
u.logger.Warn(provider + ": " + warning)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return allServers, ctxErr
|
||||
@@ -67,52 +67,3 @@ func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServe
|
||||
|
||||
return u.servers, nil
|
||||
}
|
||||
|
||||
func (u *updater) getUpdateFunction(provider string) (updateFunction updateFunc) {
|
||||
switch provider {
|
||||
case providers.Custom:
|
||||
panic("cannot update custom provider")
|
||||
case providers.Cyberghost:
|
||||
return func(ctx context.Context) (err error) { return u.updateCyberghost(ctx) }
|
||||
case providers.Expressvpn:
|
||||
return func(ctx context.Context) (err error) { return u.updateExpressvpn(ctx) }
|
||||
case providers.Fastestvpn:
|
||||
return func(ctx context.Context) (err error) { return u.updateFastestvpn(ctx) }
|
||||
case providers.HideMyAss:
|
||||
return func(ctx context.Context) (err error) { return u.updateHideMyAss(ctx) }
|
||||
case providers.Ipvanish:
|
||||
return func(ctx context.Context) (err error) { return u.updateIpvanish(ctx) }
|
||||
case providers.Ivpn:
|
||||
return func(ctx context.Context) (err error) { return u.updateIvpn(ctx) }
|
||||
case providers.Mullvad:
|
||||
return func(ctx context.Context) (err error) { return u.updateMullvad(ctx) }
|
||||
case providers.Nordvpn:
|
||||
return func(ctx context.Context) (err error) { return u.updateNordvpn(ctx) }
|
||||
case providers.Perfectprivacy:
|
||||
return func(ctx context.Context) (err error) { return u.updatePerfectprivacy(ctx) }
|
||||
case providers.Privado:
|
||||
return func(ctx context.Context) (err error) { return u.updatePrivado(ctx) }
|
||||
case providers.PrivateInternetAccess:
|
||||
return func(ctx context.Context) (err error) { return u.updatePIA(ctx) }
|
||||
case providers.Privatevpn:
|
||||
return func(ctx context.Context) (err error) { return u.updatePrivatevpn(ctx) }
|
||||
case providers.Protonvpn:
|
||||
return func(ctx context.Context) (err error) { return u.updateProtonvpn(ctx) }
|
||||
case providers.Purevpn:
|
||||
return func(ctx context.Context) (err error) { return u.updatePurevpn(ctx) }
|
||||
case providers.Surfshark:
|
||||
return func(ctx context.Context) (err error) { return u.updateSurfshark(ctx) }
|
||||
case providers.Torguard:
|
||||
return func(ctx context.Context) (err error) { return u.updateTorguard(ctx) }
|
||||
case providers.VPNUnlimited:
|
||||
return func(ctx context.Context) (err error) { return u.updateVPNUnlimited(ctx) }
|
||||
case providers.Vyprvpn:
|
||||
return func(ctx context.Context) (err error) { return u.updateVyprvpn(ctx) }
|
||||
case providers.Wevpn:
|
||||
return func(ctx context.Context) (err error) { return u.updateWevpn(ctx) }
|
||||
case providers.Windscribe:
|
||||
return func(ctx context.Context) (err error) { return u.updateWindscribe(ctx) }
|
||||
default:
|
||||
panic("provider " + provider + " is unknown")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user