chore(all): memory and thread safe storage

- settings: get filter choices from storage for settings validation
- updater: update servers to the storage
- storage: minimal deep copying and data duplication
- storage: add merged servers mutex for thread safety
- connection: filter servers in storage
- formatter: format servers to Markdown in storage
- PIA: get server by name from storage directly
- Updater: get servers count from storage directly
- Updater: equality check done in storage, fix #882
This commit is contained in:
Quentin McGaw
2022-06-05 14:58:46 +00:00
parent 1e6b4ed5eb
commit 36b504609b
84 changed files with 1267 additions and 877 deletions

View File

@@ -215,9 +215,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return err return err
} }
allServers := storage.GetServers() err = allSettings.Validate(storage)
err = allSettings.Validate(allServers)
if err != nil { if err != nil {
return err return err
} }
@@ -378,7 +376,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
vpnLogger := logger.New(log.SetComponent("vpn")) vpnLogger := logger.New(log.SetComponent("vpn"))
vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.Firewall.VPNInputPorts, vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.Firewall.VPNInputPorts,
allServers, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper, storage, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper,
cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient, cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient,
buildInfo, *allSettings.Version.Enabled) buildInfo, *allSettings.Version.Enabled)
vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler( vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler(
@@ -386,8 +384,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
go vpnLooper.Run(vpnCtx, vpnDone) go vpnLooper.Run(vpnCtx, vpnDone)
updaterLooper := updater.NewLooper(allSettings.Updater, updaterLooper := updater.NewLooper(allSettings.Updater,
allServers, storage, vpnLooper.SetServers, httpClient, storage, httpClient, logger.New(log.SetComponent("updater")))
logger.New(log.SetComponent("updater")))
updaterHandler, updaterCtx, updaterDone := goshutdown.NewGoRoutineHandler( updaterHandler, updaterCtx, updaterDone := goshutdown.NewGoRoutineHandler(
"updater", goroutine.OptionTimeout(defaultShutdownTimeout)) "updater", goroutine.OptionTimeout(defaultShutdownTimeout))
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker // wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker

View File

@@ -10,7 +10,6 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"golang.org/x/text/cases" "golang.org/x/text/cases"
"golang.org/x/text/language" "golang.org/x/text/language"
@@ -80,9 +79,8 @@ func (c *CLI) FormatServers(args []string) error {
if err != nil { if err != nil {
return fmt.Errorf("cannot create servers storage: %w", err) return fmt.Errorf("cannot create servers storage: %w", err)
} }
currentServers := storage.GetServers()
formatted := formatServers(currentServers, providerToFormat) formatted := storage.FormatToMarkdown(providerToFormat)
output = filepath.Clean(output) output = filepath.Clean(output)
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644) file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
@@ -103,11 +101,3 @@ func (c *CLI) FormatServers(args []string) error {
return nil return nil
} }
func formatServers(allServers models.AllServers, provider string) (formatted string) {
servers, ok := allServers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
}
return servers.ToMarkdown(provider)
}

View File

@@ -25,18 +25,17 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) e
if err != nil { if err != nil {
return err return err
} }
allServers := storage.GetServers()
allSettings, err := source.Read() allSettings, err := source.Read()
if err != nil { if err != nil {
return err return err
} }
if err = allSettings.Validate(allServers); err != nil { if err = allSettings.Validate(storage); err != nil {
return err return err
} }
providerConf := provider.New(*allSettings.VPN.Provider.Name, allServers, time.Now) providerConf := provider.New(*allSettings.VPN.Provider.Name, storage, time.Now)
connection, err := providerConf.GetConnection(allSettings.VPN.Provider.ServerSelection) connection, err := providerConf.GetConnection(allSettings.VPN.Provider.ServerSelection)
if err != nil { if err != nil {
return err return err

View File

@@ -2,20 +2,17 @@ package cli
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"os"
"strings" "strings"
"time" "time"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
) )
@@ -83,41 +80,19 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
if err != nil { if err != nil {
return fmt.Errorf("cannot create servers storage: %w", err) return fmt.Errorf("cannot create servers storage: %w", err)
} }
currentServers := storage.GetServers()
updater := updater.New(options, httpClient, currentServers, logger) updater := updater.New(options, httpClient, storage, logger)
allServers, err := updater.UpdateServers(ctx) err = updater.UpdateServers(ctx)
if err != nil { if err != nil {
return fmt.Errorf("cannot update server information: %w", err) return fmt.Errorf("cannot update server information: %w", err)
} }
if endUserMode {
if err := storage.FlushToFile(&allServers); err != nil {
return fmt.Errorf("cannot write updated information to file: %w", err)
}
}
if maintainerMode { if maintainerMode {
if err := writeToEmbeddedJSON(c.repoServersPath, &allServers); err != nil { err := storage.FlushToFile(c.repoServersPath)
return fmt.Errorf("cannot write updated information to file: %w", err) if err != nil {
return fmt.Errorf("cannot write servers data to embedded JSON file: %w", err)
} }
} }
return nil return nil
} }
func writeToEmbeddedJSON(repoServersPath string,
allServers *models.AllServers) error {
const perms = 0600
f, err := os.OpenFile(repoServersPath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)
if err != nil {
return err
}
defer f.Close()
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
return encoder.Encode(allServers)
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gotree" "github.com/qdm12/gotree"
) )
@@ -23,7 +22,7 @@ type Provider struct {
} }
// TODO v4 remove pointer for receiver (because of Surfshark). // TODO v4 remove pointer for receiver (because of Surfshark).
func (p *Provider) validate(vpnType string, allServers models.AllServers) (err error) { func (p *Provider) validate(vpnType string, storage Storage) (err error) {
// Validate Name // Validate Name
var validNames []string var validNames []string
if vpnType == vpn.OpenVPN { if vpnType == vpn.OpenVPN {
@@ -42,7 +41,7 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e
ErrVPNProviderNameNotValid, *p.Name, helpers.ChoicesOrString(validNames)) ErrVPNProviderNameNotValid, *p.Name, helpers.ChoicesOrString(validNames))
} }
err = p.ServerSelection.validate(*p.Name, allServers) err = p.ServerSelection.validate(*p.Name, storage)
if err != nil { if err != nil {
return fmt.Errorf("server selection: %w", err) return fmt.Errorf("server selection: %w", err)
} }

View File

