chore(all): memory and thread safe storage
- settings: get filter choices from storage for settings validation - updater: update servers to the storage - storage: minimal deep copying and data duplication - storage: add merged servers mutex for thread safety - connection: filter servers in storage - formatter: format servers to Markdown in storage - PIA: get server by name from storage directly - Updater: get servers count from storage directly - Updater: equality check done in storage, fix #882
This commit is contained in:
@@ -215,9 +215,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
return err
|
||||
}
|
||||
|
||||
allServers := storage.GetServers()
|
||||
|
||||
err = allSettings.Validate(allServers)
|
||||
err = allSettings.Validate(storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -378,7 +376,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
vpnLogger := logger.New(log.SetComponent("vpn"))
|
||||
vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.Firewall.VPNInputPorts,
|
||||
allServers, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper,
|
||||
storage, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper,
|
||||
cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient,
|
||||
buildInfo, *allSettings.Version.Enabled)
|
||||
vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler(
|
||||
@@ -386,8 +384,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
go vpnLooper.Run(vpnCtx, vpnDone)
|
||||
|
||||
updaterLooper := updater.NewLooper(allSettings.Updater,
|
||||
allServers, storage, vpnLooper.SetServers, httpClient,
|
||||
logger.New(log.SetComponent("updater")))
|
||||
storage, httpClient, logger.New(log.SetComponent("updater")))
|
||||
updaterHandler, updaterCtx, updaterDone := goshutdown.NewGoRoutineHandler(
|
||||
"updater", goroutine.OptionTimeout(defaultShutdownTimeout))
|
||||
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
@@ -80,9 +79,8 @@ func (c *CLI) FormatServers(args []string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create servers storage: %w", err)
|
||||
}
|
||||
currentServers := storage.GetServers()
|
||||
|
||||
formatted := formatServers(currentServers, providerToFormat)
|
||||
formatted := storage.FormatToMarkdown(providerToFormat)
|
||||
|
||||
output = filepath.Clean(output)
|
||||
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
|
||||
@@ -103,11 +101,3 @@ func (c *CLI) FormatServers(args []string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatServers(allServers models.AllServers, provider string) (formatted string) {
|
||||
servers, ok := allServers.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
|
||||
}
|
||||
return servers.ToMarkdown(provider)
|
||||
}
|
||||
|
||||
@@ -25,18 +25,17 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) e
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
allServers := storage.GetServers()
|
||||
|
||||
allSettings, err := source.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = allSettings.Validate(allServers); err != nil {
|
||||
if err = allSettings.Validate(storage); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
providerConf := provider.New(*allSettings.VPN.Provider.Name, allServers, time.Now)
|
||||
providerConf := provider.New(*allSettings.VPN.Provider.Name, storage, time.Now)
|
||||
connection, err := providerConf.GetConnection(allSettings.VPN.Provider.ServerSelection)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,20 +2,17 @@ package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
)
|
||||
@@ -83,41 +80,19 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create servers storage: %w", err)
|
||||
}
|
||||
currentServers := storage.GetServers()
|
||||
|
||||
updater := updater.New(options, httpClient, currentServers, logger)
|
||||
allServers, err := updater.UpdateServers(ctx)
|
||||
updater := updater.New(options, httpClient, storage, logger)
|
||||
err = updater.UpdateServers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot update server information: %w", err)
|
||||
}
|
||||
|
||||
if endUserMode {
|
||||
if err := storage.FlushToFile(&allServers); err != nil {
|
||||
return fmt.Errorf("cannot write updated information to file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if maintainerMode {
|
||||
if err := writeToEmbeddedJSON(c.repoServersPath, &allServers); err != nil {
|
||||
return fmt.Errorf("cannot write updated information to file: %w", err)
|
||||
err := storage.FlushToFile(c.repoServersPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot write servers data to embedded JSON file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeToEmbeddedJSON(repoServersPath string,
|
||||
allServers *models.AllServers) error {
|
||||
const perms = 0600
|
||||
f, err := os.OpenFile(repoServersPath,
|
||||
os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder.Encode(allServers)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gotree"
|
||||
)
|
||||
|
||||
@@ -23,7 +22,7 @@ type Provider struct {
|
||||
}
|
||||
|
||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||
func (p *Provider) validate(vpnType string, allServers models.AllServers) (err error) {
|
||||
func (p *Provider) validate(vpnType string, storage Storage) (err error) {
|
||||
// Validate Name
|
||||
var validNames []string
|
||||
if vpnType == vpn.OpenVPN {
|
||||
@@ -42,7 +41,7 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e
|
||||
ErrVPNProviderNameNotValid, *p.Name, helpers.ChoicesOrString(validNames))
|
||||
}
|
||||
|
||||
err = p.ServerSelection.validate(*p.Name, allServers)
|
||||
err = p.ServerSelection.validate(*p.Name, storage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server selection: %w", err)
|
||||
}
|
||||
|
||||
@@ -68,21 +68,19 @@ var (
|
||||
)
|
||||
|
||||
func (ss *ServerSelection) validate(vpnServiceProvider string,
|
||||
allServers models.AllServers) (err error) {
|
||||
storage Storage) (err error) {
|
||||
switch ss.VPN {
|
||||
case vpn.OpenVPN, vpn.Wireguard:
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
|
||||
}
|
||||
|
||||
countryChoices, regionChoices, cityChoices,
|
||||
ispChoices, nameChoices, hostnameChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, allServers)
|
||||
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage)
|
||||
if err != nil {
|
||||
return err // already wrapped error
|
||||
}
|
||||
|
||||
err = validateServerFilters(*ss, countryChoices, regionChoices, cityChoices,
|
||||
ispChoices, nameChoices, hostnameChoices)
|
||||
err = validateServerFilters(*ss, filterChoices)
|
||||
if err != nil {
|
||||
if errors.Is(err, helpers.ErrNoChoice) {
|
||||
return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err)
|
||||
@@ -135,63 +133,48 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection,
|
||||
allServers models.AllServers) (
|
||||
countryChoices, regionChoices, cityChoices,
|
||||
ispChoices, nameChoices, hostnameChoices []string,
|
||||
func getLocationFilterChoices(vpnServiceProvider string,
|
||||
ss *ServerSelection, storage Storage) (filterChoices models.FilterChoices,
|
||||
err error) {
|
||||
providerServers, ok := allServers.ProviderToServers[vpnServiceProvider]
|
||||
if !ok && vpnServiceProvider != providers.Custom {
|
||||
panic(fmt.Sprintf("VPN service provider unknown: %s", vpnServiceProvider))
|
||||
}
|
||||
servers := providerServers.Servers
|
||||
countryChoices = validation.ExtractCountries(servers)
|
||||
regionChoices = validation.ExtractRegions(servers)
|
||||
cityChoices = validation.ExtractCities(servers)
|
||||
ispChoices = validation.ExtractISPs(servers)
|
||||
nameChoices = validation.ExtractServerNames(servers)
|
||||
hostnameChoices = validation.ExtractHostnames(servers)
|
||||
filterChoices = storage.GetFilterChoices(vpnServiceProvider)
|
||||
|
||||
if vpnServiceProvider == providers.Surfshark {
|
||||
// // Retro compatibility
|
||||
// TODO v4 remove
|
||||
regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...)
|
||||
if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err)
|
||||
filterChoices.Regions = append(filterChoices.Regions, validation.SurfsharkRetroLocChoices()...)
|
||||
if err := helpers.AreAllOneOf(ss.Regions, filterChoices.Regions); err != nil {
|
||||
return models.FilterChoices{}, fmt.Errorf("%w: %s", ErrRegionNotValid, err)
|
||||
}
|
||||
*ss = surfsharkRetroRegion(*ss)
|
||||
}
|
||||
|
||||
return countryChoices, regionChoices, cityChoices,
|
||||
ispChoices, nameChoices, hostnameChoices, nil
|
||||
return filterChoices, nil
|
||||
}
|
||||
|
||||
// validateServerFilters validates filters against the choices given as arguments.
|
||||
// Set an argument to nil to pass the check for a particular filter.
|
||||
func validateServerFilters(settings ServerSelection,
|
||||
countryChoices, regionChoices, cityChoices, ispChoices,
|
||||
nameChoices, hostnameChoices []string) (err error) {
|
||||
if err := helpers.AreAllOneOf(settings.Countries, countryChoices); err != nil {
|
||||
func validateServerFilters(settings ServerSelection, filterChoices models.FilterChoices) (err error) {
|
||||
if err := helpers.AreAllOneOf(settings.Countries, filterChoices.Countries); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrCountryNotValid, err)
|
||||
}
|
||||
|
||||
if err := helpers.AreAllOneOf(settings.Regions, regionChoices); err != nil {
|
||||
if err := helpers.AreAllOneOf(settings.Regions, filterChoices.Regions); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrRegionNotValid, err)
|
||||
}
|
||||
|
||||
if err := helpers.AreAllOneOf(settings.Cities, cityChoices); err != nil {
|
||||
if err := helpers.AreAllOneOf(settings.Cities, filterChoices.Cities); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrCityNotValid, err)
|
||||
}
|
||||
|
||||
if err := helpers.AreAllOneOf(settings.ISPs, ispChoices); err != nil {
|
||||
if err := helpers.AreAllOneOf(settings.ISPs, filterChoices.ISPs); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrISPNotValid, err)
|
||||
}
|
||||
|
||||
if err := helpers.AreAllOneOf(settings.Hostnames, hostnameChoices); err != nil {
|
||||
if err := helpers.AreAllOneOf(settings.Hostnames, filterChoices.Hostnames); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrHostnameNotValid, err)
|
||||
}
|
||||
|
||||
if err := helpers.AreAllOneOf(settings.Names, nameChoices); err != nil {
|
||||
if err := helpers.AreAllOneOf(settings.Names, filterChoices.Names); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrNameNotValid, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -24,10 +24,14 @@ type Settings struct {
|
||||
Pprof pprof.Settings
|
||||
}
|
||||
|
||||
type Storage interface {
|
||||
GetFilterChoices(provider string) models.FilterChoices
|
||||
}
|
||||
|
||||
// Validate validates all the settings and returns an error
|
||||
// if one of them is not valid.
|
||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||
func (s *Settings) Validate(allServers models.AllServers) (err error) {
|
||||
func (s *Settings) Validate(storage Storage) (err error) {
|
||||
nameToValidation := map[string]func() error{
|
||||
"control server": s.ControlServer.validate,
|
||||
"dns": s.DNS.validate,
|
||||
@@ -42,7 +46,7 @@ func (s *Settings) Validate(allServers models.AllServers) (err error) {
|
||||
"version": s.Version.validate,
|
||||
// Pprof validation done in pprof constructor
|
||||
"VPN": func() error {
|
||||
return s.VPN.validate(allServers)
|
||||
return s.VPN.validate(storage)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -91,7 +95,7 @@ func (s *Settings) MergeWith(other Settings) {
|
||||
}
|
||||
|
||||
func (s *Settings) OverrideWith(other Settings,
|
||||
allServers models.AllServers) (err error) {
|
||||
storage Storage) (err error) {
|
||||
patchedSettings := s.copy()
|
||||
patchedSettings.ControlServer.overrideWith(other.ControlServer)
|
||||
patchedSettings.DNS.overrideWith(other.DNS)
|
||||
@@ -106,7 +110,7 @@ func (s *Settings) OverrideWith(other Settings,
|
||||
patchedSettings.Version.overrideWith(other.Version)
|
||||
patchedSettings.VPN.overrideWith(other.VPN)
|
||||
patchedSettings.Pprof.MergeWith(other.Pprof)
|
||||
err = patchedSettings.Validate(allServers)
|
||||
err = patchedSettings.Validate(storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -35,17 +35,18 @@ func (u Updater) Validate() (err error) {
|
||||
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod)
|
||||
}
|
||||
|
||||
for i, provider := range u.Providers {
|
||||
validProviders := providers.All()
|
||||
for _, provider := range u.Providers {
|
||||
valid := false
|
||||
for _, validProvider := range providers.All() {
|
||||
for _, validProvider := range validProviders {
|
||||
if provider == validProvider {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("%w: %s at index %d",
|
||||
ErrVPNProviderNameNotValid, provider, i)
|
||||
return fmt.Errorf("%w: %q can only be one of %s",
|
||||
ErrVPNProviderNameNotValid, provider, helpers.ChoicesOrString(validProviders))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gotree"
|
||||
)
|
||||
|
||||
@@ -21,7 +20,7 @@ type VPN struct {
|
||||
}
|
||||
|
||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||
func (v *VPN) validate(allServers models.AllServers) (err error) {
|
||||
func (v *VPN) validate(storage Storage) (err error) {
|
||||
// Validate Type
|
||||
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
|
||||
if !helpers.IsOneOf(v.Type, validVPNTypes...) {
|
||||
@@ -29,7 +28,7 @@ func (v *VPN) validate(allServers models.AllServers) (err error) {
|
||||
ErrVPNTypeNotValid, v.Type, strings.Join(validVPNTypes, ", "))
|
||||
}
|
||||
|
||||
err = v.Provider.validate(v.Type, allServers)
|
||||
err = v.Provider.validate(v.Type, storage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("provider settings: %w", err)
|
||||
}
|
||||
|
||||
10
internal/models/filters.go
Normal file
10
internal/models/filters.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package models
|
||||
|
||||
type FilterChoices struct {
|
||||
Countries []string
|
||||
Regions []string
|
||||
Cities []string
|
||||
ISPs []string
|
||||
Names []string
|
||||
Hostnames []string
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func (a AllServers) GetCopy() (allServersCopy AllServers) {
|
||||
allServersCopy.Version = a.Version
|
||||
allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers))
|
||||
for provider, servers := range a.ProviderToServers {
|
||||
allServersCopy.ProviderToServers[provider] = Servers{
|
||||
Version: servers.Version,
|
||||
Timestamp: servers.Timestamp,
|
||||
Servers: copyServers(servers.Servers),
|
||||
}
|
||||
}
|
||||
return allServersCopy
|
||||
}
|
||||
|
||||
func copyServers(servers []Server) (serversCopy []Server) {
|
||||
if servers == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
serversCopy = make([]Server, len(servers))
|
||||
for i, server := range servers {
|
||||
serversCopy[i] = server
|
||||
serversCopy[i].IPs = copyIPs(server.IPs)
|
||||
}
|
||||
|
||||
return serversCopy
|
||||
}
|
||||
|
||||
func copyIPs(toCopy []net.IP) (copied []net.IP) {
|
||||
if toCopy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
copied = make([]net.IP, len(toCopy))
|
||||
for i := range toCopy {
|
||||
copied[i] = copyIP(toCopy[i])
|
||||
}
|
||||
|
||||
return copied
|
||||
}
|
||||
|
||||
func copyIP(toCopy net.IP) (copied net.IP) {
|
||||
copied = make(net.IP, len(toCopy))
|
||||
copy(copied, toCopy)
|
||||
return copied
|
||||
}
|
||||
@@ -1,173 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_AllServers_GetCopy(t *testing.T) {
|
||||
allServers := AllServers{
|
||||
Version: 1,
|
||||
ProviderToServers: map[string]Servers{
|
||||
providers.Cyberghost: {
|
||||
Version: 2,
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Expressvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Fastestvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.HideMyAss: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Ipvanish: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Ivpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Mullvad: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Nordvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Perfectprivacy: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Privado: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.PrivateInternetAccess: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Privatevpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Protonvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Purevpn: {
|
||||
Version: 1,
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Surfshark: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Torguard: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.VPNUnlimited: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Vyprvpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Wevpn: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
providers.Windscribe: {
|
||||
Servers: []Server{{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
servers := allServers.GetCopy()
|
||||
|
||||
assert.Equal(t, allServers, servers)
|
||||
}
|
||||
|
||||
func Test_copyIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
toCopy []net.IP
|
||||
copied []net.IP
|
||||
}{
|
||||
"nil": {},
|
||||
"empty": {
|
||||
toCopy: []net.IP{},
|
||||
copied: []net.IP{},
|
||||
},
|
||||
"single IP": {
|
||||
toCopy: []net.IP{{1, 1, 1, 1}},
|
||||
copied: []net.IP{{1, 1, 1, 1}},
|
||||
},
|
||||
"two IPs": {
|
||||
toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
|
||||
copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Reserver leading 9 for copy modifications below
|
||||
for _, ipToCopy := range testCase.toCopy {
|
||||
require.NotEqual(t, 9, ipToCopy[0])
|
||||
}
|
||||
|
||||
copied := copyIPs(testCase.toCopy)
|
||||
|
||||
assert.Equal(t, testCase.copied, copied)
|
||||
|
||||
if len(copied) > 0 {
|
||||
original := testCase.toCopy[0][0]
|
||||
testCase.toCopy[0][0] = 9
|
||||
assert.NotEqual(t, 9, copied[0][0])
|
||||
testCase.toCopy[0][0] = original
|
||||
|
||||
copied[0][0] = 9
|
||||
assert.NotEqual(t, 9, testCase.toCopy[0][0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package models
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@@ -26,3 +27,28 @@ type Server struct {
|
||||
PortForward bool `json:"port_forward,omitempty"`
|
||||
IPs []net.IP `json:"ips,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) Equal(other Server) (equal bool) {
|
||||
if !ipsAreEqual(s.IPs, other.IPs) {
|
||||
return false
|
||||
}
|
||||
|
||||
serverCopy := *s
|
||||
serverCopy.IPs = nil
|
||||
other.IPs = nil
|
||||
return reflect.DeepEqual(serverCopy, other)
|
||||
}
|
||||
|
||||
func ipsAreEqual(a, b []net.IP) (equal bool) {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range a {
|
||||
if !a[i].Equal(b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
120
internal/models/server_test.go
Normal file
120
internal/models/server_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Server_Equal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
a *Server
|
||||
b Server
|
||||
equal bool
|
||||
}{
|
||||
"same IPs": {
|
||||
a: &Server{
|
||||
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
b: Server{
|
||||
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
equal: true,
|
||||
},
|
||||
"same IP strings": {
|
||||
a: &Server{
|
||||
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
b: Server{
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
},
|
||||
equal: true,
|
||||
},
|
||||
"different IPs": {
|
||||
a: &Server{
|
||||
IPs: []net.IP{{1, 2, 3, 4}, {2, 3, 4, 5}},
|
||||
},
|
||||
b: Server{
|
||||
IPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 4}},
|
||||
},
|
||||
},
|
||||
"all fields equal": {
|
||||
a: &Server{
|
||||
VPN: "vpn",
|
||||
Country: "country",
|
||||
Region: "region",
|
||||
City: "city",
|
||||
ISP: "isp",
|
||||
Owned: true,
|
||||
Number: 1,
|
||||
ServerName: "server_name",
|
||||
Hostname: "hostname",
|
||||
TCP: true,
|
||||
UDP: true,
|
||||
OvpnX509: "x509",
|
||||
RetroLoc: "retroloc",
|
||||
MultiHop: true,
|
||||
WgPubKey: "wgpubkey",
|
||||
Free: true,
|
||||
Stream: true,
|
||||
PortForward: true,
|
||||
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
b: Server{
|
||||
VPN: "vpn",
|
||||
Country: "country",
|
||||
Region: "region",
|
||||
City: "city",
|
||||
ISP: "isp",
|
||||
Owned: true,
|
||||
Number: 1,
|
||||
ServerName: "server_name",
|
||||
Hostname: "hostname",
|
||||
TCP: true,
|
||||
UDP: true,
|
||||
OvpnX509: "x509",
|
||||
RetroLoc: "retroloc",
|
||||
MultiHop: true,
|
||||
WgPubKey: "wgpubkey",
|
||||
Free: true,
|
||||
Stream: true,
|
||||
PortForward: true,
|
||||
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
equal: true,
|
||||
},
|
||||
"different field": {
|
||||
a: &Server{
|
||||
VPN: "vpn",
|
||||
},
|
||||
b: Server{
|
||||
VPN: "other vpn",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipsOfANotNil := testCase.a.IPs != nil
|
||||
ipsOfBNotNil := testCase.b.IPs != nil
|
||||
|
||||
equal := testCase.a.Equal(testCase.b)
|
||||
|
||||
assert.Equal(t, testCase.equal, equal)
|
||||
|
||||
// Ensure IPs field is not modified
|
||||
if ipsOfANotNil {
|
||||
assert.NotNil(t, testCase.a)
|
||||
}
|
||||
if ipsOfBNotNil {
|
||||
assert.NotNil(t, testCase.b)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,18 +15,6 @@ type AllServers struct {
|
||||
ProviderToServers map[string]Servers
|
||||
}
|
||||
|
||||
func (a *AllServers) ServersSlice(provider string) []Server {
|
||||
if provider == providers.Custom {
|
||||
return nil
|
||||
}
|
||||
|
||||
servers, ok := a.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s not found in all servers", provider))
|
||||
}
|
||||
return copyServers(servers.Servers)
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*AllServers)(nil)
|
||||
|
||||
// MarshalJSON marshals all servers to JSON.
|
||||
|
||||
66
internal/provider/common/mocks.go
Normal file
66
internal/provider/common/mocks.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/provider/common (interfaces: Storage)
|
||||
|
||||
// Package common is a generated GoMock package.
|
||||
package common
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
settings "github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
models "github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
// MockStorage is a mock of Storage interface.
|
||||
type MockStorage struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStorageMockRecorder
|
||||
}
|
||||
|
||||
// MockStorageMockRecorder is the mock recorder for MockStorage.
|
||||
type MockStorageMockRecorder struct {
|
||||
mock *MockStorage
|
||||
}
|
||||
|
||||
// NewMockStorage creates a new mock instance.
|
||||
func NewMockStorage(ctrl *gomock.Controller) *MockStorage {
|
||||
mock := &MockStorage{ctrl: ctrl}
|
||||
mock.recorder = &MockStorageMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// FilterServers mocks base method.
|
||||
func (m *MockStorage) FilterServers(arg0 string, arg1 settings.ServerSelection) ([]models.Server, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterServers", arg0, arg1)
|
||||
ret0, _ := ret[0].([]models.Server)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FilterServers indicates an expected call of FilterServers.
|
||||
func (mr *MockStorageMockRecorder) FilterServers(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterServers", reflect.TypeOf((*MockStorage)(nil).FilterServers), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetServerByName mocks base method.
|
||||
func (m *MockStorage) GetServerByName(arg0, arg1 string) (models.Server, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServerByName", arg0, arg1)
|
||||
ret0, _ := ret[0].(models.Server)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServerByName indicates an expected call of GetServerByName.
|
||||
func (mr *MockStorageMockRecorder) GetServerByName(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerByName", reflect.TypeOf((*MockStorage)(nil).GetServerByName), arg0, arg1)
|
||||
}
|
||||
5
internal/provider/common/mocks_generate_test.go
Normal file
5
internal/provider/common/mocks_generate_test.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package common
|
||||
|
||||
// Exceptionally, the storage mock is exported since it is used by all
|
||||
// provider subpackages tests, and it reduces test code duplication a lot.
|
||||
//go:generate mockgen -destination=mocks.go -package $GOPACKAGE . Storage
|
||||
12
internal/provider/common/storage.go
Normal file
12
internal/provider/common/storage.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
FilterServers(provider string, selection settings.ServerSelection) (
|
||||
servers []models.Server, err error)
|
||||
GetServerByName(provider, name string) (server models.Server, ok bool)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package cyberghost
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Cyberghost,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Cyberghost),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package expressvpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(0, 1195, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Expressvpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -1,41 +1,63 @@
|
||||
package expressvpn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Provider_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const provider = providers.Expressvpn
|
||||
|
||||
errTest := errors.New("test error")
|
||||
boolPtr := func(b bool) *bool { return &b }
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.Server
|
||||
filteredServers []models.Server
|
||||
storageErr error
|
||||
selection settings.ServerSelection
|
||||
connection models.Connection
|
||||
errWrapped error
|
||||
errMessage string
|
||||
panicMessage string
|
||||
}{
|
||||
"no server": {
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Expressvpn),
|
||||
errWrapped: utils.ErrNoServer,
|
||||
errMessage: "no server",
|
||||
"error": {
|
||||
storageErr: errTest,
|
||||
errWrapped: errTest,
|
||||
errMessage: "cannot filter servers: test error",
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
"default OpenVPN TCP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Expressvpn),
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(true),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
panicMessage: "no default OpenVPN TCP port is defined!",
|
||||
},
|
||||
"default OpenVPN UDP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(false),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
@@ -43,38 +65,14 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"target IP": {
|
||||
"default Wireguard port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
TargetIP: net.IPv4(2, 2, 2, 2),
|
||||
}.WithDefaults(providers.Expressvpn),
|
||||
servers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1195,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"with filter": {
|
||||
selection: settings.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
}.WithDefaults(providers.Expressvpn),
|
||||
servers: []models.Server{
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1195,
|
||||
Protocol: constants.UDP,
|
||||
Hostname: "b",
|
||||
},
|
||||
VPN: vpn.Wireguard,
|
||||
}.WithDefaults(provider),
|
||||
panicMessage: "no default Wireguard port is defined!",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -82,12 +80,23 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
storage := common.NewMockStorage(ctrl)
|
||||
storage.EXPECT().FilterServers(provider, testCase.selection).
|
||||
Return(testCase.filteredServers, testCase.storageErr)
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
m := New(testCase.servers, randSource)
|
||||
provider := New(storage, randSource)
|
||||
|
||||
connection, err := m.GetConnection(testCase.selection)
|
||||
if testCase.panicMessage != "" {
|
||||
assert.PanicsWithValue(t, testCase.panicMessage, func() {
|
||||
_, _ = provider.GetConnection(testCase.selection)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
connection, err := provider.GetConnection(testCase.selection)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Expressvpn),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package fastestvpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(4443, 4443, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Fastestvpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Fastestvpn),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package hidemyass
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(8080, 553, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.HideMyAss,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.HideMyAss),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package ipvanish
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Ipvanish,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Ipvanish),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package ivpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 1194, 58237) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Ivpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -1,41 +1,67 @@
|
||||
package ivpn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Provider_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const provider = providers.Ivpn
|
||||
|
||||
errTest := errors.New("test error")
|
||||
boolPtr := func(b bool) *bool { return &b }
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.Server
|
||||
filteredServers []models.Server
|
||||
storageErr error
|
||||
selection settings.ServerSelection
|
||||
connection models.Connection
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no server available": {
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Ivpn),
|
||||
errWrapped: utils.ErrNoServer,
|
||||
errMessage: "no server",
|
||||
"error": {
|
||||
storageErr: errTest,
|
||||
errWrapped: errTest,
|
||||
errMessage: "cannot filter servers: test error",
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
"default OpenVPN TCP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Ivpn),
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(true),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 443,
|
||||
Protocol: constants.TCP,
|
||||
},
|
||||
},
|
||||
"default OpenVPN UDP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(false),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
@@ -43,51 +69,36 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"target IP": {
|
||||
"default Wireguard port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
TargetIP: net.IPv4(2, 2, 2, 2),
|
||||
}.WithDefaults(providers.Ivpn),
|
||||
servers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
VPN: vpn.Wireguard,
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Type: vpn.Wireguard,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 58237,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"with filter": {
|
||||
selection: settings.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
}.WithDefaults(providers.Ivpn),
|
||||
servers: []models.Server{
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
Hostname: "b",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
storage := common.NewMockStorage(ctrl)
|
||||
storage.EXPECT().FilterServers(provider, testCase.selection).
|
||||
Return(testCase.filteredServers, testCase.storageErr)
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
m := New(testCase.servers, randSource)
|
||||
provider := New(storage, randSource)
|
||||
|
||||
connection, err := m.GetConnection(testCase.selection)
|
||||
connection, err := provider.GetConnection(testCase.selection)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Ivpn),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package mullvad
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 1194, 51820) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Mullvad,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -1,41 +1,67 @@
|
||||
package mullvad
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Provider_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const provider = providers.Mullvad
|
||||
|
||||
errTest := errors.New("test error")
|
||||
boolPtr := func(b bool) *bool { return &b }
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.Server
|
||||
filteredServers []models.Server
|
||||
storageErr error
|
||||
selection settings.ServerSelection
|
||||
connection models.Connection
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no server available": {
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Mullvad),
|
||||
errWrapped: utils.ErrNoServer,
|
||||
errMessage: "no server",
|
||||
"error": {
|
||||
storageErr: errTest,
|
||||
errWrapped: errTest,
|
||||
errMessage: "cannot filter servers: test error",
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.Server{
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
"default OpenVPN TCP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Mullvad),
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(true),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 443,
|
||||
Protocol: constants.TCP,
|
||||
},
|
||||
},
|
||||
"default OpenVPN UDP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(false),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
@@ -43,36 +69,17 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"target IP": {
|
||||
"default Wireguard port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
TargetIP: net.IPv4(2, 2, 2, 2),
|
||||
}.WithDefaults(providers.Mullvad),
|
||||
servers: []models.Server{
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
},
|
||||
VPN: vpn.Wireguard,
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"with filter": {
|
||||
selection: settings.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
}.WithDefaults(providers.Mullvad),
|
||||
servers: []models.Server{
|
||||
{VPN: vpn.OpenVPN, UDP: true, Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{VPN: vpn.OpenVPN, UDP: true, Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{VPN: vpn.OpenVPN, UDP: true, Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
},
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Hostname: "b",
|
||||
Type: vpn.Wireguard,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 51820,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
@@ -82,12 +89,16 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
storage := common.NewMockStorage(ctrl)
|
||||
storage.EXPECT().FilterServers(provider, testCase.selection).
|
||||
Return(testCase.filteredServers, testCase.storageErr)
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
m := New(testCase.servers, randSource)
|
||||
provider := New(storage, randSource)
|
||||
|
||||
connection, err := m.GetConnection(testCase.selection)
|
||||
connection, err := provider.GetConnection(testCase.selection)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Mullvad),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package nordvpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Nordvpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Nordvpn),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package perfectprivacy
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Perfectprivacy,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Perfectprivacy),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package privado
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Privado,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Privado),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package privateinternetaccess
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/privateinternetaccess/presets"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
@@ -20,5 +21,6 @@ func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
defaults.OpenVPNUDPPort = 1197
|
||||
}
|
||||
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.PrivateInternetAccess,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -15,12 +15,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/golibs/format"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrServerNameNotFound = errors.New("server name not found in servers")
|
||||
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
|
||||
ErrServerNameEmpty = errors.New("server name is empty")
|
||||
)
|
||||
@@ -29,11 +30,9 @@ var (
|
||||
func (p *Provider) PortForward(ctx context.Context, client *http.Client,
|
||||
logger utils.Logger, gateway net.IP, serverName string) (
|
||||
port uint16, err error) {
|
||||
var server models.Server
|
||||
for _, server = range p.servers {
|
||||
if server.ServerName == serverName {
|
||||
break
|
||||
}
|
||||
server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName)
|
||||
}
|
||||
|
||||
if !server.PortForward {
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
timeNow func() time.Time
|
||||
// Port forwarding
|
||||
@@ -17,11 +17,11 @@ type Provider struct {
|
||||
authFilePath string
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source,
|
||||
func New(storage common.Storage, randSource rand.Source,
|
||||
timeNow func() time.Time) *Provider {
|
||||
const jsonPortForwardPath = "/gluetun/piaportforward.json"
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
timeNow: timeNow,
|
||||
randSource: randSource,
|
||||
portForwardPath: jsonPortForwardPath,
|
||||
|
||||
@@ -2,6 +2,7 @@ package privatevpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Privatevpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Privatevpn),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package protonvpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Protonvpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Protonvpn),
|
||||
}
|
||||
|
||||
@@ -50,52 +50,57 @@ type PortForwarder interface {
|
||||
port uint16, gateway net.IP, serverName string) (err error)
|
||||
}
|
||||
|
||||
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider {
|
||||
serversSlice := allServers.ServersSlice(provider)
|
||||
type Storage interface {
|
||||
FilterServers(provider string, selection settings.ServerSelection) (
|
||||
servers []models.Server, err error)
|
||||
GetServerByName(provider, name string) (server models.Server, ok bool)
|
||||
}
|
||||
|
||||
func New(provider string, storage Storage, timeNow func() time.Time) Provider {
|
||||
randSource := rand.NewSource(timeNow().UnixNano())
|
||||
switch provider {
|
||||
case providers.Custom:
|
||||
return custom.New()
|
||||
case providers.Cyberghost:
|
||||
return cyberghost.New(serversSlice, randSource)
|
||||
return cyberghost.New(storage, randSource)
|
||||
case providers.Expressvpn:
|
||||
return expressvpn.New(serversSlice, randSource)
|
||||
return expressvpn.New(storage, randSource)
|
||||
case providers.Fastestvpn:
|
||||
return fastestvpn.New(serversSlice, randSource)
|
||||
return fastestvpn.New(storage, randSource)
|
||||
case providers.HideMyAss:
|
||||
return hidemyass.New(serversSlice, randSource)
|
||||
return hidemyass.New(storage, randSource)
|
||||
case providers.Ipvanish:
|
||||
return ipvanish.New(serversSlice, randSource)
|
||||
return ipvanish.New(storage, randSource)
|
||||
case providers.Ivpn:
|
||||
return ivpn.New(serversSlice, randSource)
|
||||
return ivpn.New(storage, randSource)
|
||||
case providers.Mullvad:
|
||||
return mullvad.New(serversSlice, randSource)
|
||||
return mullvad.New(storage, randSource)
|
||||
case providers.Nordvpn:
|
||||
return nordvpn.New(serversSlice, randSource)
|
||||
return nordvpn.New(storage, randSource)
|
||||
case providers.Perfectprivacy:
|
||||
return perfectprivacy.New(serversSlice, randSource)
|
||||
return perfectprivacy.New(storage, randSource)
|
||||
case providers.Privado:
|
||||
return privado.New(serversSlice, randSource)
|
||||
return privado.New(storage, randSource)
|
||||
case providers.PrivateInternetAccess:
|
||||
return privateinternetaccess.New(serversSlice, randSource, timeNow)
|
||||
return privateinternetaccess.New(storage, randSource, timeNow)
|
||||
case providers.Privatevpn:
|
||||
return privatevpn.New(serversSlice, randSource)
|
||||
return privatevpn.New(storage, randSource)
|
||||
case providers.Protonvpn:
|
||||
return protonvpn.New(serversSlice, randSource)
|
||||
return protonvpn.New(storage, randSource)
|
||||
case providers.Purevpn:
|
||||
return purevpn.New(serversSlice, randSource)
|
||||
return purevpn.New(storage, randSource)
|
||||
case providers.Surfshark:
|
||||
return surfshark.New(serversSlice, randSource)
|
||||
return surfshark.New(storage, randSource)
|
||||
case providers.Torguard:
|
||||
return torguard.New(serversSlice, randSource)
|
||||
return torguard.New(storage, randSource)
|
||||
case providers.VPNUnlimited:
|
||||
return vpnunlimited.New(serversSlice, randSource)
|
||||
return vpnunlimited.New(storage, randSource)
|
||||
case providers.Vyprvpn:
|
||||
return vyprvpn.New(serversSlice, randSource)
|
||||
return vyprvpn.New(storage, randSource)
|
||||
case providers.Wevpn:
|
||||
return wevpn.New(serversSlice, randSource)
|
||||
return wevpn.New(storage, randSource)
|
||||
case providers.Windscribe:
|
||||
return windscribe.New(serversSlice, randSource)
|
||||
return windscribe.New(storage, randSource)
|
||||
default:
|
||||
panic("provider " + provider + " is unknown") // should never occur
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package purevpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(80, 53, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Purevpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Purevpn),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package surfshark
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(1443, 1194, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Surfshark,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Surfshark),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package torguard
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(1912, 1912, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Torguard,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Torguard),
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
@@ -24,20 +24,20 @@ func NewConnectionDefaults(openvpnTCPPort, openvpnUDPPort,
|
||||
}
|
||||
}
|
||||
|
||||
var ErrNoServer = errors.New("no server")
|
||||
type Storage interface {
|
||||
FilterServers(provider string, selection settings.ServerSelection) (
|
||||
servers []models.Server, err error)
|
||||
}
|
||||
|
||||
func GetConnection(servers []models.Server,
|
||||
func GetConnection(provider string,
|
||||
storage Storage,
|
||||
selection settings.ServerSelection,
|
||||
defaults ConnectionDefaults,
|
||||
randSource rand.Source) (
|
||||
connection models.Connection, err error) {
|
||||
if len(servers) == 0 {
|
||||
return connection, ErrNoServer
|
||||
}
|
||||
|
||||
servers = filterServers(servers, selection)
|
||||
if len(servers) == 0 {
|
||||
return connection, noServerFoundError(selection)
|
||||
servers, err := storage.FilterServers(provider, selection)
|
||||
if err != nil {
|
||||
return connection, fmt.Errorf("cannot filter servers: %w", err)
|
||||
}
|
||||
|
||||
protocol := getProtocol(selection)
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.Server
|
||||
provider string
|
||||
filteredServers []models.Server
|
||||
filterError error
|
||||
serverSelection settings.ServerSelection
|
||||
defaults ConnectionDefaults
|
||||
randSource rand.Source
|
||||
@@ -25,25 +32,13 @@ func Test_GetConnection(t *testing.T) {
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no server": {
|
||||
serverSelection: settings.ServerSelection{}.
|
||||
WithDefaults(providers.Mullvad),
|
||||
errWrapped: ErrNoServer,
|
||||
errMessage: "no server",
|
||||
},
|
||||
"all servers filtered": {
|
||||
servers: []models.Server{
|
||||
{VPN: vpn.Wireguard},
|
||||
{VPN: vpn.Wireguard},
|
||||
},
|
||||
serverSelection: settings.ServerSelection{
|
||||
VPN: vpn.OpenVPN,
|
||||
}.WithDefaults(providers.Mullvad),
|
||||
errWrapped: ErrNoServerFound,
|
||||
errMessage: "no server found: for VPN openvpn; protocol udp",
|
||||
"storage filter error": {
|
||||
filterError: errTest,
|
||||
errWrapped: errTest,
|
||||
errMessage: "cannot filter servers: test error",
|
||||
},
|
||||
"server without IPs": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{VPN: vpn.OpenVPN, UDP: true},
|
||||
{VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
@@ -58,7 +53,7 @@ func Test_GetConnection(t *testing.T) {
|
||||
errMessage: "no connection to pick from",
|
||||
},
|
||||
"OpenVPN server with hostname": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{
|
||||
VPN: vpn.OpenVPN,
|
||||
UDP: true,
|
||||
@@ -79,7 +74,7 @@ func Test_GetConnection(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"OpenVPN server with x509": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{
|
||||
VPN: vpn.OpenVPN,
|
||||
UDP: true,
|
||||
@@ -101,7 +96,7 @@ func Test_GetConnection(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"server with IPv4 and IPv6": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{
|
||||
VPN: vpn.OpenVPN,
|
||||
UDP: true,
|
||||
@@ -128,7 +123,7 @@ func Test_GetConnection(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"mixed servers": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{
|
||||
VPN: vpn.OpenVPN,
|
||||
UDP: true,
|
||||
@@ -169,8 +164,14 @@ func Test_GetConnection(t *testing.T) {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
connection, err := GetConnection(testCase.servers,
|
||||
storage := common.NewMockStorage(ctrl)
|
||||
storage.EXPECT().
|
||||
FilterServers(testCase.provider, testCase.serverSelection).
|
||||
Return(testCase.filteredServers, testCase.filterError)
|
||||
|
||||
connection, err := GetConnection(testCase.provider, storage,
|
||||
testCase.serverSelection, testCase.defaults,
|
||||
testCase.randSource)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package vpnunlimited
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.VPNUnlimited,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.VPNUnlimited),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package vyprvpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Vyprvpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Vyprvpn),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package wevpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(1195, 1194, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Wevpn,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -1,43 +1,68 @@
|
||||
package wevpn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Provider_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const provider = providers.Wevpn
|
||||
|
||||
errTest := errors.New("test error")
|
||||
boolPtr := func(b bool) *bool { return &b }
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.Server
|
||||
filteredServers []models.Server
|
||||
storageErr error
|
||||
selection settings.ServerSelection
|
||||
connection models.Connection
|
||||
errWrapped error
|
||||
errMessage string
|
||||
panicMessage string
|
||||
}{
|
||||
"no server available": {
|
||||
"error": {
|
||||
storageErr: errTest,
|
||||
errWrapped: errTest,
|
||||
errMessage: "cannot filter servers: test error",
|
||||
},
|
||||
"default OpenVPN TCP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
VPN: vpn.OpenVPN,
|
||||
}.WithDefaults(providers.Wevpn),
|
||||
errWrapped: utils.ErrNoServer,
|
||||
errMessage: "no server",
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(true),
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 1195,
|
||||
Protocol: constants.TCP,
|
||||
},
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Wevpn),
|
||||
},
|
||||
"default OpenVPN UDP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(false),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
@@ -45,38 +70,14 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"target IP": {
|
||||
"default Wireguard port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
TargetIP: net.IPv4(2, 2, 2, 2),
|
||||
}.WithDefaults(providers.Wevpn),
|
||||
servers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"with filter": {
|
||||
selection: settings.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
}.WithDefaults(providers.Wevpn),
|
||||
servers: []models.Server{
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Hostname: "b",
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
VPN: vpn.Wireguard,
|
||||
}.WithDefaults(provider),
|
||||
panicMessage: "no default Wireguard port is defined!",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -84,12 +85,23 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
storage := common.NewMockStorage(ctrl)
|
||||
storage.EXPECT().FilterServers(provider, testCase.selection).
|
||||
Return(testCase.filteredServers, testCase.storageErr)
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
m := New(testCase.servers, randSource)
|
||||
provider := New(storage, randSource)
|
||||
|
||||
connection, err := m.GetConnection(testCase.selection)
|
||||
if testCase.panicMessage != "" {
|
||||
assert.PanicsWithValue(t, testCase.panicMessage, func() {
|
||||
_, _ = provider.GetConnection(testCase.selection)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
connection, err := provider.GetConnection(testCase.selection)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Wevpn),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package windscribe
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -9,5 +10,6 @@ import (
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 1194, 1194) //nolint:gomnd
|
||||
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
|
||||
return utils.GetConnection(providers.Windscribe,
|
||||
p.storage, selection, defaults, p.randSource)
|
||||
}
|
||||
|
||||
@@ -1,41 +1,68 @@
|
||||
package windscribe
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Provider_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const provider = providers.Windscribe
|
||||
|
||||
errTest := errors.New("test error")
|
||||
boolPtr := func(b bool) *bool { return &b }
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.Server
|
||||
filteredServers []models.Server
|
||||
storageErr error
|
||||
selection settings.ServerSelection
|
||||
connection models.Connection
|
||||
errWrapped error
|
||||
errMessage string
|
||||
panicMessage string
|
||||
}{
|
||||
"no server available": {
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Windscribe),
|
||||
errWrapped: utils.ErrNoServer,
|
||||
errMessage: "no server",
|
||||
"error": {
|
||||
storageErr: errTest,
|
||||
errWrapped: errTest,
|
||||
errMessage: "cannot filter servers: test error",
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.Server{
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
"default OpenVPN TCP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{}.WithDefaults(providers.Windscribe),
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(true),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 443,
|
||||
Protocol: constants.TCP,
|
||||
},
|
||||
},
|
||||
"default OpenVPN UDP port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
TCP: boolPtr(false),
|
||||
},
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
@@ -43,49 +70,41 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"target IP": {
|
||||
"default Wireguard port": {
|
||||
filteredServers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
},
|
||||
selection: settings.ServerSelection{
|
||||
TargetIP: net.IPv4(2, 2, 2, 2),
|
||||
}.WithDefaults(providers.Windscribe),
|
||||
servers: []models.Server{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
VPN: vpn.Wireguard,
|
||||
}.WithDefaults(provider),
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Type: vpn.Wireguard,
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"with filter": {
|
||||
selection: settings.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
}.WithDefaults(providers.Windscribe),
|
||||
servers: []models.Server{
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
connection: models.Connection{
|
||||
Type: vpn.OpenVPN,
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Hostname: "b",
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
storage := common.NewMockStorage(ctrl)
|
||||
storage.EXPECT().FilterServers(provider, testCase.selection).
|
||||
Return(testCase.filteredServers, testCase.storageErr)
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
provider := New(testCase.servers, randSource)
|
||||
provider := New(storage, randSource)
|
||||
|
||||
if testCase.panicMessage != "" {
|
||||
assert.PanicsWithValue(t, testCase.panicMessage, func() {
|
||||
_, _ = provider.GetConnection(testCase.selection)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
connection, err := provider.GetConnection(testCase.selection)
|
||||
|
||||
@@ -93,6 +112,7 @@ func Test_Provider_GetConnection(t *testing.T) {
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.connection, connection)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
servers []models.Server
|
||||
storage common.Storage
|
||||
randSource rand.Source
|
||||
utils.NoPortForwarder
|
||||
}
|
||||
|
||||
func New(servers []models.Server, randSource rand.Source) *Provider {
|
||||
func New(storage common.Storage, randSource rand.Source) *Provider {
|
||||
return &Provider{
|
||||
servers: servers,
|
||||
storage: storage,
|
||||
randSource: randSource,
|
||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Windscribe),
|
||||
}
|
||||
|
||||
27
internal/storage/choices.go
Normal file
27
internal/storage/choices.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings/validation"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func (s *Storage) GetFilterChoices(provider string) models.FilterChoices {
|
||||
if provider == providers.Custom {
|
||||
return models.FilterChoices{}
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
servers := serversObject.Servers
|
||||
return models.FilterChoices{
|
||||
Countries: validation.ExtractCountries(servers),
|
||||
Regions: validation.ExtractRegions(servers),
|
||||
Cities: validation.ExtractCities(servers),
|
||||
ISPs: validation.ExtractISPs(servers),
|
||||
Names: validation.ExtractServerNames(servers),
|
||||
Hostnames: validation.ExtractHostnames(servers),
|
||||
}
|
||||
}
|
||||
32
internal/storage/copy.go
Normal file
32
internal/storage/copy.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func copyServer(server models.Server) (serverCopy models.Server) {
|
||||
serverCopy = server
|
||||
serverCopy.IPs = copyIPs(server.IPs)
|
||||
return serverCopy
|
||||
}
|
||||
|
||||
func copyIPs(toCopy []net.IP) (copied []net.IP) {
|
||||
if toCopy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
copied = make([]net.IP, len(toCopy))
|
||||
for i := range toCopy {
|
||||
copied[i] = copyIP(toCopy[i])
|
||||
}
|
||||
|
||||
return copied
|
||||
}
|
||||
|
||||
func copyIP(toCopy net.IP) (copied net.IP) {
|
||||
copied = make(net.IP, len(toCopy))
|
||||
copy(copied, toCopy)
|
||||
return copied
|
||||
}
|
||||
75
internal/storage/copy_test.go
Normal file
75
internal/storage/copy_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_copyServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := models.Server{
|
||||
Country: "a",
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}
|
||||
|
||||
serverCopy := copyServer(server)
|
||||
|
||||
assert.Equal(t, server, serverCopy)
|
||||
// Check for mutation
|
||||
serverCopy.IPs[0][0] = 9
|
||||
assert.NotEqual(t, server, serverCopy)
|
||||
}
|
||||
|
||||
func Test_copyIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
toCopy []net.IP
|
||||
copied []net.IP
|
||||
}{
|
||||
"nil": {},
|
||||
"empty": {
|
||||
toCopy: []net.IP{},
|
||||
copied: []net.IP{},
|
||||
},
|
||||
"single IP": {
|
||||
toCopy: []net.IP{{1, 1, 1, 1}},
|
||||
copied: []net.IP{{1, 1, 1, 1}},
|
||||
},
|
||||
"two IPs": {
|
||||
toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
|
||||
copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Reserver leading 9 for copy modifications below
|
||||
for _, ipToCopy := range testCase.toCopy {
|
||||
require.NotEqual(t, 9, ipToCopy[0])
|
||||
}
|
||||
|
||||
copied := copyIPs(testCase.toCopy)
|
||||
|
||||
assert.Equal(t, testCase.copied, copied)
|
||||
|
||||
if len(copied) > 0 {
|
||||
original := testCase.toCopy[0][0]
|
||||
testCase.toCopy[0][0] = 9
|
||||
assert.NotEqual(t, 9, copied[0][0])
|
||||
testCase.toCopy[0][0] = original
|
||||
|
||||
copied[0][0] = 9
|
||||
assert.NotEqual(t, 9, testCase.toCopy[0][0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
143
internal/storage/filter.go
Normal file
143
internal/storage/filter.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
// FilterServers filter servers for the given provider and according
|
||||
// to the given selection. The filtered servers are deep copied so they
|
||||
// are safe for mutation by the caller.
|
||||
func (s *Storage) FilterServers(provider string, selection settings.ServerSelection) (
|
||||
servers []models.Server, err error) {
|
||||
if provider == providers.Custom {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
allServers := serversObject.Servers
|
||||
|
||||
if len(allServers) == 0 {
|
||||
return nil, ErrNoServerFound
|
||||
}
|
||||
|
||||
for _, server := range allServers {
|
||||
if filterServer(server, selection) {
|
||||
continue
|
||||
}
|
||||
|
||||
server = copyServer(server)
|
||||
servers = append(servers, server)
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
return nil, noServerFoundError(selection)
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
func filterServer(server models.Server,
|
||||
selection settings.ServerSelection) (filtered bool) {
|
||||
// Note each condition is split to make sure
|
||||
// we have full testing coverage.
|
||||
if server.VPN != selection.VPN {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByProtocol(selection, server.TCP, server.UDP) {
|
||||
return true
|
||||
}
|
||||
|
||||
if *selection.MultiHopOnly && !server.MultiHop {
|
||||
return true
|
||||
}
|
||||
|
||||
if *selection.FreeOnly && !server.Free {
|
||||
return true
|
||||
}
|
||||
|
||||
if *selection.StreamOnly && !server.Stream {
|
||||
return true
|
||||
}
|
||||
|
||||
if *selection.OwnedOnly && !server.Owned {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.Country, selection.Countries) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.Region, selection.Regions) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.City, selection.Cities) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.ISP, selection.ISPs) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilitiesUint16(server.Number, selection.Numbers) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.ServerName, selection.Names) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.Hostname, selection.Hostnames) {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO filter port forward server for PIA
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func filterByPossibilities(value string, possibilities []string) (filtered bool) {
|
||||
if len(possibilities) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, possibility := range possibilities {
|
||||
if strings.EqualFold(value, possibility) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO merge with filterByPossibilities with generics in Go 1.18.
|
||||
func filterByPossibilitiesUint16(value uint16, possibilities []uint16) (filtered bool) {
|
||||
if len(possibilities) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, possibility := range possibilities {
|
||||
if value == possibility {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func filterByProtocol(selection settings.ServerSelection,
|
||||
serverTCP, serverUDP bool) (filtered bool) {
|
||||
switch selection.VPN {
|
||||
case vpn.Wireguard:
|
||||
return !serverUDP
|
||||
default: // OpenVPN
|
||||
wantTCP := *selection.OpenVPN.TCP
|
||||
wantUDP := !wantTCP
|
||||
return (wantTCP && !serverTCP) || (wantUDP && !serverUDP)
|
||||
}
|
||||
}
|
||||
@@ -4,21 +4,20 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var _ Flusher = (*Storage)(nil)
|
||||
// FlushToFile flushes the merged servers data to the file
|
||||
// specified by path, as indented JSON.
|
||||
func (s *Storage) FlushToFile(path string) error {
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
type Flusher interface {
|
||||
FlushToFile(allServers *models.AllServers) error
|
||||
return s.flushToFile(path)
|
||||
}
|
||||
|
||||
func (s *Storage) FlushToFile(allServers *models.AllServers) error {
|
||||
return flushToFile(s.filepath, allServers)
|
||||
}
|
||||
|
||||
func flushToFile(path string, servers *models.AllServers) error {
|
||||
// flushToFile flushes the merged servers data to the file
|
||||
// specified by path, as indented JSON. It is not thread-safe.
|
||||
func (s *Storage) flushToFile(path string) error {
|
||||
dirPath := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dirPath, 0644); err != nil {
|
||||
return err
|
||||
@@ -28,11 +27,15 @@ func flushToFile(path string, servers *models.AllServers) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(servers); err != nil {
|
||||
|
||||
err = encoder.Encode(&s.mergedServers)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package utils
|
||||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -1,7 +1,118 @@
|
||||
package storage
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/models"
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
func (s *Storage) GetServers() models.AllServers {
|
||||
return s.mergedServers.GetCopy()
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
// SetServers sets the given servers for the given provider
|
||||
// in the storage in-memory map and saves all the servers
|
||||
// to file.
|
||||
// Note the servers given are not copied so the caller must
|
||||
// NOT MUTATE them after calling this method.
|
||||
func (s *Storage) SetServers(provider string, servers []models.Server) (err error) {
|
||||
if provider == providers.Custom {
|
||||
return
|
||||
}
|
||||
|
||||
s.mergedMutex.Lock()
|
||||
defer s.mergedMutex.Unlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
serversObject.Timestamp = time.Now().Unix()
|
||||
serversObject.Servers = servers
|
||||
s.mergedServers.ProviderToServers[provider] = serversObject
|
||||
|
||||
err = s.flushToFile(s.filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot save servers to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServerByName returns the server for the given provider
|
||||
// and server name. It returns `ok` as false if the server is
|
||||
// not found. The returned server is also deep copied so it is
|
||||
// safe for mutation and/or thread safe use.
|
||||
func (s *Storage) GetServerByName(provider, name string) (
|
||||
server models.Server, ok bool) {
|
||||
if provider == providers.Custom {
|
||||
return server, false
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
for _, server := range serversObject.Servers {
|
||||
if server.ServerName == name {
|
||||
return copyServer(server), true
|
||||
}
|
||||
}
|
||||
|
||||
return server, false
|
||||
}
|
||||
|
||||
// GetServersCount returns the number of servers for the provider given.
|
||||
func (s *Storage) GetServersCount(provider string) (count int) {
|
||||
if provider == providers.Custom {
|
||||
return 0
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
return len(serversObject.Servers)
|
||||
}
|
||||
|
||||
// FormatToMarkdown Markdown formats the servers for the provider given
|
||||
// and returns the resulting string.
|
||||
func (s *Storage) FormatToMarkdown(provider string) (formatted string) {
|
||||
if provider == providers.Custom {
|
||||
return ""
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
formatted = serversObject.ToMarkdown(provider)
|
||||
return formatted
|
||||
}
|
||||
|
||||
// GetServersCount returns the number of servers for the provider given.
|
||||
func (s *Storage) ServersAreEqual(provider string, servers []models.Server) (equal bool) {
|
||||
if provider == providers.Custom {
|
||||
return true
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
existingServers := serversObject.Servers
|
||||
|
||||
if len(existingServers) != len(servers) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range existingServers {
|
||||
if !existingServers[i].Equal(servers[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Storage) getMergedServersObject(provider string) (serversObject models.Servers) {
|
||||
serversObject, ok := s.mergedServers.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s not found in in-memory servers map", provider))
|
||||
}
|
||||
return serversObject
|
||||
}
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
mergedServers models.AllServers
|
||||
mergedMutex sync.RWMutex
|
||||
// this is stored in memory to avoid re-parsing
|
||||
// the embedded JSON file on every call to the
|
||||
// SyncServers method.
|
||||
|
||||
@@ -29,6 +29,9 @@ func (s *Storage) syncServers() (err error) {
|
||||
hardcodedCount := countServers(s.hardcodedServers)
|
||||
countOnFile := countServers(serversOnFile)
|
||||
|
||||
s.mergedMutex.Lock()
|
||||
defer s.mergedMutex.Unlock()
|
||||
|
||||
if countOnFile == 0 {
|
||||
s.logger.Info(fmt.Sprintf(
|
||||
"creating %s with %d hardcoded servers",
|
||||
@@ -47,7 +50,8 @@ func (s *Storage) syncServers() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := flushToFile(s.filepath, &s.mergedServers); err != nil {
|
||||
err = s.flushToFile(s.filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot write servers to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
)
|
||||
|
||||
@@ -24,15 +23,13 @@ type Looper interface {
|
||||
}
|
||||
|
||||
type Updater interface {
|
||||
UpdateServers(ctx context.Context) (allServers models.AllServers, err error)
|
||||
UpdateServers(ctx context.Context) (err error)
|
||||
}
|
||||
|
||||
type looper struct {
|
||||
state state
|
||||
// Objects
|
||||
updater Updater
|
||||
flusher storage.Flusher
|
||||
setAllServers func(allServers models.AllServers)
|
||||
logger Logger
|
||||
// Internal channels and locks
|
||||
loopLock sync.Mutex
|
||||
@@ -49,23 +46,26 @@ type looper struct {
|
||||
|
||||
const defaultBackoffTime = 5 * time.Second
|
||||
|
||||
type Storage interface {
|
||||
SetServers(provider string, servers []models.Server) (err error)
|
||||
GetServersCount(provider string) (count int)
|
||||
ServersAreEqual(provider string, servers []models.Server) (equal bool)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
func NewLooper(settings settings.Updater, currentServers models.AllServers,
|
||||
flusher storage.Flusher, setAllServers func(allServers models.AllServers),
|
||||
func NewLooper(settings settings.Updater, storage Storage,
|
||||
client *http.Client, logger Logger) Looper {
|
||||
return &looper{
|
||||
state: state{
|
||||
status: constants.Stopped,
|
||||
settings: settings,
|
||||
},
|
||||
updater: updater.New(settings, client, currentServers, logger),
|
||||
flusher: flusher,
|
||||
setAllServers: setAllServers,
|
||||
updater: updater.New(settings, client, storage, logger),
|
||||
logger: logger,
|
||||
start: make(chan struct{}),
|
||||
running: make(chan models.LoopStatus),
|
||||
@@ -106,20 +106,19 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||
for ctx.Err() == nil {
|
||||
updateCtx, updateCancel := context.WithCancel(ctx)
|
||||
|
||||
serversCh := make(chan models.AllServers)
|
||||
errorCh := make(chan error)
|
||||
runWg := &sync.WaitGroup{}
|
||||
runWg.Add(1)
|
||||
go func() {
|
||||
defer runWg.Done()
|
||||
servers, err := l.updater.UpdateServers(updateCtx)
|
||||
err := l.updater.UpdateServers(updateCtx)
|
||||
if err != nil {
|
||||
if updateCtx.Err() == nil {
|
||||
errorCh <- err
|
||||
}
|
||||
return
|
||||
}
|
||||
serversCh <- servers
|
||||
l.state.setStatusWithLock(constants.Completed)
|
||||
}()
|
||||
|
||||
if !crashed {
|
||||
@@ -148,16 +147,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||
updateCancel()
|
||||
runWg.Wait()
|
||||
l.stopped <- struct{}{}
|
||||
case servers := <-serversCh:
|
||||
l.setAllServers(servers)
|
||||
if err := l.flusher.FlushToFile(&servers); err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
runWg.Wait()
|
||||
l.state.setStatusWithLock(constants.Completed)
|
||||
l.logger.Info("Updated servers information")
|
||||
case err := <-errorCh:
|
||||
close(serversCh)
|
||||
runWg.Wait()
|
||||
l.state.setStatusWithLock(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
|
||||
@@ -3,8 +3,6 @@ package updater
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
@@ -31,18 +29,25 @@ import (
|
||||
)
|
||||
|
||||
func (u *Updater) updateProvider(ctx context.Context, provider string) (err error) {
|
||||
existingServers := u.getProviderServers(provider)
|
||||
minServers := getMinServers(existingServers)
|
||||
existingServersCount := u.storage.GetServersCount(provider)
|
||||
minServers := getMinServers(existingServersCount)
|
||||
servers, err := u.getServers(ctx, provider, minServers)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("cannot get servers: %w", err)
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(existingServers, servers) {
|
||||
if u.storage.ServersAreEqual(provider, servers) {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.patchProvider(provider, servers)
|
||||
// Note the servers variable must NOT BE MUTATED after this call,
|
||||
// since the implementation does not deep copy the servers.
|
||||
// TODO set in storage in provider updater directly, server by server,
|
||||
// to avoid accumulating server data in memory.
|
||||
err = u.storage.SetServers(provider, servers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot set servers to storage: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -101,25 +106,7 @@ func (u *Updater) getServers(ctx context.Context, provider string,
|
||||
return providerUpdater.GetServers(ctx, minServers)
|
||||
}
|
||||
|
||||
func (u *Updater) getProviderServers(provider string) (servers []models.Server) {
|
||||
providerServers, ok := u.servers.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s is unknown", provider))
|
||||
}
|
||||
return providerServers.Servers
|
||||
}
|
||||
|
||||
func getMinServers(servers []models.Server) (minServers int) {
|
||||
func getMinServers(existingServersCount int) (minServers int) {
|
||||
const minRatio = 0.8
|
||||
return int(minRatio * float64(len(servers)))
|
||||
}
|
||||
|
||||
func (u *Updater) patchProvider(provider string, servers []models.Server) {
|
||||
providerServers, ok := u.servers.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s is unknown", provider))
|
||||
}
|
||||
providerServers.Timestamp = time.Now().Unix()
|
||||
providerServers.Servers = servers
|
||||
u.servers.ProviderToServers[provider] = providerServers
|
||||
return int(minRatio * float64(existingServersCount))
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type Updater struct {
|
||||
options settings.Updater
|
||||
|
||||
// state
|
||||
servers models.AllServers
|
||||
storage Storage
|
||||
|
||||
// Functions for tests
|
||||
logger Logger
|
||||
@@ -29,6 +29,12 @@ type Updater struct {
|
||||
unzipper unzip.Unzipper
|
||||
}
|
||||
|
||||
type Storage interface {
|
||||
SetServers(provider string, servers []models.Server) (err error)
|
||||
GetServersCount(provider string) (count int)
|
||||
ServersAreEqual(provider string, servers []models.Server) (equal bool)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
@@ -36,20 +42,20 @@ type Logger interface {
|
||||
}
|
||||
|
||||
func New(settings settings.Updater, httpClient *http.Client,
|
||||
currentServers models.AllServers, logger Logger) *Updater {
|
||||
storage Storage, logger Logger) *Updater {
|
||||
unzipper := unzip.New(httpClient)
|
||||
return &Updater{
|
||||
options: settings,
|
||||
storage: storage,
|
||||
logger: logger,
|
||||
timeNow: time.Now,
|
||||
presolver: resolver.NewParallelResolver(settings.DNSAddress.String()),
|
||||
client: httpClient,
|
||||
unzipper: unzipper,
|
||||
options: settings,
|
||||
servers: currentServers,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) {
|
||||
func (u *Updater) UpdateServers(ctx context.Context) (err error) {
|
||||
caser := cases.Title(language.English)
|
||||
for _, provider := range u.options.Providers {
|
||||
u.logger.Info("updating " + caser.String(provider) + " servers...")
|
||||
@@ -62,17 +68,17 @@ func (u *Updater) UpdateServers(ctx context.Context) (allServers models.AllServe
|
||||
|
||||
// return the only error for the single provider.
|
||||
if len(u.options.Providers) == 1 {
|
||||
return allServers, err
|
||||
return err
|
||||
}
|
||||
|
||||
// stop updating the next providers if context is canceled.
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return allServers, ctxErr
|
||||
return ctxErr
|
||||
}
|
||||
|
||||
// Log the error and continue updating the next provider.
|
||||
u.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
return u.servers, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,12 +27,12 @@ type Looper interface {
|
||||
loopstate.Getter
|
||||
loopstate.Applier
|
||||
SettingsGetSetter
|
||||
ServersGetterSetter
|
||||
}
|
||||
|
||||
type Loop struct {
|
||||
statusManager loopstate.Manager
|
||||
state state.Manager
|
||||
storage Storage
|
||||
// Fixed parameters
|
||||
buildInfo models.BuildInformation
|
||||
versionInfo bool
|
||||
@@ -64,12 +64,17 @@ type firewallConfigurer interface {
|
||||
firewall.PortAllower
|
||||
}
|
||||
|
||||
type Storage interface {
|
||||
FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error)
|
||||
GetServerByName(provider, name string) (server models.Server, ok bool)
|
||||
}
|
||||
|
||||
const (
|
||||
defaultBackoffTime = 15 * time.Second
|
||||
)
|
||||
|
||||
func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16,
|
||||
allServers models.AllServers, openvpnConf openvpn.Interface,
|
||||
storage Storage, openvpnConf openvpn.Interface,
|
||||
netLinker netlink.NetLinker, fw firewallConfigurer, routing routing.VPNGetter,
|
||||
portForward portforward.StartStopper, starter command.Starter,
|
||||
publicip publicip.Looper, dnsLooper dns.Looper,
|
||||
@@ -81,11 +86,12 @@ func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16,
|
||||
stopped := make(chan struct{})
|
||||
|
||||
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
|
||||
state := state.New(statusManager, vpnSettings, allServers)
|
||||
state := state.New(statusManager, vpnSettings)
|
||||
|
||||
return &Loop{
|
||||
statusManager: statusManager,
|
||||
state: state,
|
||||
storage: storage,
|
||||
buildInfo: buildInfo,
|
||||
versionInfo: versionInfo,
|
||||
vpnInputPorts: vpnInputPorts,
|
||||
|
||||
@@ -28,9 +28,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
}
|
||||
|
||||
for ctx.Err() == nil {
|
||||
settings, allServers := l.state.GetSettingsAndServers()
|
||||
settings := l.state.GetSettings()
|
||||
|
||||
providerConf := provider.New(*settings.Provider.Name, allServers, time.Now)
|
||||
providerConf := provider.New(*settings.Provider.Name, l.storage, time.Now)
|
||||
|
||||
portForwarding := *settings.Provider.PortForwarding.Enabled
|
||||
var vpnRunner vpnRunner
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/vpn/state"
|
||||
)
|
||||
|
||||
type ServersGetterSetter = state.ServersGetterSetter
|
||||
|
||||
func (l *Loop) GetServers() (servers models.AllServers) {
|
||||
return l.state.GetServers()
|
||||
}
|
||||
|
||||
func (l *Loop) SetServers(servers models.AllServers) {
|
||||
l.state.SetServers(servers)
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package state
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/models"
|
||||
|
||||
type ServersGetterSetter interface {
|
||||
GetServers() (servers models.AllServers)
|
||||
SetServers(servers models.AllServers)
|
||||
}
|
||||
|
||||
func (s *State) GetServers() (servers models.AllServers) {
|
||||
s.allServersMu.RLock()
|
||||
defer s.allServersMu.RUnlock()
|
||||
return s.allServers
|
||||
}
|
||||
|
||||
func (s *State) SetServers(servers models.AllServers) {
|
||||
s.allServersMu.Lock()
|
||||
defer s.allServersMu.Unlock()
|
||||
s.allServers = servers
|
||||
}
|
||||
@@ -5,23 +5,18 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var _ Manager = (*State)(nil)
|
||||
|
||||
type Manager interface {
|
||||
SettingsGetSetter
|
||||
ServersGetterSetter
|
||||
GetSettingsAndServers() (vpn settings.VPN, allServers models.AllServers)
|
||||
}
|
||||
|
||||
func New(statusApplier loopstate.Applier,
|
||||
vpn settings.VPN, allServers models.AllServers) *State {
|
||||
func New(statusApplier loopstate.Applier, vpn settings.VPN) *State {
|
||||
return &State{
|
||||
statusApplier: statusApplier,
|
||||
vpn: vpn,
|
||||
allServers: allServers,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,18 +25,4 @@ type State struct {
|
||||
|
||||
vpn settings.VPN
|
||||
settingsMu sync.RWMutex
|
||||
|
||||
allServers models.AllServers
|
||||
allServersMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *State) GetSettingsAndServers() (vpn settings.VPN,
|
||||
allServers models.AllServers) {
|
||||
s.settingsMu.RLock()
|
||||
s.allServersMu.RLock()
|
||||
vpn = s.vpn
|
||||
allServers = s.allServers
|
||||
s.settingsMu.RUnlock()
|
||||
s.allServersMu.RUnlock()
|
||||
return vpn, allServers
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user