chore(all): provider to servers map in allServers

- Simplify formatting CLI
- Simplify updater code
- Simplify filter choices for config validation
- Simplify all servers deep copying
- Custom JSON marshaling methods for `AllServers`
- Simplify provider constructor switch
- Simplify storage merging
- Simplify storage reading and extraction
- Simplify updating code
This commit is contained in:
Quentin McGaw
2022-05-27 00:59:47 +00:00
parent 5ffe8555ba
commit bd0868d764
22 changed files with 854 additions and 1295 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage"
)
@@ -20,6 +21,7 @@ type ServersFormatter interface {
var (
ErrFormatNotRecognized = errors.New("format is not recognized")
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
ErrMultipleProvidersToFormat = errors.New("more than one VPN provider to format were specified")
)
func addProviderFlag(flagSet *flag.FlagSet,
@@ -31,21 +33,12 @@ func addProviderFlag(flagSet *flag.FlagSet,
flagSet.BoolVar(boolPtr, provider, false, "Format "+strings.Title(provider)+" servers")
}
func getFormatForProvider(providerToFormat map[string]*bool, provider string) (format bool) {
formatPtr, ok := providerToFormat[provider]
if !ok {
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
}
return *formatPtr
}
func (c *CLI) FormatServers(args []string) error {
var format, output string
allProviders := providers.All()
providersToFormat := make(map[string]*bool, len(allProviders))
for _, provider := range allProviders {
value := false
providersToFormat[provider] = &value
providersToFormat[provider] = new(bool)
}
flagSet := flag.NewFlagSet("markdown", flag.ExitOnError)
flagSet.StringVar(&format, "format", "markdown", "Format to use which can be: 'markdown'")
@@ -61,6 +54,24 @@ func (c *CLI) FormatServers(args []string) error {
return fmt.Errorf("%w: %s", ErrFormatNotRecognized, format)
}
// Verify only one provider is set to be formatted.
var providers []string
for provider, formatPtr := range providersToFormat {
if *formatPtr {
providers = append(providers, provider)
}
}
switch len(providers) {
case 0:
return ErrProviderUnspecified
case 1:
default:
return fmt.Errorf("%w: %d specified: %s",
ErrMultipleProvidersToFormat, len(providers),
strings.Join(providers, ", "))
}
providerToFormat := providers[0]
logger := newNoopLogger()
storage, err := storage.New(logger, constants.ServersData)
if err != nil {
@@ -68,51 +79,7 @@ func (c *CLI) FormatServers(args []string) error {
}
currentServers := storage.GetServers()
var formatted string
switch {
case getFormatForProvider(providersToFormat, providers.Cyberghost):
formatted = currentServers.Cyberghost.ToMarkdown(providers.Cyberghost)
case getFormatForProvider(providersToFormat, providers.Expressvpn):
formatted = currentServers.Expressvpn.ToMarkdown(providers.Expressvpn)
case getFormatForProvider(providersToFormat, providers.Fastestvpn):
formatted = currentServers.Fastestvpn.ToMarkdown(providers.Fastestvpn)
case getFormatForProvider(providersToFormat, providers.HideMyAss):
formatted = currentServers.HideMyAss.ToMarkdown(providers.HideMyAss)
case getFormatForProvider(providersToFormat, providers.Ipvanish):
formatted = currentServers.Ipvanish.ToMarkdown(providers.Ipvanish)
case getFormatForProvider(providersToFormat, providers.Ivpn):
formatted = currentServers.Ivpn.ToMarkdown(providers.Ivpn)
case getFormatForProvider(providersToFormat, providers.Mullvad):
formatted = currentServers.Mullvad.ToMarkdown(providers.Mullvad)
case getFormatForProvider(providersToFormat, providers.Nordvpn):
formatted = currentServers.Nordvpn.ToMarkdown(providers.Nordvpn)
case getFormatForProvider(providersToFormat, providers.Perfectprivacy):
formatted = currentServers.Perfectprivacy.ToMarkdown(providers.Perfectprivacy)
case getFormatForProvider(providersToFormat, providers.PrivateInternetAccess):
formatted = currentServers.Pia.ToMarkdown(providers.PrivateInternetAccess)
case getFormatForProvider(providersToFormat, providers.Privado):
formatted = currentServers.Privado.ToMarkdown(providers.Privado)
case getFormatForProvider(providersToFormat, providers.Privatevpn):
formatted = currentServers.Privatevpn.ToMarkdown(providers.Privatevpn)
case getFormatForProvider(providersToFormat, providers.Protonvpn):
formatted = currentServers.Protonvpn.ToMarkdown(providers.Protonvpn)
case getFormatForProvider(providersToFormat, providers.Purevpn):
formatted = currentServers.Purevpn.ToMarkdown(providers.Purevpn)
case getFormatForProvider(providersToFormat, providers.Surfshark):
formatted = currentServers.Surfshark.ToMarkdown(providers.Surfshark)
case getFormatForProvider(providersToFormat, providers.Torguard):
formatted = currentServers.Torguard.ToMarkdown(providers.Torguard)
case getFormatForProvider(providersToFormat, providers.VPNUnlimited):
formatted = currentServers.VPNUnlimited.ToMarkdown(providers.VPNUnlimited)
case getFormatForProvider(providersToFormat, providers.Vyprvpn):
formatted = currentServers.Vyprvpn.ToMarkdown(providers.Vyprvpn)
case getFormatForProvider(providersToFormat, providers.Wevpn):
formatted = currentServers.Wevpn.ToMarkdown(providers.Wevpn)
case getFormatForProvider(providersToFormat, providers.Windscribe):
formatted = currentServers.Windscribe.ToMarkdown(providers.Windscribe)
default:
return ErrProviderUnspecified
}
formatted := formatServers(currentServers, providerToFormat)
output = filepath.Clean(output)
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
@@ -133,3 +100,11 @@ func (c *CLI) FormatServers(args []string) error {
return nil
}
func formatServers(allServers models.AllServers, provider string) (formatted string) {
servers, ok := allServers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
}
return servers.ToMarkdown(providers.Cyberghost)
}

View File

@@ -63,12 +63,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
}
if updateAll {
for _, provider := range providers.All() {
if provider == providers.Custom {
continue
}
options.Providers = append(options.Providers, provider)
}
options.Providers = providers.All()
} else {
if csvProviders == "" {
return ErrNoProviderSpecified
@@ -99,13 +94,13 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
}
if endUserMode {
if err := storage.FlushToFile(allServers); err != nil {
if err := storage.FlushToFile(&allServers); err != nil {
return fmt.Errorf("cannot write updated information to file: %w", err)
}
}
if maintainerMode {
if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil {
if err := writeToEmbeddedJSON(c.repoServersPath, &allServers); err != nil {
return fmt.Errorf("cannot write updated information to file: %w", err)
}
}
@@ -114,7 +109,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
}
func writeToEmbeddedJSON(repoServersPath string,
allServers models.AllServers) error {
allServers *models.AllServers) error {
const perms = 0600
f, err := os.OpenFile(repoServersPath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)

View File

@@ -27,7 +27,7 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e
// Validate Name
var validNames []string
if vpnType == vpn.OpenVPN {
validNames = providers.All()
validNames = providers.AllWithCustom()
validNames = append(validNames, "pia") // Retro-compatibility
} else { // Wireguard
validNames = []string{

View File

@@ -140,118 +140,26 @@ func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection,
countryChoices, regionChoices, cityChoices,
ispChoices, nameChoices, hostnameChoices []string,
err error) {
switch vpnServiceProvider {
case providers.Custom:
case providers.Cyberghost:
servers := allServers.GetCyberghost()
countryChoices = validation.ExtractCountries(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Expressvpn:
servers := allServers.GetExpressvpn()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Fastestvpn:
servers := allServers.GetFastestvpn()
countryChoices = validation.ExtractCountries(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.HideMyAss:
servers := allServers.GetHideMyAss()
providerServers, ok := allServers.ProviderToServers[vpnServiceProvider]
if !ok {
panic(fmt.Sprintf("VPN service provider unknown: %s", vpnServiceProvider))
}
servers := providerServers.Servers
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Ipvanish:
servers := allServers.GetIpvanish()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Ivpn:
servers := allServers.GetIvpn()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
ispChoices = validation.ExtractISPs(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Mullvad:
servers := allServers.GetMullvad()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
ispChoices = validation.ExtractISPs(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Nordvpn:
servers := allServers.GetNordvpn()
regionChoices = validation.ExtractRegions(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Perfectprivacy:
servers := allServers.GetPerfectprivacy()
cityChoices = validation.ExtractCities(servers)
case providers.Privado:
servers := allServers.GetPrivado()
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.PrivateInternetAccess:
servers := allServers.GetPia()
regionChoices = validation.ExtractRegions(servers)
hostnameChoices = validation.ExtractHostnames(servers)
nameChoices = validation.ExtractServerNames(servers)
case providers.Privatevpn:
servers := allServers.GetPrivatevpn()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Protonvpn:
servers := allServers.GetProtonvpn()
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
nameChoices = validation.ExtractServerNames(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Purevpn:
servers := allServers.GetPurevpn()
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Surfshark:
servers := allServers.GetSurfshark()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
regionChoices = validation.ExtractRegions(servers)
if vpnServiceProvider == providers.Surfshark {
// // Retro compatibility
// TODO v4 remove
regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...)
if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err)
}
// Retro compatibility
// TODO remove in v4
*ss = surfsharkRetroRegion(*ss)
case providers.Torguard:
servers := allServers.GetTorguard()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.VPNUnlimited:
servers := allServers.GetVPNUnlimited()
countryChoices = validation.ExtractCountries(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Vyprvpn:
servers := allServers.GetVyprvpn()
regionChoices = validation.ExtractRegions(servers)
case providers.Wevpn:
servers := allServers.GetWevpn()
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
case providers.Windscribe:
servers := allServers.GetWindscribe()
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
hostnameChoices = validation.ExtractHostnames(servers)
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrVPNProviderNameNotValid, vpnServiceProvider)
}
return countryChoices, regionChoices, cityChoices,

View File

@@ -43,10 +43,6 @@ func (u Updater) Validate() (err error) {
for i, provider := range u.Providers {
valid := false
for _, validProvider := range providers.All() {
if validProvider == providers.Custom {
continue
}
if provider == validProvider {
valid = true
break

View File

@@ -19,6 +19,9 @@ func ExtractCountries(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers))
for _, server := range servers {
value := server.Country
if value == "" {
continue
}
_, alreadySeen := seen[value]
if alreadySeen {
continue
@@ -35,6 +38,9 @@ func ExtractRegions(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers))
for _, server := range servers {
value := server.Region
if value == "" {
continue
}
_, alreadySeen := seen[value]
if alreadySeen {
continue
@@ -51,6 +57,9 @@ func ExtractCities(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers))
for _, server := range servers {
value := server.City
if value == "" {
continue
}
_, alreadySeen := seen[value]
if alreadySeen {
continue
@@ -67,6 +76,9 @@ func ExtractISPs(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers))
for _, server := range servers {
value := server.ISP
if value == "" {
continue
}
_, alreadySeen := seen[value]
if alreadySeen {
continue
@@ -83,6 +95,9 @@ func ExtractServerNames(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers))
for _, server := range servers {
value := server.ServerName
if value == "" {
continue
}
_, alreadySeen := seen[value]
if alreadySeen {
continue
@@ -99,6 +114,9 @@ func ExtractHostnames(servers []models.Server) (values []string) {
values = make([]string, 0, len(servers))
for _, server := range servers {
value := server.Hostname
if value == "" {
continue
}
_, alreadySeen := seen[value]
if alreadySeen {
continue

View File

@@ -26,9 +26,9 @@ const (
Windscribe = "windscribe"
)
// All returns all the providers except the custom provider.
func All() []string {
return []string{
Custom,
Cyberghost,
Expressvpn,
Fastestvpn,
@@ -51,3 +51,11 @@ func All() []string {
Windscribe,
}
}
func AllWithCustom() []string {
allProviders := All()
allProvidersWithCustom := make([]string, len(allProviders)+1)
copy(allProvidersWithCustom, allProviders)
allProvidersWithCustom[len(allProvidersWithCustom)-1] = Custom
return allProvidersWithCustom
}

View File

@@ -0,0 +1,23 @@
package providers
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_All(t *testing.T) {
t.Parallel()
all := All()
assert.NotContains(t, all, Custom)
assert.NotEmpty(t, all)
}
func Test_AllWithCustom(t *testing.T) {
t.Parallel()
all := AllWithCustom()
assert.Contains(t, all, Custom)
assert.Len(t, all, len(All())+1)
}

View File

@@ -4,108 +4,17 @@ import (
"net"
)
func (a AllServers) GetCopy() (servers AllServers) {
servers = a // copy versions and timestamps
servers.Cyberghost.Servers = a.GetCyberghost()
servers.Expressvpn.Servers = a.GetExpressvpn()
servers.Fastestvpn.Servers = a.GetFastestvpn()
servers.HideMyAss.Servers = a.GetHideMyAss()
servers.Ipvanish.Servers = a.GetIpvanish()
servers.Ivpn.Servers = a.GetIvpn()
servers.Mullvad.Servers = a.GetMullvad()
servers.Nordvpn.Servers = a.GetNordvpn()
servers.Perfectprivacy.Servers = a.GetPerfectprivacy()
servers.Privado.Servers = a.GetPrivado()
servers.Pia.Servers = a.GetPia()
servers.Privatevpn.Servers = a.GetPrivatevpn()
servers.Protonvpn.Servers = a.GetProtonvpn()
servers.Purevpn.Servers = a.GetPurevpn()
servers.Surfshark.Servers = a.GetSurfshark()
servers.Torguard.Servers = a.GetTorguard()
servers.VPNUnlimited.Servers = a.GetVPNUnlimited()
servers.Vyprvpn.Servers = a.GetVyprvpn()
servers.Windscribe.Servers = a.GetWindscribe()
return servers
}
func (a *AllServers) GetCyberghost() (servers []Server) {
return copyServers(a.Cyberghost.Servers)
}
func (a *AllServers) GetExpressvpn() (servers []Server) {
return copyServers(a.Expressvpn.Servers)
}
func (a *AllServers) GetFastestvpn() (servers []Server) {
return copyServers(a.Fastestvpn.Servers)
}
func (a *AllServers) GetHideMyAss() (servers []Server) {
return copyServers(a.HideMyAss.Servers)
}
func (a *AllServers) GetIpvanish() (servers []Server) {
return copyServers(a.Ipvanish.Servers)
}
func (a *AllServers) GetIvpn() (servers []Server) {
return copyServers(a.Ivpn.Servers)
}
func (a *AllServers) GetMullvad() (servers []Server) {
return copyServers(a.Mullvad.Servers)
}
func (a *AllServers) GetNordvpn() (servers []Server) {
return copyServers(a.Nordvpn.Servers)
}
func (a *AllServers) GetPerfectprivacy() (servers []Server) {
return copyServers(a.Perfectprivacy.Servers)
}
func (a *AllServers) GetPia() (servers []Server) {
return copyServers(a.Pia.Servers)
}
func (a *AllServers) GetPrivado() (servers []Server) {
return copyServers(a.Privado.Servers)
}
func (a *AllServers) GetPrivatevpn() (servers []Server) {
return copyServers(a.Privatevpn.Servers)
}
func (a *AllServers) GetProtonvpn() (servers []Server) {
return copyServers(a.Protonvpn.Servers)
}
func (a *AllServers) GetPurevpn() (servers []Server) {
return copyServers(a.Purevpn.Servers)
}
func (a *AllServers) GetSurfshark() (servers []Server) {
return copyServers(a.Surfshark.Servers)
}
func (a *AllServers) GetTorguard() (servers []Server) {
return copyServers(a.Torguard.Servers)
}
func (a *AllServers) GetVPNUnlimited() (servers []Server) {
return copyServers(a.VPNUnlimited.Servers)
}
func (a *AllServers) GetVyprvpn() (servers []Server) {
return copyServers(a.Vyprvpn.Servers)
}
func (a *AllServers) GetWevpn() (servers []Server) {
return copyServers(a.Wevpn.Servers)
}
func (a *AllServers) GetWindscribe() (servers []Server) {
return copyServers(a.Windscribe.Servers)
func (a AllServers) GetCopy() (allServersCopy AllServers) {
allServersCopy.Version = a.Version
allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers))
for provider, servers := range a.ProviderToServers {
allServersCopy.ProviderToServers[provider] = Servers{
Version: servers.Version,
Timestamp: servers.Timestamp,
Servers: copyServers(servers.Servers),
}
}
return allServersCopy
}
func copyServers(servers []Server) (serversCopy []Server) {

View File

@@ -4,109 +4,118 @@ import (
"net"
"testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AllServers_GetCopy(t *testing.T) {
allServers := AllServers{
Cyberghost: Servers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 2,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Expressvpn: Servers{
providers.Expressvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Fastestvpn: Servers{
providers.Fastestvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
HideMyAss: Servers{
providers.HideMyAss: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Ipvanish: Servers{
providers.Ipvanish: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Ivpn: Servers{
providers.Ivpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Mullvad: Servers{
providers.Mullvad: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Nordvpn: Servers{
providers.Nordvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Perfectprivacy: Servers{
providers.Perfectprivacy: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Privado: Servers{
providers.Privado: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Pia: Servers{
providers.PrivateInternetAccess: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Privatevpn: Servers{
providers.Privatevpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Protonvpn: Servers{
providers.Protonvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Purevpn: Servers{
providers.Purevpn: {
Version: 1,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Surfshark: Servers{
providers.Surfshark: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Torguard: Servers{
providers.Torguard: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
VPNUnlimited: Servers{
providers.VPNUnlimited: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Vyprvpn: Servers{
providers.Vyprvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Windscribe: Servers{
providers.Wevpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Windscribe: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
},
}
servers := allServers.GetCopy()
@@ -114,32 +123,6 @@ func Test_AllServers_GetCopy(t *testing.T) {
assert.Equal(t, allServers, servers)
}
func Test_AllServers_GetVyprvpn(t *testing.T) {
allServers := AllServers{
Vyprvpn: Servers{
Servers: []Server{
{Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}},
{Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}},
},
},
}
servers := allServers.GetVyprvpn()
expectedServers := []Server{
{Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}},
{Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}},
}
assert.Equal(t, expectedServers, servers)
allServers.Vyprvpn.Servers[0].IPs[0][0] = 9
assert.NotEqual(t, 9, servers[0].IPs[0][0])
allServers.Vyprvpn.Servers[0].IPs[0][0] = 1
servers[0].IPs[0][0] = 9
assert.NotEqual(t, 9, allServers.Vyprvpn.Servers[0].IPs[0][0])
}
func Test_copyIPs(t *testing.T) {
t.Parallel()

View File

@@ -1,54 +1,163 @@
package models
import (
"bytes"
"encoding/json"
"fmt"
"math"
"reflect"
"github.com/qdm12/gluetun/internal/constants/providers"
)
type AllServers struct {
Version uint16 `json:"version"` // used for migration of the top level scheme
Cyberghost Servers `json:"cyberghost"`
Expressvpn Servers `json:"expressvpn"`
Fastestvpn Servers `json:"fastestvpn"`
HideMyAss Servers `json:"hidemyass"`
Ipvanish Servers `json:"ipvanish"`
Ivpn Servers `json:"ivpn"`
Mullvad Servers `json:"mullvad"`
Perfectprivacy Servers `json:"perfect privacy"`
Nordvpn Servers `json:"nordvpn"`
Privado Servers `json:"privado"`
Pia Servers `json:"private internet access"`
Privatevpn Servers `json:"privatevpn"`
Protonvpn Servers `json:"protonvpn"`
Purevpn Servers `json:"purevpn"`
Surfshark Servers `json:"surfshark"`
Torguard Servers `json:"torguard"`
VPNUnlimited Servers `json:"vpn unlimited"`
Vyprvpn Servers `json:"vyprvpn"`
Wevpn Servers `json:"wevpn"`
Windscribe Servers `json:"windscribe"`
Version uint16 // used for migration of the top level scheme
ProviderToServers map[string]Servers
}
func (a *AllServers) Count() int {
return len(a.Cyberghost.Servers) +
len(a.Expressvpn.Servers) +
len(a.Fastestvpn.Servers) +
len(a.HideMyAss.Servers) +
len(a.Ipvanish.Servers) +
len(a.Ivpn.Servers) +
len(a.Mullvad.Servers) +
len(a.Nordvpn.Servers) +
len(a.Perfectprivacy.Servers) +
len(a.Privado.Servers) +
len(a.Pia.Servers) +
len(a.Privatevpn.Servers) +
len(a.Protonvpn.Servers) +
len(a.Purevpn.Servers) +
len(a.Surfshark.Servers) +
len(a.Torguard.Servers) +
len(a.VPNUnlimited.Servers) +
len(a.Vyprvpn.Servers) +
len(a.Wevpn.Servers) +
len(a.Windscribe.Servers)
func (a *AllServers) ServersSlice(provider string) []Server {
servers, ok := a.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in all servers", provider))
}
return copyServers(servers.Servers)
}
var _ json.Marshaler = (*AllServers)(nil)
// MarshalJSON marshals all servers to JSON.
// Note you need to use a pointer to all servers
// for it to work with native json methods such as
// json.Marshal.
func (a *AllServers) MarshalJSON() (data []byte, err error) {
buffer := bytes.NewBuffer(nil)
_, err = buffer.WriteString("{")
if err != nil {
return nil, fmt.Errorf("cannot write opening bracket: %w", err)
}
versionString := fmt.Sprintf(`"version":%d`, a.Version)
_, err = buffer.WriteString(versionString)
if err != nil {
return nil, fmt.Errorf("cannot write schema version string: %w", err)
}
for _, provider := range providers.All() {
servers, ok := a.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in all servers", provider))
}
providerKey := fmt.Sprintf(`,"%s":`, provider)
_, err = buffer.WriteString(providerKey)
if err != nil {
return nil, fmt.Errorf("cannot write provider key %s: %w",
providerKey, err)
}
serversJSON, err := json.Marshal(servers)
if err != nil {
return nil, fmt.Errorf("failed encoding servers for provider %s: %w",
provider, err)
}
_, err = buffer.Write(serversJSON)
if err != nil {
return nil, fmt.Errorf("cannot write JSON servers data for provider %s: %w",
provider, err)
}
}
_, err = buffer.WriteString("}")
if err != nil {
return nil, fmt.Errorf("cannot write closing bracket: %w", err)
}
return buffer.Bytes(), nil
}
var _ json.Unmarshaler = (*AllServers)(nil)
func (a *AllServers) UnmarshalJSON(data []byte) (err error) {
keyValues := make(map[string]interface{})
err = json.Unmarshal(data, &keyValues)
if err != nil {
return err
}
versionUnmarshaled := keyValues["version"]
if versionUnmarshaled != nil { // defaults to 0
version, ok := versionUnmarshaled.(float64)
if !ok {
return &json.UnmarshalTypeError{
Value: fmt.Sprintf("number %v", versionUnmarshaled),
Type: reflect.TypeOf(uint16(0)),
Struct: "models.AllServers",
Field: "Version",
}
}
if math.Round(version) != version ||
version < 0 || version > float64(^uint16(0)) {
return &json.UnmarshalTypeError{
Value: fmt.Sprintf("number %v", version),
Type: reflect.TypeOf(uint16(0)),
Struct: "models.AllServers",
Field: "Version",
}
}
a.Version = uint16(version)
delete(keyValues, "version")
}
if len(keyValues) == 0 {
return nil
}
a.ProviderToServers = make(map[string]Servers, len(keyValues))
allProviders := providers.All()
allProvidersSet := make(map[string]struct{}, len(allProviders))
for _, provider := range allProviders {
allProvidersSet[provider] = struct{}{}
}
for key, value := range keyValues {
if _, ok := allProvidersSet[key]; !ok {
// not a provider known by Gluetun
// or a non-servers field.
continue
}
jsonValue, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("cannot marshal %s servers: %w",
key, err)
}
var servers Servers
err = json.Unmarshal(jsonValue, &servers)
if err != nil {
return fmt.Errorf("cannot unmarshal %s servers: %w",
key, err)
}
a.ProviderToServers[key] = servers
}
return nil
}
func (a *AllServers) Count() (count int) {
for _, servers := range a.ProviderToServers {
count += len(servers.Servers)
}
return count
}
type Servers struct {
Version uint16 `json:"version"`
Timestamp int64 `json:"timestamp"`
Servers []Server `json:"servers"`
Servers []Server `json:"servers,omitempty"`
}

View File

@@ -0,0 +1,189 @@
package models
import (
"bytes"
"encoding/json"
"testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AllServers_MarshalJSON(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
allServers *AllServers
dataString string
errWrapped error
errMessage string
}{
"empty": {
allServers: &AllServers{
ProviderToServers: map[string]Servers{},
},
dataString: `{"version":0,` +
`"cyberghost":{"version":0,"timestamp":0},` +
`"expressvpn":{"version":0,"timestamp":0},` +
`"fastestvpn":{"version":0,"timestamp":0},` +
`"hidemyass":{"version":0,"timestamp":0},` +
`"ipvanish":{"version":0,"timestamp":0},` +
`"ivpn":{"version":0,"timestamp":0},` +
`"mullvad":{"version":0,"timestamp":0},` +
`"nordvpn":{"version":0,"timestamp":0},` +
`"perfect privacy":{"version":0,"timestamp":0},` +
`"privado":{"version":0,"timestamp":0},` +
`"private internet access":{"version":0,"timestamp":0},` +
`"privatevpn":{"version":0,"timestamp":0},` +
`"protonvpn":{"version":0,"timestamp":0},` +
`"purevpn":{"version":0,"timestamp":0},` +
`"surfshark":{"version":0,"timestamp":0},` +
`"torguard":{"version":0,"timestamp":0},` +
`"vpn unlimited":{"version":0,"timestamp":0},` +
`"vyprvpn":{"version":0,"timestamp":0},` +
`"wevpn":{"version":0,"timestamp":0},` +
`"windscribe":{"version":0,"timestamp":0}}`,
},
"two known providers": {
allServers: &AllServers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 1,
Timestamp: 1000,
Servers: []Server{
{Country: "A"},
{Country: "B"},
},
},
providers.Privado: {
Version: 2,
Timestamp: 2000,
Servers: []Server{
{City: "C"},
{City: "D"},
},
},
},
},
dataString: `{"version":1,` +
`"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` +
`"expressvpn":{"version":0,"timestamp":0},` +
`"fastestvpn":{"version":0,"timestamp":0},` +
`"hidemyass":{"version":0,"timestamp":0},` +
`"ipvanish":{"version":0,"timestamp":0},` +
`"ivpn":{"version":0,"timestamp":0},` +
`"mullvad":{"version":0,"timestamp":0},` +
`"nordvpn":{"version":0,"timestamp":0},` +
`"perfect privacy":{"version":0,"timestamp":0},` +
`"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]},` +
`"private internet access":{"version":0,"timestamp":0},` +
`"privatevpn":{"version":0,"timestamp":0},` +
`"protonvpn":{"version":0,"timestamp":0},` +
`"purevpn":{"version":0,"timestamp":0},` +
`"surfshark":{"version":0,"timestamp":0},` +
`"torguard":{"version":0,"timestamp":0},` +
`"vpn unlimited":{"version":0,"timestamp":0},` +
`"vyprvpn":{"version":0,"timestamp":0},` +
`"wevpn":{"version":0,"timestamp":0},` +
`"windscribe":{"version":0,"timestamp":0}}`,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
// Populate all providers in all servers
for _, provider := range providers.All() {
_, has := testCase.allServers.ProviderToServers[provider]
if !has {
testCase.allServers.ProviderToServers[provider] = Servers{}
}
}
data, err := testCase.allServers.MarshalJSON()
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.dataString, string(data))
data, err = json.Marshal(testCase.allServers)
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.dataString, string(data))
buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
// encoder.SetIndent("", " ")
err = encoder.Encode(testCase.allServers)
require.NoError(t, err)
assert.Equal(t, testCase.dataString+"\n", buffer.String())
})
}
}
func Test_AllServers_UnmarshalJSON(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
dataString string
allServers AllServers
errWrapped error
errMessage string
}{
"empty": {
dataString: "{}",
allServers: AllServers{},
},
"two known providers": {
dataString: `{"version":1,` +
`"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` +
`"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]}}`,
allServers: AllServers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 1,
Timestamp: 1000,
Servers: []Server{
{Country: "A"},
{Country: "B"},
},
},
providers.Privado: {
Version: 2,
Timestamp: 2000,
Servers: []Server{
{City: "C"},
{City: "D"},
},
},
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
data := []byte(testCase.dataString)
var allServers AllServers
err := json.Unmarshal(data, &allServers)
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.allServers, allServers)
})
}
}

View File

@@ -51,50 +51,51 @@ type PortForwarder interface {
}
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider {
serversSlice := allServers.ServersSlice(provider)
randSource := rand.NewSource(timeNow().UnixNano())
switch provider {
case providers.Custom:
return custom.New()
case providers.Cyberghost:
return cyberghost.New(allServers.Cyberghost.Servers, randSource)
return cyberghost.New(serversSlice, randSource)
case providers.Expressvpn:
return expressvpn.New(allServers.Expressvpn.Servers, randSource)
return expressvpn.New(serversSlice, randSource)
case providers.Fastestvpn:
return fastestvpn.New(allServers.Fastestvpn.Servers, randSource)
return fastestvpn.New(serversSlice, randSource)
case providers.HideMyAss:
return hidemyass.New(allServers.HideMyAss.Servers, randSource)
return hidemyass.New(serversSlice, randSource)
case providers.Ipvanish:
return ipvanish.New(allServers.Ipvanish.Servers, randSource)
return ipvanish.New(serversSlice, randSource)
case providers.Ivpn:
return ivpn.New(allServers.Ivpn.Servers, randSource)
return ivpn.New(serversSlice, randSource)
case providers.Mullvad:
return mullvad.New(allServers.Mullvad.Servers, randSource)
return mullvad.New(serversSlice, randSource)
case providers.Nordvpn:
return nordvpn.New(allServers.Nordvpn.Servers, randSource)
return nordvpn.New(serversSlice, randSource)
case providers.Perfectprivacy:
return perfectprivacy.New(allServers.Perfectprivacy.Servers, randSource)
return perfectprivacy.New(serversSlice, randSource)
case providers.Privado:
return privado.New(allServers.Privado.Servers, randSource)
return privado.New(serversSlice, randSource)
case providers.PrivateInternetAccess:
return privateinternetaccess.New(allServers.Pia.Servers, randSource, timeNow)
return privateinternetaccess.New(serversSlice, randSource, timeNow)
case providers.Privatevpn:
return privatevpn.New(allServers.Privatevpn.Servers, randSource)
return privatevpn.New(serversSlice, randSource)
case providers.Protonvpn:
return protonvpn.New(allServers.Protonvpn.Servers, randSource)
return protonvpn.New(serversSlice, randSource)
case providers.Purevpn:
return purevpn.New(allServers.Purevpn.Servers, randSource)
return purevpn.New(serversSlice, randSource)
case providers.Surfshark:
return surfshark.New(allServers.Surfshark.Servers, randSource)
return surfshark.New(serversSlice, randSource)
case providers.Torguard:
return torguard.New(allServers.Torguard.Servers, randSource)
return torguard.New(serversSlice, randSource)
case providers.VPNUnlimited:
return vpnunlimited.New(allServers.VPNUnlimited.Servers, randSource)
return vpnunlimited.New(serversSlice, randSource)
case providers.Vyprvpn:
return vyprvpn.New(allServers.Vyprvpn.Servers, randSource)
return vyprvpn.New(serversSlice, randSource)
case providers.Wevpn:
return wevpn.New(allServers.Wevpn.Servers, randSource)
return wevpn.New(serversSlice, randSource)
case providers.Windscribe:
return windscribe.New(allServers.Windscribe.Servers, randSource)
return windscribe.New(serversSlice, randSource)
default:
panic("provider " + provider + " is unknown") // should never occur
}

View File

@@ -11,14 +11,14 @@ import (
var _ Flusher = (*Storage)(nil)
type Flusher interface {
FlushToFile(allServers models.AllServers) error
FlushToFile(allServers *models.AllServers) error
}
func (s *Storage) FlushToFile(allServers models.AllServers) error {
func (s *Storage) FlushToFile(allServers *models.AllServers) error {
return flushToFile(s.filepath, allServers)
}
func flushToFile(path string, servers models.AllServers) error {
func flushToFile(path string, servers *models.AllServers) error {
dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, 0644); err != nil {
return err

View File

@@ -3,6 +3,8 @@ package storage
import (
"testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -12,5 +14,13 @@ func Test_parseHardcodedServers(t *testing.T) {
servers, err := parseHardcodedServers()
require.NoError(t, err)
require.NotEmpty(t, len(servers.Cyberghost.Servers))
// all providers minus custom
allProviders := providers.All()
require.Equal(t, len(allProviders), len(servers.ProviderToServers))
for _, provider := range allProviders {
servers, ok := servers.ProviderToServers[provider]
assert.Truef(t, ok, "for provider %s", provider)
assert.NotEmptyf(t, servers, "for provider %s", provider)
}
}

View File

@@ -28,29 +28,20 @@ func (s *Storage) logTimeDiff(provider string, persistedUnix, hardcodedUnix int6
}
func (s *Storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers {
return models.AllServers{
allProviders := providers.All()
merged := models.AllServers{
Version: hardcoded.Version,
Cyberghost: s.mergeProviderServers(providers.Cyberghost, hardcoded.Cyberghost, persisted.Cyberghost),
Expressvpn: s.mergeProviderServers(providers.Expressvpn, hardcoded.Expressvpn, persisted.Expressvpn),
Fastestvpn: s.mergeProviderServers(providers.Fastestvpn, hardcoded.Fastestvpn, persisted.Fastestvpn),
HideMyAss: s.mergeProviderServers(providers.HideMyAss, hardcoded.HideMyAss, persisted.HideMyAss),
Ipvanish: s.mergeProviderServers(providers.Ipvanish, hardcoded.Ipvanish, persisted.Ipvanish),
Ivpn: s.mergeProviderServers(providers.Ivpn, hardcoded.Ivpn, persisted.Ivpn),
Mullvad: s.mergeProviderServers(providers.Mullvad, hardcoded.Mullvad, persisted.Mullvad),
Nordvpn: s.mergeProviderServers(providers.Nordvpn, hardcoded.Nordvpn, persisted.Nordvpn),
Perfectprivacy: s.mergeProviderServers(providers.Perfectprivacy, hardcoded.Perfectprivacy, persisted.Perfectprivacy),
Privado: s.mergeProviderServers(providers.Privado, hardcoded.Privado, persisted.Privado),
Pia: s.mergeProviderServers(providers.PrivateInternetAccess, hardcoded.Pia, persisted.Pia),
Privatevpn: s.mergeProviderServers(providers.Privatevpn, hardcoded.Privatevpn, persisted.Privatevpn),
Protonvpn: s.mergeProviderServers(providers.Protonvpn, hardcoded.Protonvpn, persisted.Protonvpn),
Purevpn: s.mergeProviderServers(providers.Purevpn, hardcoded.Purevpn, persisted.Purevpn),
Surfshark: s.mergeProviderServers(providers.Surfshark, hardcoded.Surfshark, persisted.Surfshark),
Torguard: s.mergeProviderServers(providers.Torguard, hardcoded.Torguard, persisted.Torguard),
VPNUnlimited: s.mergeProviderServers(providers.VPNUnlimited, hardcoded.VPNUnlimited, persisted.VPNUnlimited),
Vyprvpn: s.mergeProviderServers(providers.Vyprvpn, hardcoded.Vyprvpn, persisted.Vyprvpn),
Wevpn: s.mergeProviderServers(providers.Wevpn, hardcoded.Wevpn, persisted.Wevpn),
Windscribe: s.mergeProviderServers(providers.Windscribe, hardcoded.Windscribe, persisted.Windscribe),
ProviderToServers: make(map[string]models.Servers, len(allProviders)),
}
for _, provider := range allProviders {
hardcodedServers := hardcoded.ProviderToServers[provider]
persistedServers := persisted.ProviderToServers[provider]
merged.ProviderToServers[provider] = s.mergeProviderServers(provider,
hardcodedServers, persistedServers)
}
return merged
}
func (s *Storage) mergeProviderServers(provider string,

View File

@@ -38,172 +38,39 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) (
servers models.AllServers, err error) {
var versions allVersions
if err := json.Unmarshal(b, &versions); err != nil {
return servers, fmt.Errorf("cannot decode versions: %w", err)
}
var rawMessages allJSONRawMessages
rawMessages := make(map[string]json.RawMessage)
if err := json.Unmarshal(b, &rawMessages); err != nil {
return servers, fmt.Errorf("cannot decode servers: %w", err)
}
type element struct {
provider string
hardcoded models.Servers
serverVersion serverVersion
rawMessage json.RawMessage
target *models.Servers
}
elements := []element{
{
provider: providers.Cyberghost,
hardcoded: hardcoded.Cyberghost,
serverVersion: versions.Cyberghost,
rawMessage: rawMessages.Cyberghost,
target: &servers.Cyberghost,
},
{
provider: providers.Expressvpn,
hardcoded: hardcoded.Expressvpn,
serverVersion: versions.Expressvpn,
rawMessage: rawMessages.Expressvpn,
target: &servers.Expressvpn,
},
{
provider: providers.Fastestvpn,
hardcoded: hardcoded.Fastestvpn,
serverVersion: versions.Fastestvpn,
rawMessage: rawMessages.Fastestvpn,
target: &servers.Fastestvpn,
},
{
provider: providers.HideMyAss,
hardcoded: hardcoded.HideMyAss,
serverVersion: versions.HideMyAss,
rawMessage: rawMessages.HideMyAss,
target: &servers.HideMyAss,
},
{
provider: providers.Ipvanish,
hardcoded: hardcoded.Ipvanish,
serverVersion: versions.Ipvanish,
rawMessage: rawMessages.Ipvanish,
target: &servers.Ipvanish,
},
{
provider: providers.Ivpn,
hardcoded: hardcoded.Ivpn,
serverVersion: versions.Ivpn,
rawMessage: rawMessages.Ivpn,
target: &servers.Ivpn,
},
{
provider: providers.Mullvad,
hardcoded: hardcoded.Mullvad,
serverVersion: versions.Mullvad,
rawMessage: rawMessages.Mullvad,
target: &servers.Mullvad,
},
{
provider: providers.Nordvpn,
hardcoded: hardcoded.Nordvpn,
serverVersion: versions.Nordvpn,
rawMessage: rawMessages.Nordvpn,
target: &servers.Nordvpn,
},
{
provider: providers.Perfectprivacy,
hardcoded: hardcoded.Perfectprivacy,
serverVersion: versions.Perfectprivacy,
rawMessage: rawMessages.Perfectprivacy,
target: &servers.Perfectprivacy,
},
{
provider: providers.Privado,
hardcoded: hardcoded.Privado,
serverVersion: versions.Privado,
rawMessage: rawMessages.Privado,
target: &servers.Privado,
},
{
provider: providers.PrivateInternetAccess,
hardcoded: hardcoded.Pia,
serverVersion: versions.Pia,
rawMessage: rawMessages.Pia,
target: &servers.Pia,
},
{
provider: providers.Privatevpn,
hardcoded: hardcoded.Privatevpn,
serverVersion: versions.Privatevpn,
rawMessage: rawMessages.Privatevpn,
target: &servers.Privatevpn,
},
{
provider: providers.Protonvpn,
hardcoded: hardcoded.Protonvpn,
serverVersion: versions.Protonvpn,
rawMessage: rawMessages.Protonvpn,
target: &servers.Protonvpn,
},
{
provider: providers.Purevpn,
hardcoded: hardcoded.Purevpn,
serverVersion: versions.Purevpn,
rawMessage: rawMessages.Purevpn,
target: &servers.Purevpn,
},
{
provider: providers.Surfshark,
hardcoded: hardcoded.Surfshark,
serverVersion: versions.Surfshark,
rawMessage: rawMessages.Surfshark,
target: &servers.Surfshark,
},
{
provider: providers.Torguard,
hardcoded: hardcoded.Torguard,
serverVersion: versions.Torguard,
rawMessage: rawMessages.Torguard,
target: &servers.Torguard,
},
{
provider: providers.VPNUnlimited,
hardcoded: hardcoded.VPNUnlimited,
serverVersion: versions.VPNUnlimited,
rawMessage: rawMessages.VPNUnlimited,
target: &servers.VPNUnlimited,
},
{
provider: providers.Vyprvpn,
hardcoded: hardcoded.Vyprvpn,
serverVersion: versions.Vyprvpn,
rawMessage: rawMessages.Vyprvpn,
target: &servers.Vyprvpn,
},
{
provider: providers.Wevpn,
hardcoded: hardcoded.Wevpn,
serverVersion: versions.Wevpn,
rawMessage: rawMessages.Wevpn,
target: &servers.Wevpn,
},
{
provider: providers.Windscribe,
hardcoded: hardcoded.Windscribe,
serverVersion: versions.Windscribe,
rawMessage: rawMessages.Windscribe,
target: &servers.Windscribe,
},
// Note schema version is at map key "version" as number
allProviders := providers.All()
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
for _, provider := range allProviders {
hardcoded, ok := hardcoded.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
}
for _, element := range elements {
*element.target, err = s.readServers(element.provider,
element.hardcoded, element.serverVersion, element.rawMessage)
if err != nil {
return servers, err
rawMessage, ok := rawMessages[provider]
if !ok {
// If the provider is not found in the data bytes, just don't set it in
// the providers map. That way the hardcoded servers will override them.
// This is user provided and could come from different sources in the
// future (e.g. a file or API request).
continue
}
mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage)
if err != nil {
return models.AllServers{}, err
} else if !versionsMatch {
// mergedServers is the empty struct in this case, so don't set the key
// in the providerToServers map.
continue
}
servers.ProviderToServers[provider] = mergedServers
}
return servers, nil
@@ -214,73 +81,20 @@ var (
)
func (s *Storage) readServers(provider string, hardcoded models.Servers,
serverVersion serverVersion, rawMessage json.RawMessage) (
servers models.Servers, err error) {
rawMessage json.RawMessage) (servers models.Servers, versionsMatch bool, err error) {
provider = strings.Title(provider)
if hardcoded.Version != serverVersion.Version {
s.logVersionDiff(provider, hardcoded.Version, serverVersion.Version)
return servers, nil
}
err = json.Unmarshal(rawMessage, &servers)
var persistedServers models.Servers
err = json.Unmarshal(rawMessage, &persistedServers)
if err != nil {
return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
}
return servers, nil
}
versionsMatch = hardcoded.Version == persistedServers.Version
if !versionsMatch {
s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version)
return servers, versionsMatch, nil
}
// allVersions is a subset of models.AllServers structure used to track
// versions to avoid unmarshaling errors.
type allVersions struct {
Version uint16 `json:"version"` // used for migration of the top level scheme
Cyberghost serverVersion `json:"cyberghost"`
Expressvpn serverVersion `json:"expressvpn"`
Fastestvpn serverVersion `json:"fastestvpn"`
HideMyAss serverVersion `json:"hidemyass"`
Ipvanish serverVersion `json:"ipvanish"`
Ivpn serverVersion `json:"ivpn"`
Mullvad serverVersion `json:"mullvad"`
Nordvpn serverVersion `json:"nordvpn"`
Perfectprivacy serverVersion `json:"perfect privacy"`
Privado serverVersion `json:"privado"`
Pia serverVersion `json:"private internet access"`
Privatevpn serverVersion `json:"privatevpn"`
Protonvpn serverVersion `json:"protonvpn"`
Purevpn serverVersion `json:"purevpn"`
Surfshark serverVersion `json:"surfshark"`
Torguard serverVersion `json:"torguard"`
VPNUnlimited serverVersion `json:"vpn unlimited"`
Vyprvpn serverVersion `json:"vyprvpn"`
Wevpn serverVersion `json:"wevpn"`
Windscribe serverVersion `json:"windscribe"`
}
type serverVersion struct {
Version uint16 `json:"version"`
}
// allJSONRawMessages is to delay decoding of each provider servers.
type allJSONRawMessages struct {
Version uint16 `json:"version"` // used for migration of the top level scheme
Cyberghost json.RawMessage `json:"cyberghost"`
Expressvpn json.RawMessage `json:"expressvpn"`
Fastestvpn json.RawMessage `json:"fastestvpn"`
HideMyAss json.RawMessage `json:"hidemyass"`
Ipvanish json.RawMessage `json:"ipvanish"`
Ivpn json.RawMessage `json:"ivpn"`
Mullvad json.RawMessage `json:"mullvad"`
Nordvpn json.RawMessage `json:"nordvpn"`
Perfectprivacy json.RawMessage `json:"perfect privacy"`
Privado json.RawMessage `json:"privado"`
Pia json.RawMessage `json:"private internet access"`
Privatevpn json.RawMessage `json:"privatevpn"`
Protonvpn json.RawMessage `json:"protonvpn"`
Purevpn json.RawMessage `json:"purevpn"`
Surfshark json.RawMessage `json:"surfshark"`
Torguard json.RawMessage `json:"torguard"`
VPNUnlimited json.RawMessage `json:"vpn unlimited"`
Vyprvpn json.RawMessage `json:"vyprvpn"`
Wevpn json.RawMessage `json:"wevpn"`
Windscribe json.RawMessage `json:"windscribe"`
return persistedServers, versionsMatch, nil
}

View File

@@ -1,15 +1,35 @@
package storage
import (
"errors"
"fmt"
"testing"
gomock "github.com/golang/mock/gomock"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
servers models.AllServers) models.AllServers {
allProviders := providers.All()
if servers.ProviderToServers == nil {
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders)-1)
}
for _, provider := range allProviders {
_, has := servers.ProviderToServers[provider]
if has {
continue
}
servers.ProviderToServers[provider] = models.Servers{
Version: allProviderVersion,
Timestamp: allProviderTimestamp,
}
}
return servers
}
func Test_extractServersFromBytes(t *testing.T) {
t.Parallel()
@@ -18,60 +38,23 @@ func Test_extractServersFromBytes(t *testing.T) {
hardcoded models.AllServers
logged []string
persisted models.AllServers
err error
errMessage string
}{
"no data": {
err: errors.New("cannot decode versions: unexpected end of JSON input"),
"bad JSON": {
b: []byte("garbage"),
errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value",
},
"empty JSON": {
b: []byte("{}"),
err: errors.New("cannot decode servers for provider: Cyberghost: unexpected end of JSON input"),
"bad provider JSON": {
b: []byte(`{"cyberghost": "garbage"}`),
hardcoded: populateProviders(1, 0, models.AllServers{}),
errMessage: "cannot decode servers for provider: Cyberghost: " +
"json: cannot unmarshal string into Go value of type models.Servers",
},
"different versions": {
"absent provider keys": {
b: []byte(`{}`),
hardcoded: models.AllServers{
Cyberghost: models.Servers{Version: 1},
Expressvpn: models.Servers{Version: 1},
Fastestvpn: models.Servers{Version: 1},
HideMyAss: models.Servers{Version: 1},
Ipvanish: models.Servers{Version: 1},
Ivpn: models.Servers{Version: 1},
Mullvad: models.Servers{Version: 1},
Nordvpn: models.Servers{Version: 1},
Perfectprivacy: models.Servers{Version: 1},
Privado: models.Servers{Version: 1},
Pia: models.Servers{Version: 1},
Privatevpn: models.Servers{Version: 1},
Protonvpn: models.Servers{Version: 1},
Purevpn: models.Servers{Version: 1},
Surfshark: models.Servers{Version: 1},
Torguard: models.Servers{Version: 1},
VPNUnlimited: models.Servers{Version: 1},
Vyprvpn: models.Servers{Version: 1},
Wevpn: models.Servers{Version: 1},
Windscribe: models.Servers{Version: 1},
},
logged: []string{
"Cyberghost servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Expressvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Fastestvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Hidemyass servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Ipvanish servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Ivpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Mullvad servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Nordvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Perfect Privacy servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Privado servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Private Internet Access servers from file discarded because they have version 0 and hardcoded servers have version 1", //nolint:lll
"Privatevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Protonvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Purevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Surfshark servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Torguard servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Vpn Unlimited servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Vyprvpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Wevpn servers from file discarded because they have version 0 and hardcoded servers have version 1",
"Windscribe servers from file discarded because they have version 0 and hardcoded servers have version 1",
hardcoded: populateProviders(1, 0, models.AllServers{}),
persisted: models.AllServers{
ProviderToServers: map[string]models.Servers{},
},
},
"same versions": {
@@ -97,49 +80,57 @@ func Test_extractServersFromBytes(t *testing.T) {
"wevpn": {"version": 1, "timestamp": 1},
"windscribe": {"version": 1, "timestamp": 1}
}`),
hardcoded: models.AllServers{
Cyberghost: models.Servers{Version: 1},
Expressvpn: models.Servers{Version: 1},
Fastestvpn: models.Servers{Version: 1},
HideMyAss: models.Servers{Version: 1},
Ipvanish: models.Servers{Version: 1},
Ivpn: models.Servers{Version: 1},
Mullvad: models.Servers{Version: 1},
Nordvpn: models.Servers{Version: 1},
Perfectprivacy: models.Servers{Version: 1},
Privado: models.Servers{Version: 1},
Pia: models.Servers{Version: 1},
Privatevpn: models.Servers{Version: 1},
Protonvpn: models.Servers{Version: 1},
Purevpn: models.Servers{Version: 1},
Surfshark: models.Servers{Version: 1},
Torguard: models.Servers{Version: 1},
VPNUnlimited: models.Servers{Version: 1},
Vyprvpn: models.Servers{Version: 1},
Wevpn: models.Servers{Version: 1},
Windscribe: models.Servers{Version: 1},
hardcoded: populateProviders(1, 0, models.AllServers{}),
persisted: populateProviders(1, 1, models.AllServers{}),
},
"different versions": {
b: []byte(`{
"cyberghost": {"version": 1, "timestamp": 1},
"expressvpn": {"version": 1, "timestamp": 1},
"fastestvpn": {"version": 1, "timestamp": 1},
"hidemyass": {"version": 1, "timestamp": 1},
"ipvanish": {"version": 1, "timestamp": 1},
"ivpn": {"version": 1, "timestamp": 1},
"mullvad": {"version": 1, "timestamp": 1},
"nordvpn": {"version": 1, "timestamp": 1},
"perfect privacy": {"version": 1, "timestamp": 1},
"privado": {"version": 1, "timestamp": 1},
"private internet access": {"version": 1, "timestamp": 1},
"privatevpn": {"version": 1, "timestamp": 1},
"protonvpn": {"version": 1, "timestamp": 1},
"purevpn": {"version": 1, "timestamp": 1},
"surfshark": {"version": 1, "timestamp": 1},
"torguard": {"version": 1, "timestamp": 1},
"vpn unlimited": {"version": 1, "timestamp": 1},
"vyprvpn": {"version": 1, "timestamp": 1},
"wevpn": {"version": 1, "timestamp": 1},
"windscribe": {"version": 1, "timestamp": 1}
}`),
hardcoded: populateProviders(2, 0, models.AllServers{}),
logged: []string{
"Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Fastestvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Hidemyass servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Ipvanish servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Ivpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Mullvad servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Nordvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Perfect Privacy servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Privado servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Private Internet Access servers from file discarded because they have version 1 and hardcoded servers have version 2", //nolint:lll
"Privatevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Protonvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Purevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Surfshark servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Torguard servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Vpn Unlimited servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Vyprvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Wevpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Windscribe servers from file discarded because they have version 1 and hardcoded servers have version 2",
},
persisted: models.AllServers{
Cyberghost: models.Servers{Version: 1, Timestamp: 1},
Expressvpn: models.Servers{Version: 1, Timestamp: 1},
Fastestvpn: models.Servers{Version: 1, Timestamp: 1},
HideMyAss: models.Servers{Version: 1, Timestamp: 1},
Ipvanish: models.Servers{Version: 1, Timestamp: 1},
Ivpn: models.Servers{Version: 1, Timestamp: 1},
Mullvad: models.Servers{Version: 1, Timestamp: 1},
Nordvpn: models.Servers{Version: 1, Timestamp: 1},
Perfectprivacy: models.Servers{Version: 1, Timestamp: 1},
Privado: models.Servers{Version: 1, Timestamp: 1},
Pia: models.Servers{Version: 1, Timestamp: 1},
Privatevpn: models.Servers{Version: 1, Timestamp: 1},
Protonvpn: models.Servers{Version: 1, Timestamp: 1},
Purevpn: models.Servers{Version: 1, Timestamp: 1},
Surfshark: models.Servers{Version: 1, Timestamp: 1},
Torguard: models.Servers{Version: 1, Timestamp: 1},
VPNUnlimited: models.Servers{Version: 1, Timestamp: 1},
Vyprvpn: models.Servers{Version: 1, Timestamp: 1},
Wevpn: models.Servers{Version: 1, Timestamp: 1},
Windscribe: models.Servers{Version: 1, Timestamp: 1},
ProviderToServers: map[string]models.Servers{},
},
},
}
@@ -151,8 +142,13 @@ func Test_extractServersFromBytes(t *testing.T) {
ctrl := gomock.NewController(t)
logger := NewMockInfoErrorer(ctrl)
var previousLogCall *gomock.Call
for _, logged := range testCase.logged {
logger.EXPECT().Info(logged)
call := logger.EXPECT().Info(logged)
if previousLogCall != nil {
call.After(previousLogCall)
}
previousLogCall = call
}
s := &Storage{
@@ -161,9 +157,8 @@ func Test_extractServersFromBytes(t *testing.T) {
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
@@ -171,4 +166,25 @@ func Test_extractServersFromBytes(t *testing.T) {
assert.Equal(t, testCase.persisted, servers)
})
}
t.Run("hardcoded panic", func(t *testing.T) {
t.Parallel()
s := &Storage{}
allProviders := providers.All()
require.GreaterOrEqual(t, len(allProviders), 2)
b := []byte(`{}`)
hardcoded := models.AllServers{
ProviderToServers: map[string]models.Servers{
allProviders[0]: {},
// Missing provider allProviders[1]
},
}
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
assert.PanicsWithValue(t, expectedPanicValue, func() {
_, _ = s.extractServersFromBytes(b, hardcoded)
})
})
}

View File

@@ -7,27 +7,11 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
func countServers(allServers models.AllServers) int {
return len(allServers.Cyberghost.Servers) +
len(allServers.Expressvpn.Servers) +
len(allServers.Fastestvpn.Servers) +
len(allServers.HideMyAss.Servers) +
len(allServers.Ipvanish.Servers) +
len(allServers.Ivpn.Servers) +
len(allServers.Mullvad.Servers) +
len(allServers.Nordvpn.Servers) +
len(allServers.Perfectprivacy.Servers) +
len(allServers.Privado.Servers) +
len(allServers.Pia.Servers) +
len(allServers.Privatevpn.Servers) +
len(allServers.Protonvpn.Servers) +
len(allServers.Purevpn.Servers) +
len(allServers.Surfshark.Servers) +
len(allServers.Torguard.Servers) +
len(allServers.VPNUnlimited.Servers) +
len(allServers.Vyprvpn.Servers) +
len(allServers.Wevpn.Servers) +
len(allServers.Windscribe.Servers)
func countServers(allServers models.AllServers) (count int) {
for _, servers := range allServers.ProviderToServers {
count += len(servers.Servers)
}
return count
}
func (s *Storage) SyncServers() (err error) {
@@ -57,7 +41,7 @@ func (s *Storage) SyncServers() (err error) {
return nil
}
if err := flushToFile(s.filepath, s.mergedServers); err != nil {
if err := flushToFile(s.filepath, &s.mergedServers); err != nil {
return fmt.Errorf("cannot write servers to file: %w", err)
}
return nil

View File

@@ -139,7 +139,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
l.stopped <- struct{}{}
case servers := <-serversCh:
l.setAllServers(servers)
if err := l.flusher.FlushToFile(servers); err != nil {
if err := l.flusher.FlushToFile(&servers); err != nil {
l.logger.Error(err.Error())
}
runWg.Wait()

View File

@@ -4,8 +4,10 @@ import (
"context"
"fmt"
"reflect"
"time"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/updater/providers/cyberghost"
"github.com/qdm12/gluetun/internal/updater/providers/expressvpn"
"github.com/qdm12/gluetun/internal/updater/providers/fastestvpn"
@@ -28,419 +30,96 @@ import (
"github.com/qdm12/gluetun/internal/updater/providers/windscribe"
)
func (u *updater) updateCyberghost(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Cyberghost.Servers))
servers, err := cyberghost.GetServers(ctx, u.presolver, minServers)
func (u *updater) updateProvider(ctx context.Context, provider string) (
warnings []string, err error) {
existingServers := u.getProviderServers(provider)
minServers := getMinServers(existingServers)
servers, warnings, err := u.getServers(ctx, provider, minServers)
if err != nil {
return err
return warnings, err
}
if reflect.DeepEqual(u.servers.Cyberghost.Servers, servers) {
return nil
if reflect.DeepEqual(existingServers, servers) {
return warnings, nil
}
u.servers.Cyberghost.Timestamp = u.timeNow().Unix()
u.servers.Cyberghost.Servers = servers
return nil
u.patchProvider(provider, servers)
return warnings, nil
}
func (u *updater) updateExpressvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Expressvpn.Servers))
servers, warnings, err := expressvpn.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("ExpressVPN: " + warning)
func (u *updater) getServers(ctx context.Context, provider string,
minServers int) (servers []models.Server, warnings []string, err error) {
switch provider {
case providers.Custom:
panic("cannot update custom provider")
case providers.Cyberghost:
servers, err = cyberghost.GetServers(ctx, u.presolver, minServers)
return servers, nil, err
case providers.Expressvpn:
return expressvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Fastestvpn:
return fastestvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.HideMyAss:
return hidemyass.GetServers(ctx, u.client, u.presolver, minServers)
case providers.Ipvanish:
return ipvanish.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Ivpn:
return ivpn.GetServers(ctx, u.client, u.presolver, minServers)
case providers.Mullvad:
servers, err = mullvad.GetServers(ctx, u.client, minServers)
return servers, nil, err
case providers.Nordvpn:
return nordvpn.GetServers(ctx, u.client, minServers)
case providers.Perfectprivacy:
return perfectprivacy.GetServers(ctx, u.unzipper, minServers)
case providers.Privado:
return privado.GetServers(ctx, u.unzipper, u.client, u.presolver, minServers)
case providers.PrivateInternetAccess:
servers, err = pia.GetServers(ctx, u.client, minServers)
return servers, nil, err
case providers.Privatevpn:
return privatevpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Protonvpn:
return protonvpn.GetServers(ctx, u.client, minServers)
case providers.Purevpn:
return purevpn.GetServers(ctx, u.client, u.unzipper, u.presolver, minServers)
case providers.Surfshark:
return surfshark.GetServers(ctx, u.unzipper, u.client, u.presolver, minServers)
case providers.Torguard:
return torguard.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.VPNUnlimited:
return vpnunlimited.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Vyprvpn:
return vyprvpn.GetServers(ctx, u.unzipper, u.presolver, minServers)
case providers.Wevpn:
return wevpn.GetServers(ctx, u.presolver, minServers)
case providers.Windscribe:
servers, err = windscribe.GetServers(ctx, u.client, minServers)
return servers, nil, err
default:
panic("provider " + provider + " is unknown")
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Expressvpn.Servers, servers) {
return nil
}
u.servers.Expressvpn.Timestamp = u.timeNow().Unix()
u.servers.Expressvpn.Servers = servers
return nil
}
func (u *updater) updateFastestvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Fastestvpn.Servers))
servers, warnings, err := fastestvpn.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("FastestVPN: " + warning)
func (u *updater) getProviderServers(provider string) (servers []models.Server) {
providerServers, ok := u.servers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s is unknown", provider))
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Fastestvpn.Servers, servers) {
return nil
}
u.servers.Fastestvpn.Timestamp = u.timeNow().Unix()
u.servers.Fastestvpn.Servers = servers
return nil
return providerServers.Servers
}
func (u *updater) updateHideMyAss(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.HideMyAss.Servers))
servers, warnings, err := hidemyass.GetServers(
ctx, u.client, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("HideMyAss: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.HideMyAss.Servers, servers) {
return nil
}
u.servers.HideMyAss.Timestamp = u.timeNow().Unix()
u.servers.HideMyAss.Servers = servers
return nil
}
func (u *updater) updateIpvanish(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Ipvanish.Servers))
servers, warnings, err := ipvanish.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Ipvanish: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Ipvanish.Servers, servers) {
return nil
}
u.servers.Ipvanish.Timestamp = u.timeNow().Unix()
u.servers.Ipvanish.Servers = servers
return nil
}
func (u *updater) updateIvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Ivpn.Servers))
servers, warnings, err := ivpn.GetServers(
ctx, u.client, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Ivpn: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Ivpn.Servers, servers) {
return nil
}
u.servers.Ivpn.Timestamp = u.timeNow().Unix()
u.servers.Ivpn.Servers = servers
return nil
}
func (u *updater) updateMullvad(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Mullvad.Servers))
servers, err := mullvad.GetServers(ctx, u.client, minServers)
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Mullvad.Servers, servers) {
return nil
}
u.servers.Mullvad.Timestamp = u.timeNow().Unix()
u.servers.Mullvad.Servers = servers
return nil
}
func (u *updater) updateNordvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Nordvpn.Servers))
servers, warnings, err := nordvpn.GetServers(ctx, u.client, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("NordVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Nordvpn.Servers, servers) {
return nil
}
u.servers.Nordvpn.Timestamp = u.timeNow().Unix()
u.servers.Nordvpn.Servers = servers
return nil
}
func (u *updater) updatePerfectprivacy(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Perfectprivacy.Servers))
servers, warnings, err := perfectprivacy.GetServers(ctx, u.unzipper, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn(providers.Perfectprivacy + ": " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Perfectprivacy.Servers, servers) {
return nil
}
u.servers.Perfectprivacy.Timestamp = u.timeNow().Unix()
u.servers.Perfectprivacy.Servers = servers
return nil
}
func (u *updater) updatePIA(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Pia.Servers))
servers, err := pia.GetServers(ctx, u.client, minServers)
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Pia.Servers, servers) {
return nil
}
u.servers.Pia.Timestamp = u.timeNow().Unix()
u.servers.Pia.Servers = servers
return nil
}
func (u *updater) updatePrivado(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Privado.Servers))
servers, warnings, err := privado.GetServers(
ctx, u.unzipper, u.client, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Privado: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Privado.Servers, servers) {
return nil
}
u.servers.Privado.Timestamp = u.timeNow().Unix()
u.servers.Privado.Servers = servers
return nil
}
func (u *updater) updatePrivatevpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Privatevpn.Servers))
servers, warnings, err := privatevpn.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("PrivateVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Privatevpn.Servers, servers) {
return nil
}
u.servers.Privatevpn.Timestamp = u.timeNow().Unix()
u.servers.Privatevpn.Servers = servers
return nil
}
func (u *updater) updateProtonvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Privatevpn.Servers))
servers, warnings, err := protonvpn.GetServers(ctx, u.client, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("ProtonVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Protonvpn.Servers, servers) {
return nil
}
u.servers.Protonvpn.Timestamp = u.timeNow().Unix()
u.servers.Protonvpn.Servers = servers
return nil
}
func (u *updater) updatePurevpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Purevpn.Servers))
servers, warnings, err := purevpn.GetServers(
ctx, u.client, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("PureVPN: " + warning)
}
}
if err != nil {
return fmt.Errorf("cannot update Purevpn servers: %w", err)
}
if reflect.DeepEqual(u.servers.Purevpn.Servers, servers) {
return nil
}
u.servers.Purevpn.Timestamp = u.timeNow().Unix()
u.servers.Purevpn.Servers = servers
return nil
}
func (u *updater) updateSurfshark(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Surfshark.Servers))
servers, warnings, err := surfshark.GetServers(
ctx, u.unzipper, u.client, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Surfshark: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Surfshark.Servers, servers) {
return nil
}
u.servers.Surfshark.Timestamp = u.timeNow().Unix()
u.servers.Surfshark.Servers = servers
return nil
}
func (u *updater) updateTorguard(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Torguard.Servers))
servers, warnings, err := torguard.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Torguard: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Torguard.Servers, servers) {
return nil
}
u.servers.Torguard.Timestamp = u.timeNow().Unix()
u.servers.Torguard.Servers = servers
return nil
}
func (u *updater) updateVPNUnlimited(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.VPNUnlimited.Servers))
servers, warnings, err := vpnunlimited.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn(providers.VPNUnlimited + ": " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.VPNUnlimited.Servers, servers) {
return nil
}
u.servers.VPNUnlimited.Timestamp = u.timeNow().Unix()
u.servers.VPNUnlimited.Servers = servers
return nil
}
func (u *updater) updateVyprvpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Vyprvpn.Servers))
servers, warnings, err := vyprvpn.GetServers(
ctx, u.unzipper, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("VyprVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Vyprvpn.Servers, servers) {
return nil
}
u.servers.Vyprvpn.Timestamp = u.timeNow().Unix()
u.servers.Vyprvpn.Servers = servers
return nil
}
func (u *updater) updateWevpn(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Wevpn.Servers))
servers, warnings, err := wevpn.GetServers(ctx, u.presolver, minServers)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("WeVPN: " + warning)
}
}
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Wevpn.Servers, servers) {
return nil
}
u.servers.Wevpn.Timestamp = u.timeNow().Unix()
u.servers.Wevpn.Servers = servers
return nil
}
func (u *updater) updateWindscribe(ctx context.Context) (err error) {
minServers := getMinServers(len(u.servers.Windscribe.Servers))
servers, err := windscribe.GetServers(ctx, u.client, minServers)
if err != nil {
return err
}
if reflect.DeepEqual(u.servers.Windscribe.Servers, servers) {
return nil
}
u.servers.Windscribe.Timestamp = u.timeNow().Unix()
u.servers.Windscribe.Servers = servers
return nil
}
func getMinServers(existingServers int) (minServers int) {
func getMinServers(servers []models.Server) (minServers int) {
const minRatio = 0.8
return int(minRatio * float64(existingServers))
return int(minRatio * float64(len(servers)))
}
func (u *updater) patchProvider(provider string, servers []models.Server) {
providerServers, ok := u.servers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s is unknown", provider))
}
providerServers.Timestamp = time.Now().Unix()
providerServers.Servers = servers
u.servers.ProviderToServers[provider] = providerServers
}

View File

@@ -8,7 +8,6 @@ import (
"time"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/updater/resolver"
"github.com/qdm12/gluetun/internal/updater/unzip"
@@ -47,16 +46,17 @@ func New(settings settings.Updater, httpClient *http.Client,
}
}
type updateFunc func(ctx context.Context) (err error)
func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) {
for _, provider := range u.options.Providers {
u.logger.Info("updating " + strings.Title(provider) + " servers...")
updateProvider := u.getUpdateFunction(provider)
// TODO support servers offering only TCP or only UDP
// for NordVPN and PureVPN
err = updateProvider(ctx)
warnings, err := u.updateProvider(ctx, provider)
if *u.options.CLI {
for _, warning := range warnings {
u.logger.Warn(provider + ": " + warning)
}
}
if err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr
@@ -67,52 +67,3 @@ func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServe
return u.servers, nil
}
func (u *updater) getUpdateFunction(provider string) (updateFunction updateFunc) {
switch provider {
case providers.Custom:
panic("cannot update custom provider")
case providers.Cyberghost:
return func(ctx context.Context) (err error) { return u.updateCyberghost(ctx) }
case providers.Expressvpn:
return func(ctx context.Context) (err error) { return u.updateExpressvpn(ctx) }
case providers.Fastestvpn:
return func(ctx context.Context) (err error) { return u.updateFastestvpn(ctx) }
case providers.HideMyAss:
return func(ctx context.Context) (err error) { return u.updateHideMyAss(ctx) }
case providers.Ipvanish:
return func(ctx context.Context) (err error) { return u.updateIpvanish(ctx) }
case providers.Ivpn:
return func(ctx context.Context) (err error) { return u.updateIvpn(ctx) }
case providers.Mullvad:
return func(ctx context.Context) (err error) { return u.updateMullvad(ctx) }
case providers.Nordvpn:
return func(ctx context.Context) (err error) { return u.updateNordvpn(ctx) }
case providers.Perfectprivacy:
return func(ctx context.Context) (err error) { return u.updatePerfectprivacy(ctx) }
case providers.Privado:
return func(ctx context.Context) (err error) { return u.updatePrivado(ctx) }
case providers.PrivateInternetAccess:
return func(ctx context.Context) (err error) { return u.updatePIA(ctx) }
case providers.Privatevpn:
return func(ctx context.Context) (err error) { return u.updatePrivatevpn(ctx) }
case providers.Protonvpn:
return func(ctx context.Context) (err error) { return u.updateProtonvpn(ctx) }
case providers.Purevpn:
return func(ctx context.Context) (err error) { return u.updatePurevpn(ctx) }
case providers.Surfshark:
return func(ctx context.Context) (err error) { return u.updateSurfshark(ctx) }
case providers.Torguard:
return func(ctx context.Context) (err error) { return u.updateTorguard(ctx) }
case providers.VPNUnlimited:
return func(ctx context.Context) (err error) { return u.updateVPNUnlimited(ctx) }
case providers.Vyprvpn:
return func(ctx context.Context) (err error) { return u.updateVyprvpn(ctx) }
case providers.Wevpn:
return func(ctx context.Context) (err error) { return u.updateWevpn(ctx) }
case providers.Windscribe:
return func(ctx context.Context) (err error) { return u.updateWindscribe(ctx) }
default:
panic("provider " + provider + " is unknown")
}
}