@@ -68,21 +68,19 @@ var (
) )
func (ss *ServerSelection) validate(vpnServiceProvider string, func (ss *ServerSelection) validate(vpnServiceProvider string,
allServers models.AllServers) (err error) { storage Storage) (err error) {
switch ss.VPN { switch ss.VPN {
case vpn.OpenVPN, vpn.Wireguard: case vpn.OpenVPN, vpn.Wireguard:
default: default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN) return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
} }
countryChoices, regionChoices, cityChoices, filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage)
ispChoices, nameChoices, hostnameChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, allServers)
if err != nil { if err != nil {
return err // already wrapped error return err // already wrapped error
} }
err = validateServerFilters(*ss, countryChoices, regionChoices, cityChoices, err = validateServerFilters(*ss, filterChoices)
ispChoices, nameChoices, hostnameChoices)
if err != nil { if err != nil {
if errors.Is(err, helpers.ErrNoChoice) { if errors.Is(err, helpers.ErrNoChoice) {
return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err) return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err)
@@ -135,63 +133,48 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
return nil return nil
} }
func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection, func getLocationFilterChoices(vpnServiceProvider string,
allServers models.AllServers) ( ss *ServerSelection, storage Storage) (filterChoices models.FilterChoices,
countryChoices, regionChoices, cityChoices,
ispChoices, nameChoices, hostnameChoices []string,
err error) { err error) {
providerServers, ok := allServers.ProviderToServers[vpnServiceProvider] filterChoices = storage.GetFilterChoices(vpnServiceProvider)
if !ok && vpnServiceProvider != providers.Custom {
panic(fmt.Sprintf("VPN service provider unknown: %s", vpnServiceProvider))
}
servers := providerServers.Servers
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
ispChoices = validation.ExtractISPs(servers)
nameChoices = validation.ExtractServerNames(servers)
hostnameChoices = validation.ExtractHostnames(servers)
if vpnServiceProvider == providers.Surfshark { if vpnServiceProvider == providers.Surfshark {
// // Retro compatibility // // Retro compatibility
// TODO v4 remove // TODO v4 remove
regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...) filterChoices.Regions = append(filterChoices.Regions, validation.SurfsharkRetroLocChoices()...)
if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil { if err := helpers.AreAllOneOf(ss.Regions, filterChoices.Regions); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err) return models.FilterChoices{}, fmt.Errorf("%w: %s", ErrRegionNotValid, err)
} }
*ss = surfsharkRetroRegion(*ss) *ss = surfsharkRetroRegion(*ss)
} }
return countryChoices, regionChoices, cityChoices, return filterChoices, nil
ispChoices, nameChoices, hostnameChoices, nil
} }
// validateServerFilters validates filters against the choices given as arguments. // validateServerFilters validates filters against the choices given as arguments.
// Set an argument to nil to pass the check for a particular filter. // Set an argument to nil to pass the check for a particular filter.
func validateServerFilters(settings ServerSelection, func validateServerFilters(settings ServerSelection, filterChoices models.FilterChoices) (err error) {
countryChoices, regionChoices, cityChoices, ispChoices, if err := helpers.AreAllOneOf(settings.Countries, filterChoices.Countries); err != nil {
nameChoices, hostnameChoices []string) (err error) {
if err := helpers.AreAllOneOf(settings.Countries, countryChoices); err != nil {
return fmt.Errorf("%w: %s", ErrCountryNotValid, err) return fmt.Errorf("%w: %s", ErrCountryNotValid, err)
} }
if err := helpers.AreAllOneOf(settings.Regions, regionChoices); err != nil { if err := helpers.AreAllOneOf(settings.Regions, filterChoices.Regions); err != nil {
return fmt.Errorf("%w: %s", ErrRegionNotValid, err) return fmt.Errorf("%w: %s", ErrRegionNotValid, err)
} }
if err := helpers.AreAllOneOf(settings.Cities, cityChoices); err != nil { if err := helpers.AreAllOneOf(settings.Cities, filterChoices.Cities); err != nil {
return fmt.Errorf("%w: %s", ErrCityNotValid, err) return fmt.Errorf("%w: %s", ErrCityNotValid, err)
} }
if err := helpers.AreAllOneOf(settings.ISPs, ispChoices); err != nil { if err := helpers.AreAllOneOf(settings.ISPs, filterChoices.ISPs); err != nil {
return fmt.Errorf("%w: %s", ErrISPNotValid, err) return fmt.Errorf("%w: %s", ErrISPNotValid, err)
} }
if err := helpers.AreAllOneOf(settings.Hostnames, hostnameChoices); err != nil { if err := helpers.AreAllOneOf(settings.Hostnames, filterChoices.Hostnames); err != nil {
return fmt.Errorf("%w: %s", ErrHostnameNotValid, err) return fmt.Errorf("%w: %s", ErrHostnameNotValid, err)
} }
if err := helpers.AreAllOneOf(settings.Names, nameChoices); err != nil { if err := helpers.AreAllOneOf(settings.Names, filterChoices.Names); err != nil {
return fmt.Errorf("%w: %s", ErrNameNotValid, err) return fmt.Errorf("%w: %s", ErrNameNotValid, err)
} }

View File

@@ -24,10 +24,14 @@ type Settings struct {
Pprof pprof.Settings Pprof pprof.Settings
} }
type Storage interface {
GetFilterChoices(provider string) models.FilterChoices
}
// Validate validates all the settings and returns an error // Validate validates all the settings and returns an error
// if one of them is not valid. // if one of them is not valid.
// TODO v4 remove pointer for receiver (because of Surfshark). // TODO v4 remove pointer for receiver (because of Surfshark).
func (s *Settings) Validate(allServers models.AllServers) (err error) { func (s *Settings) Validate(storage Storage) (err error) {
nameToValidation := map[string]func() error{ nameToValidation := map[string]func() error{
"control server": s.ControlServer.validate, "control server": s.ControlServer.validate,
"dns": s.DNS.validate, "dns": s.DNS.validate,
@@ -42,7 +46,7 @@ func (s *Settings) Validate(allServers models.AllServers) (err error) {
"version": s.Version.validate, "version": s.Version.validate,
// Pprof validation done in pprof constructor // Pprof validation done in pprof constructor
"VPN": func() error { "VPN": func() error {
return s.VPN.validate(allServers) return s.VPN.validate(storage)
}, },
} }
@@ -91,7 +95,7 @@ func (s *Settings) MergeWith(other Settings) {
} }
func (s *Settings) OverrideWith(other Settings, func (s *Settings) OverrideWith(other Settings,
allServers models.AllServers) (err error) { storage Storage) (err error) {
patchedSettings := s.copy() patchedSettings := s.copy()
patchedSettings.ControlServer.overrideWith(other.ControlServer) patchedSettings.ControlServer.overrideWith(other.ControlServer)
patchedSettings.DNS.overrideWith(other.DNS) patchedSettings.DNS.overrideWith(other.DNS)
@@ -106,7 +110,7 @@ func (s *Settings) OverrideWith(other Settings,
patchedSettings.Version.overrideWith(other.Version) patchedSettings.Version.overrideWith(other.Version)
patchedSettings.VPN.overrideWith(other.VPN) patchedSettings.VPN.overrideWith(other.VPN)
patchedSettings.Pprof.MergeWith(other.Pprof) patchedSettings.Pprof.MergeWith(other.Pprof)
err = patchedSettings.Validate(allServers) err = patchedSettings.Validate(storage)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -35,17 +35,18 @@ func (u Updater) Validate() (err error) {
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod) ErrUpdaterPeriodTooSmall, *u.Period, minPeriod)
} }
for i, provider := range u.Providers { validProviders := providers.All()
for _, provider := range u.Providers {
valid := false valid := false
for _, validProvider := range providers.All() { for _, validProvider := range validProviders {
if provider == validProvider { if provider == validProvider {
valid = true valid = true
break break
} }
} }
if !valid { if !valid {
return fmt.Errorf("%w: %s at index %d", return fmt.Errorf("%w: %q can only be one of %s",
ErrVPNProviderNameNotValid, provider, i) ErrVPNProviderNameNotValid, provider, helpers.ChoicesOrString(validProviders))
} }
} }

View File

@@ -6,7 +6,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gotree" "github.com/qdm12/gotree"
) )
@@ -21,7 +20,7 @@ type VPN struct {
} }
// TODO v4 remove pointer for receiver (because of Surfshark). // TODO v4 remove pointer for receiver (because of Surfshark).
func (v *VPN) validate(allServers models.AllServers) (err error) { func (v *VPN) validate(storage Storage) (err error) {
// Validate Type // Validate Type
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard} validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
if !helpers.IsOneOf(v.Type, validVPNTypes...) { if !helpers.IsOneOf(v.Type, validVPNTypes...) {
@@ -29,7 +28,7 @@ func (v *VPN) validate(allServers models.AllServers) (err error) {
ErrVPNTypeNotValid, v.Type, strings.Join(validVPNTypes, ", ")) ErrVPNTypeNotValid, v.Type, strings.Join(validVPNTypes, ", "))
} }
err = v.Provider.validate(v.Type, allServers) err = v.Provider.validate(v.Type, storage)
if err != nil { if err != nil {
return fmt.Errorf("provider settings: %w", err) return fmt.Errorf("provider settings: %w", err)
} }

View File

@@ -0,0 +1,10 @@
package models
type FilterChoices struct {
Countries []string
Regions []string
Cities []string
ISPs []string
Names []string
Hostnames []string
}

View File

@@ -1,51 +0,0 @@
package models
import (
"net"
)
func (a AllServers) GetCopy() (allServersCopy AllServers) {
allServersCopy.Version = a.Version
allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers))
for provider, servers := range a.ProviderToServers {
allServersCopy.ProviderToServers[provider] = Servers{
Version: servers.Version,
Timestamp: servers.Timestamp,
Servers: copyServers(servers.Servers),
}
}
return allServersCopy
}
func copyServers(servers []Server) (serversCopy []Server) {
if servers == nil {
return nil
}
serversCopy = make([]Server, len(servers))
for i, server := range servers {
serversCopy[i] = server
serversCopy[i].IPs = copyIPs(server.IPs)
}
return serversCopy
}
func copyIPs(toCopy []net.IP) (copied []net.IP) {
if toCopy == nil {
return nil
}
copied = make([]net.IP, len(toCopy))
for i := range toCopy {
copied[i] = copyIP(toCopy[i])
}
return copied
}
func copyIP(toCopy net.IP) (copied net.IP) {
copied = make(net.IP, len(toCopy))
copy(copied, toCopy)
return copied
}

View File

@@ -1,173 +0,0 @@
package models
import (
"net"
"testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AllServers_GetCopy(t *testing.T) {
allServers := AllServers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 2,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Expressvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Fastestvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.HideMyAss: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Ipvanish: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Ivpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Mullvad: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Nordvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Perfectprivacy: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Privado: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.PrivateInternetAccess: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Privatevpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Protonvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Purevpn: {
Version: 1,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Surfshark: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Torguard: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.VPNUnlimited: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Vyprvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Wevpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Windscribe: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
},
}
servers := allServers.GetCopy()
assert.Equal(t, allServers, servers)
}
func Test_copyIPs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
toCopy []net.IP
copied []net.IP
}{
"nil": {},
"empty": {
toCopy: []net.IP{},
copied: []net.IP{},
},
"single IP": {
toCopy: []net.IP{{1, 1, 1, 1}},
copied: []net.IP{{1, 1, 1, 1}},
},
"two IPs": {
toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
// Reserver leading 9 for copy modifications below
for _, ipToCopy := range testCase.toCopy {
require.NotEqual(t, 9, ipToCopy[0])
}
copied := copyIPs(testCase.toCopy)
assert.Equal(t, testCase.copied, copied)
if len(copied) > 0 {
original := testCase.toCopy[0][0]
testCase.toCopy[0][0] = 9
assert.NotEqual(t, 9, copied[0][0])
testCase.toCopy[0][0] = original
copied[0][0] = 9
assert.NotEqual(t, 9, testCase.toCopy[0][0])
}
})
}
}

View File

@@ -2,6 +2,7 @@ package models
import ( import (
"net" "net"
"reflect"
) )
type Server struct { type Server struct {
@@ -26,3 +27,28 @@ type Server struct {
PortForward bool `json:"port_forward,omitempty"` PortForward bool `json:"port_forward,omitempty"`
IPs []net.IP `json:"ips,omitempty"` IPs []net.IP `json:"ips,omitempty"`
} }
func (s *Server) Equal(other Server) (equal bool) {
if !ipsAreEqual(s.IPs, other.IPs) {
return false
}
serverCopy := *s
serverCopy.IPs = nil
other.IPs = nil
return reflect.DeepEqual(serverCopy, other)
}
func ipsAreEqual(a, b []net.IP) (equal bool) {
if len(a) != len(b) {
return false
}
for i := range a {
if !a[i].Equal(b[i]) {
return false
}
}
return true
}

View File

@@ -0,0 +1,120 @@
package models
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_Server_Equal(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
a *Server
b Server
equal bool
}{
"same IPs": {
a: &Server{
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
b: Server{
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
equal: true,
},
"same IP strings": {
a: &Server{
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
b: Server{
IPs: []net.IP{{1, 2, 3, 4}},
},
equal: true,
},
"different IPs": {
a: &Server{
IPs: []net.IP{{1, 2, 3, 4}, {2, 3, 4, 5}},
},
b: Server{
IPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 4}},
},
},
"all fields equal": {
a: &Server{
VPN: "vpn",
Country: "country",
Region: "region",
City: "city",
ISP: "isp",
Owned: true,
Number: 1,
ServerName: "server_name",
Hostname: "hostname",
TCP: true,
UDP: true,
OvpnX509: "x509",
RetroLoc: "retroloc",
MultiHop: true,
WgPubKey: "wgpubkey",
Free: true,
Stream: true,
PortForward: true,
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
b: Server{
VPN: "vpn",
Country: "country",
Region: "region",
City: "city",
ISP: "isp",
Owned: true,
Number: 1,
ServerName: "server_name",
Hostname: "hostname",
TCP: true,
UDP: true,
OvpnX509: "x509",
RetroLoc: "retroloc",
MultiHop: true,
WgPubKey: "wgpubkey",
Free: true,
Stream: true,
PortForward: true,
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
equal: true,
},
"different field": {
a: &Server{
VPN: "vpn",
},
b: Server{
VPN: "other vpn",
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ipsOfANotNil := testCase.a.IPs != nil
ipsOfBNotNil := testCase.b.IPs != nil
equal := testCase.a.Equal(testCase.b)
assert.Equal(t, testCase.equal, equal)
// Ensure IPs field is not modified
if ipsOfANotNil {
assert.NotNil(t, testCase.a)
}
if ipsOfBNotNil {
assert.NotNil(t, testCase.b)
}
})
}
}

View File

@@ -15,18 +15,6 @@ type AllServers struct {
ProviderToServers map[string]Servers ProviderToServers map[string]Servers
} }
func (a *AllServers) ServersSlice(provider string) []Server {
if provider == providers.Custom {
return nil
}
servers, ok := a.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in all servers", provider))
}
return copyServers(servers.Servers)
}
var _ json.Marshaler = (*AllServers)(nil) var _ json.Marshaler = (*AllServers)(nil)
// MarshalJSON marshals all servers to JSON. // MarshalJSON marshals all servers to JSON.

View File

@@ -0,0 +1,66 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/provider/common (interfaces: Storage)
// Package common is a generated GoMock package.
package common
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
settings "github.com/qdm12/gluetun/internal/configuration/settings"
models "github.com/qdm12/gluetun/internal/models"
)
// MockStorage is a mock of Storage interface.
type MockStorage struct {
ctrl *gomock.Controller
recorder *MockStorageMockRecorder
}
// MockStorageMockRecorder is the mock recorder for MockStorage.
type MockStorageMockRecorder struct {
mock *MockStorage
}
// NewMockStorage creates a new mock instance.
func NewMockStorage(ctrl *gomock.Controller) *MockStorage {
mock := &MockStorage{ctrl: ctrl}
mock.recorder = &MockStorageMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder
}
// FilterServers mocks base method.
func (m *MockStorage) FilterServers(arg0 string, arg1 settings.ServerSelection) ([]models.Server, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterServers", arg0, arg1)
ret0, _ := ret[0].([]models.Server)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FilterServers indicates an expected call of FilterServers.
func (mr *MockStorageMockRecorder) FilterServers(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterServers", reflect.TypeOf((*MockStorage)(nil).FilterServers), arg0, arg1)
}
// GetServerByName mocks base method.
func (m *MockStorage) GetServerByName(arg0, arg1 string) (models.Server, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServerByName", arg0, arg1)
ret0, _ := ret[0].(models.Server)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// GetServerByName indicates an expected call of GetServerByName.
func (mr *MockStorageMockRecorder) GetServerByName(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerByName", reflect.TypeOf((*MockStorage)(nil).GetServerByName), arg0, arg1)
}

View File

@@ -0,0 +1,5 @@
package common
// Exceptionally, the storage mock is exported since it is used by all
// provider subpackages tests, and it reduces test code duplication a lot.
//go:generate mockgen -destination=mocks.go -package $GOPACKAGE . Storage

View File

@@ -0,0 +1,12 @@
package common
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
)
type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}

View File

@@ -2,6 +2,7 @@ package cyberghost
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Cyberghost,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Cyberghost), NoPortForwarder: utils.NewNoPortForwarding(providers.Cyberghost),
} }

View File

@@ -2,6 +2,7 @@ package expressvpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 1195, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(0, 1195, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Expressvpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -1,41 +1,63 @@
package expressvpn package expressvpn
import ( import (
"errors"
"math/rand" "math/rand"
"net" "net"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Provider_GetConnection(t *testing.T) { func Test_Provider_GetConnection(t *testing.T) {
t.Parallel() t.Parallel()
const provider = providers.Expressvpn
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct { testCases := map[string]struct {
servers []models.Server filteredServers []models.Server
selection settings.ServerSelection storageErr error
connection models.Connection selection settings.ServerSelection
errWrapped error connection models.Connection
errMessage string errWrapped error
errMessage string
panicMessage string
}{ }{
"no server": { "error": {
selection: settings.ServerSelection{}.WithDefaults(providers.Expressvpn), storageErr: errTest,
errWrapped: utils.ErrNoServer, errWrapped: errTest,
errMessage: "no server", errMessage: "cannot filter servers: test error",
}, },
"no filter": { "default OpenVPN TCP port": {
servers: []models.Server{ filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
}, },
selection: settings.ServerSelection{}.WithDefaults(providers.Expressvpn), selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
panicMessage: "no default OpenVPN TCP port is defined!",
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{ connection: models.Connection{
Type: vpn.OpenVPN, Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1), IP: net.IPv4(1, 1, 1, 1),
@@ -43,38 +65,14 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
}, },
"target IP": { "default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{ selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2), VPN: vpn.Wireguard,
}.WithDefaults(providers.Expressvpn), }.WithDefaults(provider),
servers: []models.Server{ panicMessage: "no default Wireguard port is defined!",
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1195,
Protocol: constants.UDP,
},
},
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Expressvpn),
servers: []models.Server{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1195,
Protocol: constants.UDP,
Hostname: "b",
},
}, },
} }
@@ -82,12 +80,23 @@ func Test_Provider_GetConnection(t *testing.T) {
testCase := testCase testCase := testCase
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0) randSource := rand.NewSource(0)
m := New(testCase.servers, randSource) provider := New(storage, randSource)
connection, err := m.GetConnection(testCase.selection) if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
_, _ = provider.GetConnection(testCase.selection)
})
return
}
connection, err := provider.GetConnection(testCase.selection)
assert.ErrorIs(t, err, testCase.errWrapped) assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil { if testCase.errWrapped != nil {

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Expressvpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Expressvpn),
} }

View File

@@ -2,6 +2,7 @@ package fastestvpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(4443, 4443, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(4443, 4443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Fastestvpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Fastestvpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Fastestvpn),
} }

View File

@@ -2,6 +2,7 @@ package hidemyass
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(8080, 553, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(8080, 553, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.HideMyAss,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.HideMyAss), NoPortForwarder: utils.NewNoPortForwarding(providers.HideMyAss),
} }

View File

@@ -2,6 +2,7 @@ package ipvanish
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Ipvanish,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Ipvanish), NoPortForwarder: utils.NewNoPortForwarding(providers.Ipvanish),
} }

View File

@@ -2,6 +2,7 @@ package ivpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 58237) //nolint:gomnd defaults := utils.NewConnectionDefaults(443, 1194, 58237) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Ivpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -1,41 +1,67 @@
package ivpn package ivpn
import ( import (
"errors"
"math/rand" "math/rand"
"net" "net"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Provider_GetConnection(t *testing.T) { func Test_Provider_GetConnection(t *testing.T) {
t.Parallel() t.Parallel()
const provider = providers.Ivpn
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct { testCases := map[string]struct {
servers []models.Server filteredServers []models.Server
selection settings.ServerSelection storageErr error
connection models.Connection selection settings.ServerSelection
errWrapped error connection models.Connection
errMessage string errWrapped error
errMessage string
}{ }{
"no server available": { "error": {
selection: settings.ServerSelection{}.WithDefaults(providers.Ivpn), storageErr: errTest,
errWrapped: utils.ErrNoServer, errWrapped: errTest,
errMessage: "no server", errMessage: "cannot filter servers: test error",
}, },
"no filter": { "default OpenVPN TCP port": {
servers: []models.Server{ filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
}, },
selection: settings.ServerSelection{}.WithDefaults(providers.Ivpn), selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
Port: 443,
Protocol: constants.TCP,
},
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{ connection: models.Connection{
Type: vpn.OpenVPN, Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1), IP: net.IPv4(1, 1, 1, 1),
@@ -43,51 +69,36 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
}, },
"target IP": { "default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{ selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2), VPN: vpn.Wireguard,
}.WithDefaults(providers.Ivpn), }.WithDefaults(provider),
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{ connection: models.Connection{
Type: vpn.OpenVPN, Type: vpn.Wireguard,
IP: net.IPv4(2, 2, 2, 2), IP: net.IPv4(1, 1, 1, 1),
Port: 1194, Port: 58237,
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
}, },
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Ivpn),
servers: []models.Server{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
Hostname: "b",
},
},
} }
for name, testCase := range testCases { for name, testCase := range testCases {
testCase := testCase testCase := testCase
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0) randSource := rand.NewSource(0)
m := New(testCase.servers, randSource) provider := New(storage, randSource)
connection, err := m.GetConnection(testCase.selection) connection, err := provider.GetConnection(testCase.selection)
assert.ErrorIs(t, err, testCase.errWrapped) assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil { if testCase.errWrapped != nil {

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Ivpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Ivpn),
} }

View File

@@ -2,6 +2,7 @@ package mullvad
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 51820) //nolint:gomnd defaults := utils.NewConnectionDefaults(443, 1194, 51820) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Mullvad,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -1,41 +1,67 @@
package mullvad package mullvad
import ( import (
"errors"
"math/rand" "math/rand"
"net" "net"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Provider_GetConnection(t *testing.T) { func Test_Provider_GetConnection(t *testing.T) {
t.Parallel() t.Parallel()
const provider = providers.Mullvad
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct { testCases := map[string]struct {
servers []models.Server filteredServers []models.Server
selection settings.ServerSelection storageErr error
connection models.Connection selection settings.ServerSelection
errWrapped error connection models.Connection
errMessage string errWrapped error
errMessage string
}{ }{
"no server available": { "error": {
selection: settings.ServerSelection{}.WithDefaults(providers.Mullvad), storageErr: errTest,
errWrapped: utils.ErrNoServer, errWrapped: errTest,
errMessage: "no server", errMessage: "cannot filter servers: test error",
}, },
"no filter": { "default OpenVPN TCP port": {
servers: []models.Server{ filteredServers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
}, },
selection: settings.ServerSelection{}.WithDefaults(providers.Mullvad), selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
Port: 443,
Protocol: constants.TCP,
},
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{ connection: models.Connection{
Type: vpn.OpenVPN, Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1), IP: net.IPv4(1, 1, 1, 1),
@@ -43,36 +69,17 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
}, },
"target IP": { "default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{ selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2), VPN: vpn.Wireguard,
}.WithDefaults(providers.Mullvad), }.WithDefaults(provider),
servers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{ connection: models.Connection{
Type: vpn.OpenVPN, Type: vpn.Wireguard,
IP: net.IPv4(2, 2, 2, 2), IP: net.IPv4(1, 1, 1, 1),
Port: 1194, Port: 51820,
Protocol: constants.UDP,
},
},
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Mullvad),
servers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true, Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{VPN: vpn.OpenVPN, UDP: true, Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{VPN: vpn.OpenVPN, UDP: true, Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Hostname: "b",
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
}, },
@@ -82,12 +89,16 @@ func Test_Provider_GetConnection(t *testing.T) {
testCase := testCase testCase := testCase
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0) randSource := rand.NewSource(0)
m := New(testCase.servers, randSource) provider := New(storage, randSource)
connection, err := m.GetConnection(testCase.selection) connection, err := provider.GetConnection(testCase.selection)
assert.ErrorIs(t, err, testCase.errWrapped) assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil { if testCase.errWrapped != nil {

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Mullvad), NoPortForwarder: utils.NewNoPortForwarding(providers.Mullvad),
} }

View File

@@ -2,6 +2,7 @@ package nordvpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Nordvpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Nordvpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Nordvpn),
} }

View File

@@ -2,6 +2,7 @@ package perfectprivacy
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Perfectprivacy,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Perfectprivacy), NoPortForwarder: utils.NewNoPortForwarding(providers.Perfectprivacy),
} }

View File

@@ -2,6 +2,7 @@ package privado
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Privado,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Privado), NoPortForwarder: utils.NewNoPortForwarding(providers.Privado),
} }

View File

@@ -2,6 +2,7 @@ package privateinternetaccess
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/privateinternetaccess/presets" "github.com/qdm12/gluetun/internal/provider/privateinternetaccess/presets"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
@@ -20,5 +21,6 @@ func (p *Provider) GetConnection(selection settings.ServerSelection) (
defaults.OpenVPNUDPPort = 1197 defaults.OpenVPNUDPPort = 1197
} }
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.PrivateInternetAccess,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -15,25 +15,24 @@ import (
"strings" "strings"
"time" "time"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/golibs/format" "github.com/qdm12/golibs/format"
) )
var ( var (
ErrGatewayIPIsNil = errors.New("gateway IP address is nil") ErrServerNameNotFound = errors.New("server name not found in servers")
ErrServerNameEmpty = errors.New("server name is empty") ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
ErrServerNameEmpty = errors.New("server name is empty")
) )
// PortForward obtains a VPN server side port forwarded from PIA. // PortForward obtains a VPN server side port forwarded from PIA.
func (p *Provider) PortForward(ctx context.Context, client *http.Client, func (p *Provider) PortForward(ctx context.Context, client *http.Client,
logger utils.Logger, gateway net.IP, serverName string) ( logger utils.Logger, gateway net.IP, serverName string) (
port uint16, err error) { port uint16, err error) {
var server models.Server server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName)
for _, server = range p.servers { if !ok {
if server.ServerName == serverName { return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName)
break
}
} }
if !server.PortForward { if !server.PortForward {

View File

@@ -5,11 +5,11 @@ import (
"time" "time"
"github.com/qdm12/gluetun/internal/constants/openvpn" "github.com/qdm12/gluetun/internal/constants/openvpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
timeNow func() time.Time timeNow func() time.Time
// Port forwarding // Port forwarding
@@ -17,11 +17,11 @@ type Provider struct {
authFilePath string authFilePath string
} }
func New(servers []models.Server, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
timeNow func() time.Time) *Provider { timeNow func() time.Time) *Provider {
const jsonPortForwardPath = "/gluetun/piaportforward.json" const jsonPortForwardPath = "/gluetun/piaportforward.json"
return &Provider{ return &Provider{
servers: servers, storage: storage,
timeNow: timeNow, timeNow: timeNow,
randSource: randSource, randSource: randSource,
portForwardPath: jsonPortForwardPath, portForwardPath: jsonPortForwardPath,

View File

@@ -2,6 +2,7 @@ package privatevpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Privatevpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Privatevpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Privatevpn),
} }

View File

@@ -2,6 +2,7 @@ package protonvpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Protonvpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Protonvpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Protonvpn),
} }

View File

@@ -50,52 +50,57 @@ type PortForwarder interface {
port uint16, gateway net.IP, serverName string) (err error) port uint16, gateway net.IP, serverName string) (err error)
} }
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider { type Storage interface {
serversSlice := allServers.ServersSlice(provider) FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}
func New(provider string, storage Storage, timeNow func() time.Time) Provider {
randSource := rand.NewSource(timeNow().UnixNano()) randSource := rand.NewSource(timeNow().UnixNano())
switch provider { switch provider {
case providers.Custom: case providers.Custom:
return custom.New() return custom.New()
case providers.Cyberghost: case providers.Cyberghost:
return cyberghost.New(serversSlice, randSource) return cyberghost.New(storage, randSource)
case providers.Expressvpn: case providers.Expressvpn:
return expressvpn.New(serversSlice, randSource) return expressvpn.New(storage, randSource)
case providers.Fastestvpn: case providers.Fastestvpn:
return fastestvpn.New(serversSlice, randSource) return fastestvpn.New(storage, randSource)
case providers.HideMyAss: case providers.HideMyAss:
return hidemyass.New(serversSlice, randSource) return hidemyass.New(storage, randSource)
case providers.Ipvanish: case providers.Ipvanish:
return ipvanish.New(serversSlice, randSource) return ipvanish.New(storage, randSource)
case providers.Ivpn: case providers.Ivpn:
return ivpn.New(serversSlice, randSource) return ivpn.New(storage, randSource)
case providers.Mullvad: case providers.Mullvad:
return mullvad.New(serversSlice, randSource) return mullvad.New(storage, randSource)
case providers.Nordvpn: case providers.Nordvpn:
return nordvpn.New(serversSlice, randSource) return nordvpn.New(storage, randSource)
case providers.Perfectprivacy: case providers.Perfectprivacy:
return perfectprivacy.New(serversSlice, randSource) return perfectprivacy.New(storage, randSource)
case providers.Privado: case providers.Privado:
return privado.New(serversSlice, randSource) return privado.New(storage, randSource)
case providers.PrivateInternetAccess: case providers.PrivateInternetAccess:
return privateinternetaccess.New(serversSlice, randSource, timeNow) return privateinternetaccess.New(storage, randSource, timeNow)
case providers.Privatevpn: case providers.Privatevpn:
return privatevpn.New(serversSlice, randSource) return privatevpn.New(storage, randSource)
case providers.Protonvpn: case providers.Protonvpn:
return protonvpn.New(serversSlice, randSource) return protonvpn.New(storage, randSource)
case providers.Purevpn: case providers.Purevpn:
return purevpn.New(serversSlice, randSource) return purevpn.New(storage, randSource)
case providers.Surfshark: case providers.Surfshark:
return surfshark.New(serversSlice, randSource) return surfshark.New(storage, randSource)
case providers.Torguard: case providers.Torguard:
return torguard.New(serversSlice, randSource) return torguard.New(storage, randSource)
case providers.VPNUnlimited: case providers.VPNUnlimited:
return vpnunlimited.New(serversSlice, randSource) return vpnunlimited.New(storage, randSource)
case providers.Vyprvpn: case providers.Vyprvpn:
return vyprvpn.New(serversSlice, randSource) return vyprvpn.New(storage, randSource)
case providers.Wevpn: case providers.Wevpn:
return wevpn.New(serversSlice, randSource) return wevpn.New(storage, randSource)
case providers.Windscribe: case providers.Windscribe:
return windscribe.New(serversSlice, randSource) return windscribe.New(storage, randSource)
default: default:
panic("provider " + provider + " is unknown") // should never occur panic("provider " + provider + " is unknown") // should never occur
} }

View File

@@ -2,6 +2,7 @@ package purevpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(80, 53, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(80, 53, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Purevpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Purevpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Purevpn),
} }

View File

@@ -2,6 +2,7 @@ package surfshark
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(1443, 1194, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(1443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Surfshark,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Surfshark), NoPortForwarder: utils.NewNoPortForwarding(providers.Surfshark),
} }

View File

@@ -2,6 +2,7 @@ package torguard
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(1912, 1912, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(1912, 1912, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Torguard,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Torguard), NoPortForwarder: utils.NewNoPortForwarding(providers.Torguard),
} }

View File

@@ -1,7 +1,7 @@
package utils package utils
import ( import (
"errors" "fmt"
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
@@ -24,20 +24,20 @@ func NewConnectionDefaults(openvpnTCPPort, openvpnUDPPort,
} }
} }
var ErrNoServer = errors.New("no server") type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error)
}
func GetConnection(servers []models.Server, func GetConnection(provider string,
storage Storage,
selection settings.ServerSelection, selection settings.ServerSelection,
defaults ConnectionDefaults, defaults ConnectionDefaults,
randSource rand.Source) ( randSource rand.Source) (
connection models.Connection, err error) { connection models.Connection, err error) {
if len(servers) == 0 { servers, err := storage.FilterServers(provider, selection)
return connection, ErrNoServer if err != nil {
} return connection, fmt.Errorf("cannot filter servers: %w", err)
servers = filterServers(servers, selection)
if len(servers) == 0 {
return connection, noServerFoundError(selection)
} }
protocol := getProtocol(selection) protocol := getProtocol(selection)

View File

@@ -1,23 +1,30 @@
package utils package utils
import ( import (
"errors"
"math/rand" "math/rand"
"net" "net"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_GetConnection(t *testing.T) { func Test_GetConnection(t *testing.T) {
t.Parallel() t.Parallel()
errTest := errors.New("test error")
testCases := map[string]struct { testCases := map[string]struct {
servers []models.Server provider string
filteredServers []models.Server
filterError error
serverSelection settings.ServerSelection serverSelection settings.ServerSelection
defaults ConnectionDefaults defaults ConnectionDefaults
randSource rand.Source randSource rand.Source
@@ -25,25 +32,13 @@ func Test_GetConnection(t *testing.T) {
errWrapped error errWrapped error
errMessage string errMessage string
}{ }{
"no server": { "storage filter error": {
serverSelection: settings.ServerSelection{}. filterError: errTest,
WithDefaults(providers.Mullvad), errWrapped: errTest,
errWrapped: ErrNoServer, errMessage: "cannot filter servers: test error",
errMessage: "no server",
},
"all servers filtered": {
servers: []models.Server{
{VPN: vpn.Wireguard},
{VPN: vpn.Wireguard},
},
serverSelection: settings.ServerSelection{
VPN: vpn.OpenVPN,
}.WithDefaults(providers.Mullvad),
errWrapped: ErrNoServerFound,
errMessage: "no server found: for VPN openvpn; protocol udp",
}, },
"server without IPs": { "server without IPs": {
servers: []models.Server{ filteredServers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true}, {VPN: vpn.OpenVPN, UDP: true},
{VPN: vpn.OpenVPN, UDP: true}, {VPN: vpn.OpenVPN, UDP: true},
}, },
@@ -58,7 +53,7 @@ func Test_GetConnection(t *testing.T) {
errMessage: "no connection to pick from", errMessage: "no connection to pick from",
}, },
"OpenVPN server with hostname": { "OpenVPN server with hostname": {
servers: []models.Server{ filteredServers: []models.Server{
{ {
VPN: vpn.OpenVPN, VPN: vpn.OpenVPN,
UDP: true, UDP: true,
@@ -79,7 +74,7 @@ func Test_GetConnection(t *testing.T) {
}, },
}, },
"OpenVPN server with x509": { "OpenVPN server with x509": {
servers: []models.Server{ filteredServers: []models.Server{
{ {
VPN: vpn.OpenVPN, VPN: vpn.OpenVPN,
UDP: true, UDP: true,
@@ -101,7 +96,7 @@ func Test_GetConnection(t *testing.T) {
}, },
}, },
"server with IPv4 and IPv6": { "server with IPv4 and IPv6": {
servers: []models.Server{ filteredServers: []models.Server{
{ {
VPN: vpn.OpenVPN, VPN: vpn.OpenVPN,
UDP: true, UDP: true,
@@ -128,7 +123,7 @@ func Test_GetConnection(t *testing.T) {
}, },
}, },
"mixed servers": { "mixed servers": {
servers: []models.Server{ filteredServers: []models.Server{
{ {
VPN: vpn.OpenVPN, VPN: vpn.OpenVPN,
UDP: true, UDP: true,
@@ -169,8 +164,14 @@ func Test_GetConnection(t *testing.T) {
testCase := testCase testCase := testCase
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t)
connection, err := GetConnection(testCase.servers, storage := common.NewMockStorage(ctrl)
storage.EXPECT().
FilterServers(testCase.provider, testCase.serverSelection).
Return(testCase.filteredServers, testCase.filterError)
connection, err := GetConnection(testCase.provider, storage,
testCase.serverSelection, testCase.defaults, testCase.serverSelection, testCase.defaults,
testCase.randSource) testCase.randSource)

View File

@@ -2,6 +2,7 @@ package vpnunlimited
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.VPNUnlimited,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.VPNUnlimited), NoPortForwarder: utils.NewNoPortForwarding(providers.VPNUnlimited),
} }

View File

@@ -2,6 +2,7 @@ package vyprvpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Vyprvpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Vyprvpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Vyprvpn),
} }

View File

@@ -2,6 +2,7 @@ package wevpn
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(1195, 1194, 0) //nolint:gomnd defaults := utils.NewConnectionDefaults(1195, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Wevpn,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -1,43 +1,68 @@
package wevpn package wevpn
import ( import (
"errors"
"math/rand" "math/rand"
"net" "net"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Provider_GetConnection(t *testing.T) { func Test_Provider_GetConnection(t *testing.T) {
t.Parallel() t.Parallel()
const provider = providers.Wevpn
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct { testCases := map[string]struct {
servers []models.Server filteredServers []models.Server
selection settings.ServerSelection storageErr error
connection models.Connection selection settings.ServerSelection
errWrapped error connection models.Connection
errMessage string errWrapped error
errMessage string
panicMessage string
}{ }{
"no server available": { "error": {
selection: settings.ServerSelection{ storageErr: errTest,
VPN: vpn.OpenVPN, errWrapped: errTest,
}.WithDefaults(providers.Wevpn), errMessage: "cannot filter servers: test error",
errWrapped: utils.ErrNoServer,
errMessage: "no server",
}, },
"no filter": { "default OpenVPN TCP port": {
servers: []models.Server{ filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true}, {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
}, },
selection: settings.ServerSelection{}.WithDefaults(providers.Wevpn), selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
Port: 1195,
Protocol: constants.TCP,
},
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{ connection: models.Connection{
Type: vpn.OpenVPN, Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1), IP: net.IPv4(1, 1, 1, 1),
@@ -45,38 +70,14 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
}, },
"target IP": { "default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{ selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2), VPN: vpn.Wireguard,
}.WithDefaults(providers.Wevpn), }.WithDefaults(provider),
servers: []models.Server{ panicMessage: "no default Wireguard port is defined!",
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Wevpn),
servers: []models.Server{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Hostname: "b",
Protocol: constants.UDP,
},
}, },
} }
@@ -84,12 +85,23 @@ func Test_Provider_GetConnection(t *testing.T) {
testCase := testCase testCase := testCase
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0) randSource := rand.NewSource(0)
m := New(testCase.servers, randSource) provider := New(storage, randSource)
connection, err := m.GetConnection(testCase.selection) if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
_, _ = provider.GetConnection(testCase.selection)
})
return
}
connection, err := provider.GetConnection(testCase.selection)
assert.ErrorIs(t, err, testCase.errWrapped) assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil { if testCase.errWrapped != nil {

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Wevpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Wevpn),
} }

View File

@@ -2,6 +2,7 @@ package windscribe
import ( import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 1194) //nolint:gomnd defaults := utils.NewConnectionDefaults(443, 1194, 1194) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource) return utils.GetConnection(providers.Windscribe,
p.storage, selection, defaults, p.randSource)
} }

View File

@@ -1,41 +1,68 @@
package windscribe package windscribe
import ( import (
"errors"
"math/rand" "math/rand"
"net" "net"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Provider_GetConnection(t *testing.T) { func Test_Provider_GetConnection(t *testing.T) {
t.Parallel() t.Parallel()
const provider = providers.Windscribe
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct { testCases := map[string]struct {
servers []models.Server filteredServers []models.Server
selection settings.ServerSelection storageErr error
connection models.Connection selection settings.ServerSelection
errWrapped error connection models.Connection
errMessage string errWrapped error
errMessage string
panicMessage string
}{ }{
"no server available": { "error": {
selection: settings.ServerSelection{}.WithDefaults(providers.Windscribe), storageErr: errTest,
errWrapped: utils.ErrNoServer, errWrapped: errTest,
errMessage: "no server", errMessage: "cannot filter servers: test error",
}, },
"no filter": { "default OpenVPN TCP port": {
servers: []models.Server{ filteredServers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}}, {IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
}, },
selection: settings.ServerSelection{}.WithDefaults(providers.Windscribe), selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
Port: 443,
Protocol: constants.TCP,
},
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{ connection: models.Connection{
Type: vpn.OpenVPN, Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1), IP: net.IPv4(1, 1, 1, 1),
@@ -43,49 +70,41 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
}, },
"target IP": { "default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{ selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2), VPN: vpn.Wireguard,
}.WithDefaults(providers.Windscribe), }.WithDefaults(provider),
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{ connection: models.Connection{
Type: vpn.OpenVPN, Type: vpn.Wireguard,
IP: net.IPv4(2, 2, 2, 2), IP: net.IPv4(1, 1, 1, 1),
Port: 1194, Port: 1194,
Protocol: constants.UDP, Protocol: constants.UDP,
}, },
}, },
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Windscribe),
servers: []models.Server{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Hostname: "b",
Protocol: constants.UDP,
},
},
} }
for name, testCase := range testCases { for name, testCase := range testCases {
testCase := testCase testCase := testCase
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0) randSource := rand.NewSource(0)
provider := New(testCase.servers, randSource) provider := New(storage, randSource)
if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
_, _ = provider.GetConnection(testCase.selection)
})
return
}
connection, err := provider.GetConnection(testCase.selection) connection, err := provider.GetConnection(testCase.selection)
@@ -93,6 +112,7 @@ func Test_Provider_GetConnection(t *testing.T) {
if testCase.errWrapped != nil { if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
} }
assert.Equal(t, testCase.connection, connection) assert.Equal(t, testCase.connection, connection)
}) })
} }

View File

@@ -4,19 +4,19 @@ import (
"math/rand" "math/rand"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
type Provider struct { type Provider struct {
servers []models.Server storage common.Storage
randSource rand.Source randSource rand.Source
utils.NoPortForwarder utils.NoPortForwarder
} }
func New(servers []models.Server, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{ return &Provider{
servers: servers, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Windscribe), NoPortForwarder: utils.NewNoPortForwarding(providers.Windscribe),
} }

View File

@@ -0,0 +1,27 @@
package storage
import (
"github.com/qdm12/gluetun/internal/configuration/settings/validation"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
)
func (s *Storage) GetFilterChoices(provider string) models.FilterChoices {
if provider == providers.Custom {
return models.FilterChoices{}
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
servers := serversObject.Servers
return models.FilterChoices{
Countries: validation.ExtractCountries(servers),
Regions: validation.ExtractRegions(servers),
Cities: validation.ExtractCities(servers),
ISPs: validation.ExtractISPs(servers),
Names: validation.ExtractServerNames(servers),
Hostnames: validation.ExtractHostnames(servers),
}
}

32
internal/storage/copy.go Normal file
View File

@@ -0,0 +1,32 @@
package storage
import (
"net"
"github.com/qdm12/gluetun/internal/models"
)
func copyServer(server models.Server) (serverCopy models.Server) {
serverCopy = server
serverCopy.IPs = copyIPs(server.IPs)
return serverCopy
}
func copyIPs(toCopy []net.IP) (copied []net.IP) {
if toCopy == nil {
return nil
}
copied = make([]net.IP, len(toCopy))
for i := range toCopy {
copied[i] = copyIP(toCopy[i])
}
return copied
}
func copyIP(toCopy net.IP) (copied net.IP) {
copied = make(net.IP, len(toCopy))
copy(copied, toCopy)
return copied
}

View File

@@ -0,0 +1,75 @@
package storage
import (
"net"
"testing"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_copyServer(t *testing.T) {
t.Parallel()
server := models.Server{
Country: "a",
IPs: []net.IP{{1, 2, 3, 4}},
}
serverCopy := copyServer(server)
assert.Equal(t, server, serverCopy)
// Check for mutation
serverCopy.IPs[0][0] = 9
assert.NotEqual(t, server, serverCopy)
}
func Test_copyIPs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
toCopy []net.IP
copied []net.IP
}{
"nil": {},
"empty": {
toCopy: []net.IP{},
copied: []net.IP{},
},
"single IP": {
toCopy: []net.IP{{1, 1, 1, 1}},
copied: []net.IP{{1, 1, 1, 1}},
},
"two IPs": {
toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
// Reserver leading 9 for copy modifications below
for _, ipToCopy := range testCase.toCopy {
require.NotEqual(t, 9, ipToCopy[0])
}
copied := copyIPs(testCase.toCopy)
assert.Equal(t, testCase.copied, copied)
if len(copied) > 0 {
original := testCase.toCopy[0][0]
testCase.toCopy[0][0] = 9
assert.NotEqual(t, 9, copied[0][0])
testCase.toCopy[0][0] = original
copied[0][0] = 9
assert.NotEqual(t, 9, testCase.toCopy[0][0])
}
})
}
}

143
internal/storage/filter.go Normal file
View File

@@ -0,0 +1,143 @@
package storage
import (
"strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
)
// FilterServers filter servers for the given provider and according
// to the given selection. The filtered servers are deep copied so they
// are safe for mutation by the caller.
func (s *Storage) FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error) {
if provider == providers.Custom {
return nil, nil
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
allServers := serversObject.Servers
if len(allServers) == 0 {
return nil, ErrNoServerFound
}
for _, server := range allServers {
if filterServer(server, selection) {
continue
}
server = copyServer(server)
servers = append(servers, server)
}
if len(servers) == 0 {
return nil, noServerFoundError(selection)
}
return servers, nil
}
func filterServer(server models.Server,
selection settings.ServerSelection) (filtered bool) {
// Note each condition is split to make sure
// we have full testing coverage.
if server.VPN != selection.VPN {
return true
}
if filterByProtocol(selection, server.TCP, server.UDP) {
return true
}
if *selection.MultiHopOnly && !server.MultiHop {
return true
}
if *selection.FreeOnly && !server.Free {
return true
}
if *selection.StreamOnly && !server.Stream {
return true
}
if *selection.OwnedOnly && !server.Owned {
return true
}
if filterByPossibilities(server.Country, selection.Countries) {
return true
}
if filterByPossibilities(server.Region, selection.Regions) {
return true
}
if filterByPossibilities(server.City, selection.Cities) {
return true
}
if filterByPossibilities(server.ISP, selection.ISPs) {
return true
}
if filterByPossibilitiesUint16(server.Number, selection.Numbers) {
return true
}
if filterByPossibilities(server.ServerName, selection.Names) {
return true
}
if filterByPossibilities(server.Hostname, selection.Hostnames) {
return true
}
// TODO filter port forward server for PIA
return false
}
func filterByPossibilities(value string, possibilities []string) (filtered bool) {
if len(possibilities) == 0 {
return false
}
for _, possibility := range possibilities {
if strings.EqualFold(value, possibility) {
return false
}
}
return true
}
// TODO merge with filterByPossibilities with generics in Go 1.18.
func filterByPossibilitiesUint16(value uint16, possibilities []uint16) (filtered bool) {
if len(possibilities) == 0 {
return false
}
for _, possibility := range possibilities {
if value == possibility {
return false
}
}
return true
}
func filterByProtocol(selection settings.ServerSelection,
serverTCP, serverUDP bool) (filtered bool) {
switch selection.VPN {
case vpn.Wireguard:
return !serverUDP
default: // OpenVPN
wantTCP := *selection.OpenVPN.TCP
wantUDP := !wantTCP
return (wantTCP && !serverTCP) || (wantUDP && !serverUDP)
}
}

View File

@@ -4,21 +4,20 @@ import (
"encoding/json" "encoding/json"
"os" "os"
"path/filepath" "path/filepath"
"github.com/qdm12/gluetun/internal/models"
) )
var _ Flusher = (*Storage)(nil) // FlushToFile flushes the merged servers data to the file
// specified by path, as indented JSON.
func (s *Storage) FlushToFile(path string) error {
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
type Flusher interface { return s.flushToFile(path)
FlushToFile(allServers *models.AllServers) error
} }
func (s *Storage) FlushToFile(allServers *models.AllServers) error { // flushToFile flushes the merged servers data to the file
return flushToFile(s.filepath, allServers) // specified by path, as indented JSON. It is not thread-safe.
} func (s *Storage) flushToFile(path string) error {
func flushToFile(path string, servers *models.AllServers) error {
dirPath := filepath.Dir(path) dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, 0644); err != nil { if err := os.MkdirAll(dirPath, 0644); err != nil {
return err return err
@@ -28,11 +27,15 @@ func flushToFile(path string, servers *models.AllServers) error {
if err != nil { if err != nil {
return err return err
} }
encoder := json.NewEncoder(file) encoder := json.NewEncoder(file)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
if err := encoder.Encode(servers); err != nil {
err = encoder.Encode(&s.mergedServers)
if err != nil {
_ = file.Close() _ = file.Close()
return err return err
} }
return file.Close() return file.Close()
} }

View File

@@ -1,4 +1,4 @@
package utils package storage
import ( import (
"errors" "errors"

View File

@@ -1,7 +1,118 @@
package storage package storage
import "github.com/qdm12/gluetun/internal/models" import (
"fmt"
"time"
func (s *Storage) GetServers() models.AllServers { "github.com/qdm12/gluetun/internal/constants/providers"
return s.mergedServers.GetCopy() "github.com/qdm12/gluetun/internal/models"
)
// SetServers sets the given servers for the given provider
// in the storage in-memory map and saves all the servers
// to file.
// Note the servers given are not copied so the caller must
// NOT MUTATE them after calling this method.
func (s *Storage) SetServers(provider string, servers []models.Server) (err error) {
if provider == providers.Custom {
return
}
s.mergedMutex.Lock()
defer s.mergedMutex.Unlock()
serversObject := s.getMergedServersObject(provider)
serversObject.Timestamp = time.Now().Unix()
serversObject.Servers = servers
s.mergedServers.ProviderToServers[provider] = serversObject
err = s.flushToFile(s.filepath)
if err != nil {
return fmt.Errorf("cannot save servers to file: %w", err)
}
return nil
}
// GetServerByName returns the server for the given provider
// and server name. It returns `ok` as false if the server is
// not found. The returned server is also deep copied so it is
// safe for mutation and/or thread safe use.
func (s *Storage) GetServerByName(provider, name string) (
server models.Server, ok bool) {
if provider == providers.Custom {
return server, false
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
for _, server := range serversObject.Servers {
if server.ServerName == name {
return copyServer(server), true
}
}
return server, false
}
// GetServersCount returns the number of servers for the provider given.
func (s *Storage) GetServersCount(provider string) (count int) {
if provider == providers.Custom {
return 0
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
return len(serversObject.Servers)
}
// FormatToMarkdown Markdown formats the servers for the provider given
// and returns the resulting string.
func (s *Storage) FormatToMarkdown(provider string) (formatted string) {
if provider == providers.Custom {
return ""
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
formatted = serversObject.ToMarkdown(provider)
return formatted
}
// GetServersCount returns the number of servers for the provider given.
func (s *Storage) ServersAreEqual(provider string, servers []models.Server) (equal bool) {
if provider == providers.Custom {
return true
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
existingServers := serversObject.Servers
if len(existingServers) != len(servers) {
return false
}
for i := range existingServers {
if !existingServers[i].Equal(servers[i]) {
return false
}
}
return true
}
func (s *Storage) getMergedServersObject(provider string) (serversObject models.Servers) {
serversObject, ok := s.mergedServers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in in-memory servers map", provider))
}
return serversObject
} }

View File

@@ -2,11 +2,14 @@
package storage package storage
import ( import (
"sync"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
type Storage struct { type Storage struct {
mergedServers models.AllServers mergedServers models.AllServers
mergedMutex sync.RWMutex
// this is stored in memory to avoid re-parsing // this is stored in memory to avoid re-parsing
// the embedded JSON file on every call to the // the embedded JSON file on every call to the
// SyncServers method. // SyncServers method.

View File

@@ -29,6 +29,9 @@ func (s *Storage) syncServers() (err error) {
hardcodedCount := countServers(s.hardcodedServers) hardcodedCount := countServers(s.hardcodedServers)
countOnFile := countServers(serversOnFile) countOnFile := countServers(serversOnFile)
s.mergedMutex.Lock()
defer s.mergedMutex.Unlock()
if countOnFile == 0 { if countOnFile == 0 {
s.logger.Info(fmt.Sprintf( s.logger.Info(fmt.Sprintf(
"creating %s with %d hardcoded servers", "creating %s with %d hardcoded servers",
@@ -47,7 +50,8 @@ func (s *Storage) syncServers() (err error) {
return nil return nil
} }
if err := flushToFile(s.filepath, &s.mergedServers); err != nil { err = s.flushToFile(s.filepath)
if err != nil {
return fmt.Errorf("cannot write servers to file: %w", err) return fmt.Errorf("cannot write servers to file: %w", err)
} }
return nil return nil

View File

@@ -9,7 +9,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
) )
@@ -24,16 +23,14 @@ type Looper interface {
} }
type Updater interface { type Updater interface {
UpdateServers(ctx context.Context) (allServers models.AllServers, err error) UpdateServers(ctx context.Context) (err error)
} }
type looper struct { type looper struct {
state state state state
// Objects // Objects
updater Updater updater Updater
flusher storage.Flusher logger Logger
setAllServers func(allServers models.AllServers)
logger Logger
// Internal channels and locks // Internal channels and locks
loopLock sync.Mutex loopLock sync.Mutex
start chan struct{} start chan struct{}
@@ -49,32 +46,35 @@ type looper struct {
const defaultBackoffTime = 5 * time.Second const defaultBackoffTime = 5 * time.Second
type Storage interface {
SetServers(provider string, servers []models.Server) (err error)
GetServersCount(provider string) (count int)
ServersAreEqual(provider string, servers []models.Server) (equal bool)
}
type Logger interface { type Logger interface {
Info(s string) Info(s string)
Warn(s string) Warn(s string)
Error(s string) Error(s string)
} }
func NewLooper(settings settings.Updater, currentServers models.AllServers, func NewLooper(settings settings.Updater, storage Storage,
flusher storage.Flusher, setAllServers func(allServers models.AllServers),
client *http.Client, logger Logger) Looper { client *http.Client, logger Logger) Looper {
return &looper{ return &looper{
state: state{ state: state{
status: constants.Stopped, status: constants.Stopped,
settings: settings, settings: settings,
}, },
updater: updater.New(settings, client, currentServers, logger), updater: updater.New(settings, client, storage, logger),
flusher: flusher, logger: logger,
setAllServers: setAllServers, start: make(chan struct{}),
logger: logger, running: make(chan models.LoopStatus),
start: make(chan struct{}), stop: make(chan struct{}),
running: make(chan models.LoopStatus), stopped: make(chan struct{}),
stop: make(chan struct{}), updateTicker: make(chan struct{}),
stopped: make(chan struct{}), timeNow: time.Now,
updateTicker: make(chan struct{}), timeSince: time.Since,
timeNow: time.Now, backoffTime: defaultBackoffTime,
timeSince: time.Since,
backoffTime: defaultBackoffTime,
} }
} }
@@ -106,20 +106,19 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
for ctx.Err() == nil { for ctx.Err() == nil {
updateCtx, updateCancel := context.WithCancel(ctx) updateCtx, updateCancel := context.WithCancel(ctx)
serversCh := make(chan models.AllServers)
errorCh := make(chan error) errorCh := make(chan error)
runWg := &sync.WaitGroup{} runWg := &sync.WaitGroup{}
runWg.Add(1) runWg.Add(1)
go func() { go func() {
defer runWg.Done() defer runWg.Done()
servers, err := l.updater.UpdateServers(updateCtx) err := l.updater.UpdateServers(updateCtx)
if err != nil { if err != nil {
if updateCtx.Err() == nil { if updateCtx.Err() == nil {
errorCh <- err errorCh <- err
} }
return return
} }
serversCh <- servers l.state.setStatusWithLock(constants.Completed)
}() }()
if !crashed { if !crashed {
@@ -148,16 +147,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
updateCancel() updateCancel()
runWg.Wait() runWg.Wait()
l.stopped <- struct{}{} l.stopped <- struct{}{}
case servers := <-serversCh:
l.setAllServers(servers)
if err := l.flusher.FlushToFile(&servers); err != nil {
l.logger.Error(err.Error())
}
runWg.Wait()
l.state.setStatusWithLock(constants.Completed)
l.logger.Info("Updated servers information")
case err := <-errorCh: case err := <-errorCh:
close(serversCh)
runWg.Wait() runWg.Wait()
l.state.setStatusWithLock(constants.Crashed) l.state.setStatusWithLock(constants.Crashed)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)

View File

@@ -3,8 +3,6 @@ package updater
import ( import (
"context" "context"
"fmt" "fmt"
"reflect"
"time"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
@@ -31,18 +29,25 @@ import (
) )
func (u *Updater) updateProvider(ctx context.Context, provider string) (err error) { func (u *Updater) updateProvider(ctx context.Context, provider string) (err error) {
existingServers := u.getProviderServers(provider) existingServersCount := u.storage.GetServersCount(provider)
minServers := getMinServers(existingServers) minServers := getMinServers(existingServersCount)
servers, err := u.getServers(ctx, provider, minServers) servers, err := u.getServers(ctx, provider, minServers)
if err != nil { if err != nil {
return err return fmt.Errorf("cannot get servers: %w", err)
} }
if reflect.DeepEqual(existingServers, servers) { if u.storage.ServersAreEqual(provider, servers) {
return nil return nil
} }
u.patchProvider(provider, servers) // Note the servers variable must NOT BE MUTATED after this call,
// since the implementation does not deep copy the servers.
// TODO set in storage in provider updater directly, server by server,
// to avoid accumulating server data in memory.
err = u.storage.SetServers(provider, servers)
if err != nil {
return fmt.Errorf("cannot set servers to storage: %w", err)
}
return nil return nil
} }
@@ -101,25 +106,7 @@ func (u *Updater) getServers(ctx context.Context, provider string,
return providerUpdater.GetServers(ctx, minServers) return providerUpdater.GetServers(ctx, minServers)
} }
func (u *Updater) getProviderServers(provider string) (servers []models.Server) { func getMinServers(existingServersCount int) (minServers int) {
providerServers, ok := u.servers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s is unknown", provider))
}
return providerServers.Servers
}
func getMinServers(servers []models.Server) (minServers int) {
const minRatio = 0.8 const minRatio = 0.8
return int(minRatio * float64(len(servers))) return int(minRatio * float64(existingServersCount))
}
func (u *Updater) patchProvider(provider string, servers []models.Server) {
providerServers, ok := u.servers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s is unknown", provider))
}
providerServers.Timestamp = time.Now().Unix()
providerServers.Servers = servers
u.servers.ProviderToServers[provider] = providerServers
} }

View File

@@ -19,7 +19,7 @@ type Updater struct {
options settings.Updater options settings.Updater
// state // state
servers models.AllServers storage Storage
// Functions for tests // Functions for tests
logger Logger logger Logger
@@ -29,6 +29,12 @@ type Updater struct {
unzipper unzip.Unzipper unzipper unzip.Unzipper
} }
type Storage interface {
SetServers(provider string, servers []models.Server) (err error)
GetServersCount(provider string) (count int)
ServersAreEqual(provider string, servers []models.Server) (equal bool)
}
type Logger interface { type Logger interface {
Info(s string) Info(s string)
Warn(s string) Warn(s string)
@@ -36,20 +42,20 @@ type Logger interface {
} }
func New(settings settings.Updater, httpClient *http.Client, func New(settings settings.Updater, httpClient *http.Client,
currentServers models.AllServers, logger Logger) *Updater { storage Storage, logger Logger) *Updater {
unzipper := unzip.New(httpClient) unzipper := unzip.New(httpClient)
return &Updater{ return &Updater{
options: settings,
storage: storage,
logger: logger, logger: logger,
timeNow: time.Now, timeNow: time.Now,
presolver: resolver.NewParallelResolver(settings.DNSAddress.String()), presolver: resolver.NewParallelResolver(settings.DNSAddress.String()),
client: httpClient, client: httpClient,
unzipper: unzipper, unzipper: unzipper,
options: settings,
servers: currentServers,
} }
} }
func (u *Updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) { func (u *Updater) UpdateServers(ctx context.Context) (err error) {
caser := cases.Title(language.English) caser := cases.Title(language.English)
for _, provider := range u.options.Providers { for _, provider := range u.options.Providers {
u.logger.Info("updating " + caser.String(provider) + " servers...") u.logger.Info("updating " + caser.String(provider) + " servers...")
@@ -62,17 +68,17 @@ func (u *Updater) UpdateServers(ctx context.Context) (allServers models.AllServe
// return the only error for the single provider. // return the only error for the single provider.
if len(u.options.Providers) == 1 { if len(u.options.Providers) == 1 {
return allServers, err return err
} }
// stop updating the next providers if context is canceled. // stop updating the next providers if context is canceled.
if ctxErr := ctx.Err(); ctxErr != nil { if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr return ctxErr
} }
// Log the error and continue updating the next provider. // Log the error and continue updating the next provider.
u.logger.Error(err.Error()) u.logger.Error(err.Error())
} }
return u.servers, nil return nil
} }

View File

@@ -27,12 +27,12 @@ type Looper interface {
loopstate.Getter loopstate.Getter
loopstate.Applier loopstate.Applier
SettingsGetSetter SettingsGetSetter
ServersGetterSetter
} }
type Loop struct { type Loop struct {
statusManager loopstate.Manager statusManager loopstate.Manager
state state.Manager state state.Manager
storage Storage
// Fixed parameters // Fixed parameters
buildInfo models.BuildInformation buildInfo models.BuildInformation
versionInfo bool versionInfo bool
@@ -64,12 +64,17 @@ type firewallConfigurer interface {
firewall.PortAllower firewall.PortAllower
} }
type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}
const ( const (
defaultBackoffTime = 15 * time.Second defaultBackoffTime = 15 * time.Second
) )
func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16, func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16,
allServers models.AllServers, openvpnConf openvpn.Interface, storage Storage, openvpnConf openvpn.Interface,
netLinker netlink.NetLinker, fw firewallConfigurer, routing routing.VPNGetter, netLinker netlink.NetLinker, fw firewallConfigurer, routing routing.VPNGetter,
portForward portforward.StartStopper, starter command.Starter, portForward portforward.StartStopper, starter command.Starter,
publicip publicip.Looper, dnsLooper dns.Looper, publicip publicip.Looper, dnsLooper dns.Looper,
@@ -81,11 +86,12 @@ func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16,
stopped := make(chan struct{}) stopped := make(chan struct{})
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped) statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
state := state.New(statusManager, vpnSettings, allServers) state := state.New(statusManager, vpnSettings)
return &Loop{ return &Loop{
statusManager: statusManager, statusManager: statusManager,
state: state, state: state,
storage: storage,
buildInfo: buildInfo, buildInfo: buildInfo,
versionInfo: versionInfo, versionInfo: versionInfo,
vpnInputPorts: vpnInputPorts, vpnInputPorts: vpnInputPorts,

View File

@@ -28,9 +28,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
} }
for ctx.Err() == nil { for ctx.Err() == nil {
settings, allServers := l.state.GetSettingsAndServers() settings := l.state.GetSettings()
providerConf := provider.New(*settings.Provider.Name, allServers, time.Now) providerConf := provider.New(*settings.Provider.Name, l.storage, time.Now)
portForwarding := *settings.Provider.PortForwarding.Enabled portForwarding := *settings.Provider.PortForwarding.Enabled
var vpnRunner vpnRunner var vpnRunner vpnRunner

View File

@@ -1,16 +0,0 @@
package vpn
import (
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/vpn/state"
)
type ServersGetterSetter = state.ServersGetterSetter
func (l *Loop) GetServers() (servers models.AllServers) {
return l.state.GetServers()
}
func (l *Loop) SetServers(servers models.AllServers) {
l.state.SetServers(servers)
}

View File

@@ -1,20 +0,0 @@
package state
import "github.com/qdm12/gluetun/internal/models"
type ServersGetterSetter interface {
GetServers() (servers models.AllServers)
SetServers(servers models.AllServers)
}
func (s *State) GetServers() (servers models.AllServers) {
s.allServersMu.RLock()
defer s.allServersMu.RUnlock()
return s.allServers
}
func (s *State) SetServers(servers models.AllServers) {
s.allServersMu.Lock()
defer s.allServersMu.Unlock()
s.allServers = servers
}

View File

@@ -5,23 +5,18 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/loopstate" "github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
) )
var _ Manager = (*State)(nil) var _ Manager = (*State)(nil)
type Manager interface { type Manager interface {
SettingsGetSetter SettingsGetSetter
ServersGetterSetter
GetSettingsAndServers() (vpn settings.VPN, allServers models.AllServers)
} }
func New(statusApplier loopstate.Applier, func New(statusApplier loopstate.Applier, vpn settings.VPN) *State {
vpn settings.VPN, allServers models.AllServers) *State {
return &State{ return &State{
statusApplier: statusApplier, statusApplier: statusApplier,
vpn: vpn, vpn: vpn,
allServers: allServers,
} }
} }
@@ -30,18 +25,4 @@ type State struct {
vpn settings.VPN vpn settings.VPN
settingsMu sync.RWMutex settingsMu sync.RWMutex
allServers models.AllServers
allServersMu sync.RWMutex
}
func (s *State) GetSettingsAndServers() (vpn settings.VPN,
allServers models.AllServers) {
s.settingsMu.RLock()
s.allServersMu.RLock()
vpn = s.vpn
allServers = s.allServers
s.settingsMu.RUnlock()
s.allServersMu.RUnlock()
return vpn, allServers
} }