chore(all): memory and thread safe storage

- settings: get filter choices from storage for settings validation
- updater: update servers to the storage
- storage: minimal deep copying and data duplication
- storage: add merged servers mutex for thread safety
- connection: filter servers in storage
- formatter: format servers to Markdown in storage
- PIA: get server by name from storage directly
- Updater: get servers count from storage directly
- Updater: equality check done in storage, fix #882
This commit is contained in:
Quentin McGaw
2022-06-05 14:58:46 +00:00
parent 1e6b4ed5eb
commit 36b504609b
84 changed files with 1267 additions and 877 deletions

View File

@@ -215,9 +215,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return err
}
allServers := storage.GetServers()
err = allSettings.Validate(allServers)
err = allSettings.Validate(storage)
if err != nil {
return err
}
@@ -378,7 +376,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
vpnLogger := logger.New(log.SetComponent("vpn"))
vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.Firewall.VPNInputPorts,
allServers, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper,
storage, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper,
cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient,
buildInfo, *allSettings.Version.Enabled)
vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler(
@@ -386,8 +384,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
go vpnLooper.Run(vpnCtx, vpnDone)
updaterLooper := updater.NewLooper(allSettings.Updater,
allServers, storage, vpnLooper.SetServers, httpClient,
logger.New(log.SetComponent("updater")))
storage, httpClient, logger.New(log.SetComponent("updater")))
updaterHandler, updaterCtx, updaterDone := goshutdown.NewGoRoutineHandler(
"updater", goroutine.OptionTimeout(defaultShutdownTimeout))
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker

View File

@@ -10,7 +10,6 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage"
"golang.org/x/text/cases"
"golang.org/x/text/language"
@@ -80,9 +79,8 @@ func (c *CLI) FormatServers(args []string) error {
if err != nil {
return fmt.Errorf("cannot create servers storage: %w", err)
}
currentServers := storage.GetServers()
formatted := formatServers(currentServers, providerToFormat)
formatted := storage.FormatToMarkdown(providerToFormat)
output = filepath.Clean(output)
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
@@ -103,11 +101,3 @@ func (c *CLI) FormatServers(args []string) error {
return nil
}
func formatServers(allServers models.AllServers, provider string) (formatted string) {
servers, ok := allServers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("unknown provider in format map: %s", provider))
}
return servers.ToMarkdown(provider)
}

View File

@@ -25,18 +25,17 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) e
if err != nil {
return err
}
allServers := storage.GetServers()
allSettings, err := source.Read()
if err != nil {
return err
}
if err = allSettings.Validate(allServers); err != nil {
if err = allSettings.Validate(storage); err != nil {
return err
}
providerConf := provider.New(*allSettings.VPN.Provider.Name, allServers, time.Now)
providerConf := provider.New(*allSettings.VPN.Provider.Name, storage, time.Now)
connection, err := providerConf.GetConnection(allSettings.VPN.Provider.ServerSelection)
if err != nil {
return err

View File

@@ -2,20 +2,17 @@ package cli
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater"
)
@@ -83,41 +80,19 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
if err != nil {
return fmt.Errorf("cannot create servers storage: %w", err)
}
currentServers := storage.GetServers()
updater := updater.New(options, httpClient, currentServers, logger)
allServers, err := updater.UpdateServers(ctx)
updater := updater.New(options, httpClient, storage, logger)
err = updater.UpdateServers(ctx)
if err != nil {
return fmt.Errorf("cannot update server information: %w", err)
}
if endUserMode {
if err := storage.FlushToFile(&allServers); err != nil {
return fmt.Errorf("cannot write updated information to file: %w", err)
}
}
if maintainerMode {
if err := writeToEmbeddedJSON(c.repoServersPath, &allServers); err != nil {
return fmt.Errorf("cannot write updated information to file: %w", err)
err := storage.FlushToFile(c.repoServersPath)
if err != nil {
return fmt.Errorf("cannot write servers data to embedded JSON file: %w", err)
}
}
return nil
}
func writeToEmbeddedJSON(repoServersPath string,
allServers *models.AllServers) error {
const perms = 0600
f, err := os.OpenFile(repoServersPath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)
if err != nil {
return err
}
defer f.Close()
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
return encoder.Encode(allServers)
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gotree"
)
@@ -23,7 +22,7 @@ type Provider struct {
}
// TODO v4 remove pointer for receiver (because of Surfshark).
func (p *Provider) validate(vpnType string, allServers models.AllServers) (err error) {
func (p *Provider) validate(vpnType string, storage Storage) (err error) {
// Validate Name
var validNames []string
if vpnType == vpn.OpenVPN {
@@ -42,7 +41,7 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e
ErrVPNProviderNameNotValid, *p.Name, helpers.ChoicesOrString(validNames))
}
err = p.ServerSelection.validate(*p.Name, allServers)
err = p.ServerSelection.validate(*p.Name, storage)
if err != nil {
return fmt.Errorf("server selection: %w", err)
}

View File

@@ -68,21 +68,19 @@ var (
)
func (ss *ServerSelection) validate(vpnServiceProvider string,
allServers models.AllServers) (err error) {
storage Storage) (err error) {
switch ss.VPN {
case vpn.OpenVPN, vpn.Wireguard:
default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
}
countryChoices, regionChoices, cityChoices,
ispChoices, nameChoices, hostnameChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, allServers)
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage)
if err != nil {
return err // already wrapped error
}
err = validateServerFilters(*ss, countryChoices, regionChoices, cityChoices,
ispChoices, nameChoices, hostnameChoices)
err = validateServerFilters(*ss, filterChoices)
if err != nil {
if errors.Is(err, helpers.ErrNoChoice) {
return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err)
@@ -135,63 +133,48 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
return nil
}
func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection,
allServers models.AllServers) (
countryChoices, regionChoices, cityChoices,
ispChoices, nameChoices, hostnameChoices []string,
func getLocationFilterChoices(vpnServiceProvider string,
ss *ServerSelection, storage Storage) (filterChoices models.FilterChoices,
err error) {
providerServers, ok := allServers.ProviderToServers[vpnServiceProvider]
if !ok && vpnServiceProvider != providers.Custom {
panic(fmt.Sprintf("VPN service provider unknown: %s", vpnServiceProvider))
}
servers := providerServers.Servers
countryChoices = validation.ExtractCountries(servers)
regionChoices = validation.ExtractRegions(servers)
cityChoices = validation.ExtractCities(servers)
ispChoices = validation.ExtractISPs(servers)
nameChoices = validation.ExtractServerNames(servers)
hostnameChoices = validation.ExtractHostnames(servers)
filterChoices = storage.GetFilterChoices(vpnServiceProvider)
if vpnServiceProvider == providers.Surfshark {
// // Retro compatibility
// TODO v4 remove
regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...)
if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err)
filterChoices.Regions = append(filterChoices.Regions, validation.SurfsharkRetroLocChoices()...)
if err := helpers.AreAllOneOf(ss.Regions, filterChoices.Regions); err != nil {
return models.FilterChoices{}, fmt.Errorf("%w: %s", ErrRegionNotValid, err)
}
*ss = surfsharkRetroRegion(*ss)
}
return countryChoices, regionChoices, cityChoices,
ispChoices, nameChoices, hostnameChoices, nil
return filterChoices, nil
}
// validateServerFilters validates filters against the choices given as arguments.
// Set an argument to nil to pass the check for a particular filter.
func validateServerFilters(settings ServerSelection,
countryChoices, regionChoices, cityChoices, ispChoices,
nameChoices, hostnameChoices []string) (err error) {
if err := helpers.AreAllOneOf(settings.Countries, countryChoices); err != nil {
func validateServerFilters(settings ServerSelection, filterChoices models.FilterChoices) (err error) {
if err := helpers.AreAllOneOf(settings.Countries, filterChoices.Countries); err != nil {
return fmt.Errorf("%w: %s", ErrCountryNotValid, err)
}
if err := helpers.AreAllOneOf(settings.Regions, regionChoices); err != nil {
if err := helpers.AreAllOneOf(settings.Regions, filterChoices.Regions); err != nil {
return fmt.Errorf("%w: %s", ErrRegionNotValid, err)
}
if err := helpers.AreAllOneOf(settings.Cities, cityChoices); err != nil {
if err := helpers.AreAllOneOf(settings.Cities, filterChoices.Cities); err != nil {
return fmt.Errorf("%w: %s", ErrCityNotValid, err)
}
if err := helpers.AreAllOneOf(settings.ISPs, ispChoices); err != nil {
if err := helpers.AreAllOneOf(settings.ISPs, filterChoices.ISPs); err != nil {
return fmt.Errorf("%w: %s", ErrISPNotValid, err)
}
if err := helpers.AreAllOneOf(settings.Hostnames, hostnameChoices); err != nil {
if err := helpers.AreAllOneOf(settings.Hostnames, filterChoices.Hostnames); err != nil {
return fmt.Errorf("%w: %s", ErrHostnameNotValid, err)
}
if err := helpers.AreAllOneOf(settings.Names, nameChoices); err != nil {
if err := helpers.AreAllOneOf(settings.Names, filterChoices.Names); err != nil {
return fmt.Errorf("%w: %s", ErrNameNotValid, err)
}

View File

@@ -24,10 +24,14 @@ type Settings struct {
Pprof pprof.Settings
}
type Storage interface {
GetFilterChoices(provider string) models.FilterChoices
}
// Validate validates all the settings and returns an error
// if one of them is not valid.
// TODO v4 remove pointer for receiver (because of Surfshark).
func (s *Settings) Validate(allServers models.AllServers) (err error) {
func (s *Settings) Validate(storage Storage) (err error) {
nameToValidation := map[string]func() error{
"control server": s.ControlServer.validate,
"dns": s.DNS.validate,
@@ -42,7 +46,7 @@ func (s *Settings) Validate(allServers models.AllServers) (err error) {
"version": s.Version.validate,
// Pprof validation done in pprof constructor
"VPN": func() error {
return s.VPN.validate(allServers)
return s.VPN.validate(storage)
},
}
@@ -91,7 +95,7 @@ func (s *Settings) MergeWith(other Settings) {
}
func (s *Settings) OverrideWith(other Settings,
allServers models.AllServers) (err error) {
storage Storage) (err error) {
patchedSettings := s.copy()
patchedSettings.ControlServer.overrideWith(other.ControlServer)
patchedSettings.DNS.overrideWith(other.DNS)
@@ -106,7 +110,7 @@ func (s *Settings) OverrideWith(other Settings,
patchedSettings.Version.overrideWith(other.Version)
patchedSettings.VPN.overrideWith(other.VPN)
patchedSettings.Pprof.MergeWith(other.Pprof)
err = patchedSettings.Validate(allServers)
err = patchedSettings.Validate(storage)
if err != nil {
return err
}

View File

@@ -35,17 +35,18 @@ func (u Updater) Validate() (err error) {
ErrUpdaterPeriodTooSmall, *u.Period, minPeriod)
}
for i, provider := range u.Providers {
validProviders := providers.All()
for _, provider := range u.Providers {
valid := false
for _, validProvider := range providers.All() {
for _, validProvider := range validProviders {
if provider == validProvider {
valid = true
break
}
}
if !valid {
return fmt.Errorf("%w: %s at index %d",
ErrVPNProviderNameNotValid, provider, i)
return fmt.Errorf("%w: %q can only be one of %s",
ErrVPNProviderNameNotValid, provider, helpers.ChoicesOrString(validProviders))
}
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gotree"
)
@@ -21,7 +20,7 @@ type VPN struct {
}
// TODO v4 remove pointer for receiver (because of Surfshark).
func (v *VPN) validate(allServers models.AllServers) (err error) {
func (v *VPN) validate(storage Storage) (err error) {
// Validate Type
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
if !helpers.IsOneOf(v.Type, validVPNTypes...) {
@@ -29,7 +28,7 @@ func (v *VPN) validate(allServers models.AllServers) (err error) {
ErrVPNTypeNotValid, v.Type, strings.Join(validVPNTypes, ", "))
}
err = v.Provider.validate(v.Type, allServers)
err = v.Provider.validate(v.Type, storage)
if err != nil {
return fmt.Errorf("provider settings: %w", err)
}

View File

@@ -0,0 +1,10 @@
package models
type FilterChoices struct {
Countries []string
Regions []string
Cities []string
ISPs []string
Names []string
Hostnames []string
}

View File

@@ -1,51 +0,0 @@
package models
import (
"net"
)
func (a AllServers) GetCopy() (allServersCopy AllServers) {
allServersCopy.Version = a.Version
allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers))
for provider, servers := range a.ProviderToServers {
allServersCopy.ProviderToServers[provider] = Servers{
Version: servers.Version,
Timestamp: servers.Timestamp,
Servers: copyServers(servers.Servers),
}
}
return allServersCopy
}
func copyServers(servers []Server) (serversCopy []Server) {
if servers == nil {
return nil
}
serversCopy = make([]Server, len(servers))
for i, server := range servers {
serversCopy[i] = server
serversCopy[i].IPs = copyIPs(server.IPs)
}
return serversCopy
}
func copyIPs(toCopy []net.IP) (copied []net.IP) {
if toCopy == nil {
return nil
}
copied = make([]net.IP, len(toCopy))
for i := range toCopy {
copied[i] = copyIP(toCopy[i])
}
return copied
}
func copyIP(toCopy net.IP) (copied net.IP) {
copied = make(net.IP, len(toCopy))
copy(copied, toCopy)
return copied
}

View File

@@ -1,173 +0,0 @@
package models
import (
"net"
"testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AllServers_GetCopy(t *testing.T) {
allServers := AllServers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 2,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Expressvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Fastestvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.HideMyAss: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Ipvanish: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Ivpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Mullvad: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Nordvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Perfectprivacy: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Privado: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.PrivateInternetAccess: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Privatevpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Protonvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Purevpn: {
Version: 1,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Surfshark: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Torguard: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.VPNUnlimited: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Vyprvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Wevpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Windscribe: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
},
}
servers := allServers.GetCopy()
assert.Equal(t, allServers, servers)
}
func Test_copyIPs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
toCopy []net.IP
copied []net.IP
}{
"nil": {},
"empty": {
toCopy: []net.IP{},
copied: []net.IP{},
},
"single IP": {
toCopy: []net.IP{{1, 1, 1, 1}},
copied: []net.IP{{1, 1, 1, 1}},
},
"two IPs": {
toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
// Reserver leading 9 for copy modifications below
for _, ipToCopy := range testCase.toCopy {
require.NotEqual(t, 9, ipToCopy[0])
}
copied := copyIPs(testCase.toCopy)
assert.Equal(t, testCase.copied, copied)
if len(copied) > 0 {
original := testCase.toCopy[0][0]
testCase.toCopy[0][0] = 9
assert.NotEqual(t, 9, copied[0][0])
testCase.toCopy[0][0] = original
copied[0][0] = 9
assert.NotEqual(t, 9, testCase.toCopy[0][0])
}
})
}
}

View File

@@ -2,6 +2,7 @@ package models
import (
"net"
"reflect"
)
type Server struct {
@@ -26,3 +27,28 @@ type Server struct {
PortForward bool `json:"port_forward,omitempty"`
IPs []net.IP `json:"ips,omitempty"`
}
func (s *Server) Equal(other Server) (equal bool) {
if !ipsAreEqual(s.IPs, other.IPs) {
return false
}
serverCopy := *s
serverCopy.IPs = nil
other.IPs = nil
return reflect.DeepEqual(serverCopy, other)
}
func ipsAreEqual(a, b []net.IP) (equal bool) {
if len(a) != len(b) {
return false
}
for i := range a {
if !a[i].Equal(b[i]) {
return false
}
}
return true
}

View File

@@ -0,0 +1,120 @@
package models
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_Server_Equal(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
a *Server
b Server
equal bool
}{
"same IPs": {
a: &Server{
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
b: Server{
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
equal: true,
},
"same IP strings": {
a: &Server{
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
b: Server{
IPs: []net.IP{{1, 2, 3, 4}},
},
equal: true,
},
"different IPs": {
a: &Server{
IPs: []net.IP{{1, 2, 3, 4}, {2, 3, 4, 5}},
},
b: Server{
IPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 4}},
},
},
"all fields equal": {
a: &Server{
VPN: "vpn",
Country: "country",
Region: "region",
City: "city",
ISP: "isp",
Owned: true,
Number: 1,
ServerName: "server_name",
Hostname: "hostname",
TCP: true,
UDP: true,
OvpnX509: "x509",
RetroLoc: "retroloc",
MultiHop: true,
WgPubKey: "wgpubkey",
Free: true,
Stream: true,
PortForward: true,
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
b: Server{
VPN: "vpn",
Country: "country",
Region: "region",
City: "city",
ISP: "isp",
Owned: true,
Number: 1,
ServerName: "server_name",
Hostname: "hostname",
TCP: true,
UDP: true,
OvpnX509: "x509",
RetroLoc: "retroloc",
MultiHop: true,
WgPubKey: "wgpubkey",
Free: true,
Stream: true,
PortForward: true,
IPs: []net.IP{net.IPv4(1, 2, 3, 4)},
},
equal: true,
},
"different field": {
a: &Server{
VPN: "vpn",
},
b: Server{
VPN: "other vpn",
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ipsOfANotNil := testCase.a.IPs != nil
ipsOfBNotNil := testCase.b.IPs != nil
equal := testCase.a.Equal(testCase.b)
assert.Equal(t, testCase.equal, equal)
// Ensure IPs field is not modified
if ipsOfANotNil {
assert.NotNil(t, testCase.a)
}
if ipsOfBNotNil {
assert.NotNil(t, testCase.b)
}
})
}
}

View File

@@ -15,18 +15,6 @@ type AllServers struct {
ProviderToServers map[string]Servers
}
func (a *AllServers) ServersSlice(provider string) []Server {
if provider == providers.Custom {
return nil
}
servers, ok := a.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in all servers", provider))
}
return copyServers(servers.Servers)
}
var _ json.Marshaler = (*AllServers)(nil)
// MarshalJSON marshals all servers to JSON.

View File

@@ -0,0 +1,66 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/provider/common (interfaces: Storage)
// Package common is a generated GoMock package.
package common
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
settings "github.com/qdm12/gluetun/internal/configuration/settings"
models "github.com/qdm12/gluetun/internal/models"
)
// MockStorage is a mock of Storage interface.
type MockStorage struct {
ctrl *gomock.Controller
recorder *MockStorageMockRecorder
}
// MockStorageMockRecorder is the mock recorder for MockStorage.
type MockStorageMockRecorder struct {
mock *MockStorage
}
// NewMockStorage creates a new mock instance.
func NewMockStorage(ctrl *gomock.Controller) *MockStorage {
mock := &MockStorage{ctrl: ctrl}
mock.recorder = &MockStorageMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder
}
// FilterServers mocks base method.
func (m *MockStorage) FilterServers(arg0 string, arg1 settings.ServerSelection) ([]models.Server, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterServers", arg0, arg1)
ret0, _ := ret[0].([]models.Server)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FilterServers indicates an expected call of FilterServers.
func (mr *MockStorageMockRecorder) FilterServers(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterServers", reflect.TypeOf((*MockStorage)(nil).FilterServers), arg0, arg1)
}
// GetServerByName mocks base method.
func (m *MockStorage) GetServerByName(arg0, arg1 string) (models.Server, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServerByName", arg0, arg1)
ret0, _ := ret[0].(models.Server)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// GetServerByName indicates an expected call of GetServerByName.
func (mr *MockStorageMockRecorder) GetServerByName(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerByName", reflect.TypeOf((*MockStorage)(nil).GetServerByName), arg0, arg1)
}

View File

@@ -0,0 +1,5 @@
package common
// Exceptionally, the storage mock is exported since it is used by all
// provider subpackages tests, and it reduces test code duplication a lot.
//go:generate mockgen -destination=mocks.go -package $GOPACKAGE . Storage

View File

@@ -0,0 +1,12 @@
package common
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
)
type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}

View File

@@ -2,6 +2,7 @@ package cyberghost
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Cyberghost,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Cyberghost),
}

View File

@@ -2,6 +2,7 @@ package expressvpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 1195, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Expressvpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -1,41 +1,63 @@
package expressvpn
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert"
)
func Test_Provider_GetConnection(t *testing.T) {
t.Parallel()
const provider = providers.Expressvpn
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct {
servers []models.Server
filteredServers []models.Server
storageErr error
selection settings.ServerSelection
connection models.Connection
errWrapped error
errMessage string
panicMessage string
}{
"no server": {
selection: settings.ServerSelection{}.WithDefaults(providers.Expressvpn),
errWrapped: utils.ErrNoServer,
errMessage: "no server",
"error": {
storageErr: errTest,
errWrapped: errTest,
errMessage: "cannot filter servers: test error",
},
"no filter": {
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
"default OpenVPN TCP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{}.WithDefaults(providers.Expressvpn),
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
panicMessage: "no default OpenVPN TCP port is defined!",
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
@@ -43,38 +65,14 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP,
},
},
"target IP": {
"default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
}.WithDefaults(providers.Expressvpn),
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1195,
Protocol: constants.UDP,
},
},
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Expressvpn),
servers: []models.Server{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1195,
Protocol: constants.UDP,
Hostname: "b",
},
VPN: vpn.Wireguard,
}.WithDefaults(provider),
panicMessage: "no default Wireguard port is defined!",
},
}
@@ -82,12 +80,23 @@ func Test_Provider_GetConnection(t *testing.T) {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
provider := New(storage, randSource)
connection, err := m.GetConnection(testCase.selection)
if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
_, _ = provider.GetConnection(testCase.selection)
})
return
}
connection, err := provider.GetConnection(testCase.selection)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Expressvpn),
}

View File

@@ -2,6 +2,7 @@ package fastestvpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(4443, 4443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Fastestvpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Fastestvpn),
}

View File

@@ -2,6 +2,7 @@ package hidemyass
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(8080, 553, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.HideMyAss,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.HideMyAss),
}

View File

@@ -2,6 +2,7 @@ package ipvanish
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Ipvanish,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Ipvanish),
}

View File

@@ -2,6 +2,7 @@ package ivpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 58237) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Ivpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -1,41 +1,67 @@
package ivpn
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert"
)
func Test_Provider_GetConnection(t *testing.T) {
t.Parallel()
const provider = providers.Ivpn
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct {
servers []models.Server
filteredServers []models.Server
storageErr error
selection settings.ServerSelection
connection models.Connection
errWrapped error
errMessage string
}{
"no server available": {
selection: settings.ServerSelection{}.WithDefaults(providers.Ivpn),
errWrapped: utils.ErrNoServer,
errMessage: "no server",
"error": {
storageErr: errTest,
errWrapped: errTest,
errMessage: "cannot filter servers: test error",
},
"no filter": {
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
"default OpenVPN TCP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{}.WithDefaults(providers.Ivpn),
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
Port: 443,
Protocol: constants.TCP,
},
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
@@ -43,51 +69,36 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP,
},
},
"target IP": {
"default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
}.WithDefaults(providers.Ivpn),
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
VPN: vpn.Wireguard,
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Type: vpn.Wireguard,
IP: net.IPv4(1, 1, 1, 1),
Port: 58237,
Protocol: constants.UDP,
},
},
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Ivpn),
servers: []models.Server{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
Hostname: "b",
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
provider := New(storage, randSource)
connection, err := m.GetConnection(testCase.selection)
connection, err := provider.GetConnection(testCase.selection)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Ivpn),
}

View File

@@ -2,6 +2,7 @@ package mullvad
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 51820) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Mullvad,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -1,41 +1,67 @@
package mullvad
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert"
)
func Test_Provider_GetConnection(t *testing.T) {
t.Parallel()
const provider = providers.Mullvad
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct {
servers []models.Server
filteredServers []models.Server
storageErr error
selection settings.ServerSelection
connection models.Connection
errWrapped error
errMessage string
}{
"no server available": {
selection: settings.ServerSelection{}.WithDefaults(providers.Mullvad),
errWrapped: utils.ErrNoServer,
errMessage: "no server",
"error": {
storageErr: errTest,
errWrapped: errTest,
errMessage: "cannot filter servers: test error",
},
"no filter": {
servers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
"default OpenVPN TCP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{}.WithDefaults(providers.Mullvad),
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
Port: 443,
Protocol: constants.TCP,
},
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
@@ -43,36 +69,17 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP,
},
},
"target IP": {
"default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
}.WithDefaults(providers.Mullvad),
servers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
VPN: vpn.Wireguard,
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Mullvad),
servers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true, Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{VPN: vpn.OpenVPN, UDP: true, Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{VPN: vpn.OpenVPN, UDP: true, Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Hostname: "b",
Type: vpn.Wireguard,
IP: net.IPv4(1, 1, 1, 1),
Port: 51820,
Protocol: constants.UDP,
},
},
@@ -82,12 +89,16 @@ func Test_Provider_GetConnection(t *testing.T) {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
provider := New(storage, randSource)
connection, err := m.GetConnection(testCase.selection)
connection, err := provider.GetConnection(testCase.selection)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Mullvad),
}

View File

@@ -2,6 +2,7 @@ package nordvpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Nordvpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Nordvpn),
}

View File

@@ -2,6 +2,7 @@ package perfectprivacy
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Perfectprivacy,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Perfectprivacy),
}

View File

@@ -2,6 +2,7 @@ package privado
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Privado,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Privado),
}

View File

@@ -2,6 +2,7 @@ package privateinternetaccess
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/privateinternetaccess/presets"
"github.com/qdm12/gluetun/internal/provider/utils"
@@ -20,5 +21,6 @@ func (p *Provider) GetConnection(selection settings.ServerSelection) (
defaults.OpenVPNUDPPort = 1197
}
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.PrivateInternetAccess,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -15,12 +15,13 @@ import (
"strings"
"time"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/golibs/format"
)
var (
ErrServerNameNotFound = errors.New("server name not found in servers")
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
ErrServerNameEmpty = errors.New("server name is empty")
)
@@ -29,11 +30,9 @@ var (
func (p *Provider) PortForward(ctx context.Context, client *http.Client,
logger utils.Logger, gateway net.IP, serverName string) (
port uint16, err error) {
var server models.Server
for _, server = range p.servers {
if server.ServerName == serverName {
break
}
server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName)
if !ok {
return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName)
}
if !server.PortForward {

View File

@@ -5,11 +5,11 @@ import (
"time"
"github.com/qdm12/gluetun/internal/constants/openvpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
timeNow func() time.Time
// Port forwarding
@@ -17,11 +17,11 @@ type Provider struct {
authFilePath string
}
func New(servers []models.Server, randSource rand.Source,
func New(storage common.Storage, randSource rand.Source,
timeNow func() time.Time) *Provider {
const jsonPortForwardPath = "/gluetun/piaportforward.json"
return &Provider{
servers: servers,
storage: storage,
timeNow: timeNow,
randSource: randSource,
portForwardPath: jsonPortForwardPath,

View File

@@ -2,6 +2,7 @@ package privatevpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Privatevpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Privatevpn),
}

View File

@@ -2,6 +2,7 @@ package protonvpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Protonvpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Protonvpn),
}

View File

@@ -50,52 +50,57 @@ type PortForwarder interface {
port uint16, gateway net.IP, serverName string) (err error)
}
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider {
serversSlice := allServers.ServersSlice(provider)
type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}
func New(provider string, storage Storage, timeNow func() time.Time) Provider {
randSource := rand.NewSource(timeNow().UnixNano())
switch provider {
case providers.Custom:
return custom.New()
case providers.Cyberghost:
return cyberghost.New(serversSlice, randSource)
return cyberghost.New(storage, randSource)
case providers.Expressvpn:
return expressvpn.New(serversSlice, randSource)
return expressvpn.New(storage, randSource)
case providers.Fastestvpn:
return fastestvpn.New(serversSlice, randSource)
return fastestvpn.New(storage, randSource)
case providers.HideMyAss:
return hidemyass.New(serversSlice, randSource)
return hidemyass.New(storage, randSource)
case providers.Ipvanish:
return ipvanish.New(serversSlice, randSource)
return ipvanish.New(storage, randSource)
case providers.Ivpn:
return ivpn.New(serversSlice, randSource)
return ivpn.New(storage, randSource)
case providers.Mullvad:
return mullvad.New(serversSlice, randSource)
return mullvad.New(storage, randSource)
case providers.Nordvpn:
return nordvpn.New(serversSlice, randSource)
return nordvpn.New(storage, randSource)
case providers.Perfectprivacy:
return perfectprivacy.New(serversSlice, randSource)
return perfectprivacy.New(storage, randSource)
case providers.Privado:
return privado.New(serversSlice, randSource)
return privado.New(storage, randSource)
case providers.PrivateInternetAccess:
return privateinternetaccess.New(serversSlice, randSource, timeNow)
return privateinternetaccess.New(storage, randSource, timeNow)
case providers.Privatevpn:
return privatevpn.New(serversSlice, randSource)
return privatevpn.New(storage, randSource)
case providers.Protonvpn:
return protonvpn.New(serversSlice, randSource)
return protonvpn.New(storage, randSource)
case providers.Purevpn:
return purevpn.New(serversSlice, randSource)
return purevpn.New(storage, randSource)
case providers.Surfshark:
return surfshark.New(serversSlice, randSource)
return surfshark.New(storage, randSource)
case providers.Torguard:
return torguard.New(serversSlice, randSource)
return torguard.New(storage, randSource)
case providers.VPNUnlimited:
return vpnunlimited.New(serversSlice, randSource)
return vpnunlimited.New(storage, randSource)
case providers.Vyprvpn:
return vyprvpn.New(serversSlice, randSource)
return vyprvpn.New(storage, randSource)
case providers.Wevpn:
return wevpn.New(serversSlice, randSource)
return wevpn.New(storage, randSource)
case providers.Windscribe:
return windscribe.New(serversSlice, randSource)
return windscribe.New(storage, randSource)
default:
panic("provider " + provider + " is unknown") // should never occur
}

View File

@@ -2,6 +2,7 @@ package purevpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(80, 53, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Purevpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Purevpn),
}

View File

@@ -2,6 +2,7 @@ package surfshark
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(1443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Surfshark,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Surfshark),
}

View File

@@ -2,6 +2,7 @@ package torguard
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(1912, 1912, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Torguard,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Torguard),
}

View File

@@ -1,7 +1,7 @@
package utils
import (
"errors"
"fmt"
"math/rand"
"github.com/qdm12/gluetun/internal/configuration/settings"
@@ -24,20 +24,20 @@ func NewConnectionDefaults(openvpnTCPPort, openvpnUDPPort,
}
}
var ErrNoServer = errors.New("no server")
type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error)
}
func GetConnection(servers []models.Server,
func GetConnection(provider string,
storage Storage,
selection settings.ServerSelection,
defaults ConnectionDefaults,
randSource rand.Source) (
connection models.Connection, err error) {
if len(servers) == 0 {
return connection, ErrNoServer
}
servers = filterServers(servers, selection)
if len(servers) == 0 {
return connection, noServerFoundError(selection)
servers, err := storage.FilterServers(provider, selection)
if err != nil {
return connection, fmt.Errorf("cannot filter servers: %w", err)
}
protocol := getProtocol(selection)

View File

@@ -1,23 +1,30 @@
package utils
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert"
)
func Test_GetConnection(t *testing.T) {
t.Parallel()
errTest := errors.New("test error")
testCases := map[string]struct {
servers []models.Server
provider string
filteredServers []models.Server
filterError error
serverSelection settings.ServerSelection
defaults ConnectionDefaults
randSource rand.Source
@@ -25,25 +32,13 @@ func Test_GetConnection(t *testing.T) {
errWrapped error
errMessage string
}{
"no server": {
serverSelection: settings.ServerSelection{}.
WithDefaults(providers.Mullvad),
errWrapped: ErrNoServer,
errMessage: "no server",
},
"all servers filtered": {
servers: []models.Server{
{VPN: vpn.Wireguard},
{VPN: vpn.Wireguard},
},
serverSelection: settings.ServerSelection{
VPN: vpn.OpenVPN,
}.WithDefaults(providers.Mullvad),
errWrapped: ErrNoServerFound,
errMessage: "no server found: for VPN openvpn; protocol udp",
"storage filter error": {
filterError: errTest,
errWrapped: errTest,
errMessage: "cannot filter servers: test error",
},
"server without IPs": {
servers: []models.Server{
filteredServers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true},
{VPN: vpn.OpenVPN, UDP: true},
},
@@ -58,7 +53,7 @@ func Test_GetConnection(t *testing.T) {
errMessage: "no connection to pick from",
},
"OpenVPN server with hostname": {
servers: []models.Server{
filteredServers: []models.Server{
{
VPN: vpn.OpenVPN,
UDP: true,
@@ -79,7 +74,7 @@ func Test_GetConnection(t *testing.T) {
},
},
"OpenVPN server with x509": {
servers: []models.Server{
filteredServers: []models.Server{
{
VPN: vpn.OpenVPN,
UDP: true,
@@ -101,7 +96,7 @@ func Test_GetConnection(t *testing.T) {
},
},
"server with IPv4 and IPv6": {
servers: []models.Server{
filteredServers: []models.Server{
{
VPN: vpn.OpenVPN,
UDP: true,
@@ -128,7 +123,7 @@ func Test_GetConnection(t *testing.T) {
},
},
"mixed servers": {
servers: []models.Server{
filteredServers: []models.Server{
{
VPN: vpn.OpenVPN,
UDP: true,
@@ -169,8 +164,14 @@ func Test_GetConnection(t *testing.T) {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
connection, err := GetConnection(testCase.servers,
storage := common.NewMockStorage(ctrl)
storage.EXPECT().
FilterServers(testCase.provider, testCase.serverSelection).
Return(testCase.filteredServers, testCase.filterError)
connection, err := GetConnection(testCase.provider, storage,
testCase.serverSelection, testCase.defaults,
testCase.randSource)

View File

@@ -2,6 +2,7 @@ package vpnunlimited
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.VPNUnlimited,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.VPNUnlimited),
}

View File

@@ -2,6 +2,7 @@ package vyprvpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(0, 443, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Vyprvpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Vyprvpn),
}

View File

@@ -2,6 +2,7 @@ package wevpn
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(1195, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Wevpn,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -1,43 +1,68 @@
package wevpn
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert"
)
func Test_Provider_GetConnection(t *testing.T) {
t.Parallel()
const provider = providers.Wevpn
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct {
servers []models.Server
filteredServers []models.Server
storageErr error
selection settings.ServerSelection
connection models.Connection
errWrapped error
errMessage string
panicMessage string
}{
"no server available": {
"error": {
storageErr: errTest,
errWrapped: errTest,
errMessage: "cannot filter servers: test error",
},
"default OpenVPN TCP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
VPN: vpn.OpenVPN,
}.WithDefaults(providers.Wevpn),
errWrapped: utils.ErrNoServer,
errMessage: "no server",
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
"no filter": {
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
Port: 1195,
Protocol: constants.TCP,
},
selection: settings.ServerSelection{}.WithDefaults(providers.Wevpn),
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
@@ -45,38 +70,14 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP,
},
},
"target IP": {
"default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
}.WithDefaults(providers.Wevpn),
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Wevpn),
servers: []models.Server{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Hostname: "b",
Protocol: constants.UDP,
},
VPN: vpn.Wireguard,
}.WithDefaults(provider),
panicMessage: "no default Wireguard port is defined!",
},
}
@@ -84,12 +85,23 @@ func Test_Provider_GetConnection(t *testing.T) {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
provider := New(storage, randSource)
connection, err := m.GetConnection(testCase.selection)
if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
_, _ = provider.GetConnection(testCase.selection)
})
return
}
connection, err := provider.GetConnection(testCase.selection)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Wevpn),
}

View File

@@ -2,6 +2,7 @@ package windscribe
import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
@@ -9,5 +10,6 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 1194) //nolint:gomnd
return utils.GetConnection(p.servers, selection, defaults, p.randSource)
return utils.GetConnection(providers.Windscribe,
p.storage, selection, defaults, p.randSource)
}

View File

@@ -1,41 +1,68 @@
package windscribe
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/stretchr/testify/assert"
)
func Test_Provider_GetConnection(t *testing.T) {
t.Parallel()
const provider = providers.Windscribe
errTest := errors.New("test error")
boolPtr := func(b bool) *bool { return &b }
testCases := map[string]struct {
servers []models.Server
filteredServers []models.Server
storageErr error
selection settings.ServerSelection
connection models.Connection
errWrapped error
errMessage string
panicMessage string
}{
"no server available": {
selection: settings.ServerSelection{}.WithDefaults(providers.Windscribe),
errWrapped: utils.ErrNoServer,
errMessage: "no server",
"error": {
storageErr: errTest,
errWrapped: errTest,
errMessage: "cannot filter servers: test error",
},
"no filter": {
servers: []models.Server{
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{VPN: vpn.OpenVPN, UDP: true, IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
"default OpenVPN TCP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{}.WithDefaults(providers.Windscribe),
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(true),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
Port: 443,
Protocol: constants.TCP,
},
},
"default OpenVPN UDP port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
OpenVPN: settings.OpenVPNSelection{
TCP: boolPtr(false),
},
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(1, 1, 1, 1),
@@ -43,49 +70,41 @@ func Test_Provider_GetConnection(t *testing.T) {
Protocol: constants.UDP,
},
},
"target IP": {
"default Wireguard port": {
filteredServers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
},
selection: settings.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
}.WithDefaults(providers.Windscribe),
servers: []models.Server{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
VPN: vpn.Wireguard,
}.WithDefaults(provider),
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Type: vpn.Wireguard,
IP: net.IPv4(1, 1, 1, 1),
Port: 1194,
Protocol: constants.UDP,
},
},
"with filter": {
selection: settings.ServerSelection{
Hostnames: []string{"b"},
}.WithDefaults(providers.Windscribe),
servers: []models.Server{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}, VPN: vpn.OpenVPN, UDP: true},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}, VPN: vpn.OpenVPN, UDP: true},
},
connection: models.Connection{
Type: vpn.OpenVPN,
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Hostname: "b",
Protocol: constants.UDP,
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
storage := common.NewMockStorage(ctrl)
storage.EXPECT().FilterServers(provider, testCase.selection).
Return(testCase.filteredServers, testCase.storageErr)
randSource := rand.NewSource(0)
provider := New(testCase.servers, randSource)
provider := New(storage, randSource)
if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
_, _ = provider.GetConnection(testCase.selection)
})
return
}
connection, err := provider.GetConnection(testCase.selection)
@@ -93,6 +112,7 @@ func Test_Provider_GetConnection(t *testing.T) {
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.connection, connection)
})
}

View File

@@ -4,19 +4,19 @@ import (
"math/rand"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/utils"
)
type Provider struct {
servers []models.Server
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
}
func New(servers []models.Server, randSource rand.Source) *Provider {
func New(storage common.Storage, randSource rand.Source) *Provider {
return &Provider{
servers: servers,
storage: storage,
randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Windscribe),
}

View File

@@ -0,0 +1,27 @@
package storage
import (
"github.com/qdm12/gluetun/internal/configuration/settings/validation"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
)
func (s *Storage) GetFilterChoices(provider string) models.FilterChoices {
if provider == providers.Custom {
return models.FilterChoices{}
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
servers := serversObject.Servers
return models.FilterChoices{
Countries: validation.ExtractCountries(servers),
Regions: validation.ExtractRegions(servers),
Cities: validation.ExtractCities(servers),
ISPs: validation.ExtractISPs(servers),
Names: validation.ExtractServerNames(servers),
Hostnames: validation.ExtractHostnames(servers),
}
}

32
internal/storage/copy.go Normal file
View File

@@ -0,0 +1,32 @@
package storage
import (
"net"
"github.com/qdm12/gluetun/internal/models"
)
func copyServer(server models.Server) (serverCopy models.Server) {
serverCopy = server
serverCopy.IPs = copyIPs(server.IPs)
return serverCopy
}
func copyIPs(toCopy []net.IP) (copied []net.IP) {
if toCopy == nil {
return nil
}
copied = make([]net.IP, len(toCopy))
for i := range toCopy {
copied[i] = copyIP(toCopy[i])
}
return copied
}
func copyIP(toCopy net.IP) (copied net.IP) {
copied = make(net.IP, len(toCopy))
copy(copied, toCopy)
return copied
}

View File

@@ -0,0 +1,75 @@
package storage
import (
"net"
"testing"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_copyServer(t *testing.T) {
t.Parallel()
server := models.Server{
Country: "a",
IPs: []net.IP{{1, 2, 3, 4}},
}
serverCopy := copyServer(server)
assert.Equal(t, server, serverCopy)
// Check for mutation
serverCopy.IPs[0][0] = 9
assert.NotEqual(t, server, serverCopy)
}
func Test_copyIPs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
toCopy []net.IP
copied []net.IP
}{
"nil": {},
"empty": {
toCopy: []net.IP{},
copied: []net.IP{},
},
"single IP": {
toCopy: []net.IP{{1, 1, 1, 1}},
copied: []net.IP{{1, 1, 1, 1}},
},
"two IPs": {
toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
// Reserver leading 9 for copy modifications below
for _, ipToCopy := range testCase.toCopy {
require.NotEqual(t, 9, ipToCopy[0])
}
copied := copyIPs(testCase.toCopy)
assert.Equal(t, testCase.copied, copied)
if len(copied) > 0 {
original := testCase.toCopy[0][0]
testCase.toCopy[0][0] = 9
assert.NotEqual(t, 9, copied[0][0])
testCase.toCopy[0][0] = original
copied[0][0] = 9
assert.NotEqual(t, 9, testCase.toCopy[0][0])
}
})
}
}

143
internal/storage/filter.go Normal file
View File

@@ -0,0 +1,143 @@
package storage
import (
"strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
)
// FilterServers filter servers for the given provider and according
// to the given selection. The filtered servers are deep copied so they
// are safe for mutation by the caller.
func (s *Storage) FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error) {
if provider == providers.Custom {
return nil, nil
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
allServers := serversObject.Servers
if len(allServers) == 0 {
return nil, ErrNoServerFound
}
for _, server := range allServers {
if filterServer(server, selection) {
continue
}
server = copyServer(server)
servers = append(servers, server)
}
if len(servers) == 0 {
return nil, noServerFoundError(selection)
}
return servers, nil
}
func filterServer(server models.Server,
selection settings.ServerSelection) (filtered bool) {
// Note each condition is split to make sure
// we have full testing coverage.
if server.VPN != selection.VPN {
return true
}
if filterByProtocol(selection, server.TCP, server.UDP) {
return true
}
if *selection.MultiHopOnly && !server.MultiHop {
return true
}
if *selection.FreeOnly && !server.Free {
return true
}
if *selection.StreamOnly && !server.Stream {
return true
}
if *selection.OwnedOnly && !server.Owned {
return true
}
if filterByPossibilities(server.Country, selection.Countries) {
return true
}
if filterByPossibilities(server.Region, selection.Regions) {
return true
}
if filterByPossibilities(server.City, selection.Cities) {
return true
}
if filterByPossibilities(server.ISP, selection.ISPs) {
return true
}
if filterByPossibilitiesUint16(server.Number, selection.Numbers) {
return true
}
if filterByPossibilities(server.ServerName, selection.Names) {
return true
}
if filterByPossibilities(server.Hostname, selection.Hostnames) {
return true
}
// TODO filter port forward server for PIA
return false
}
func filterByPossibilities(value string, possibilities []string) (filtered bool) {
if len(possibilities) == 0 {
return false
}
for _, possibility := range possibilities {
if strings.EqualFold(value, possibility) {
return false
}
}
return true
}
// TODO merge with filterByPossibilities with generics in Go 1.18.
func filterByPossibilitiesUint16(value uint16, possibilities []uint16) (filtered bool) {
if len(possibilities) == 0 {
return false
}
for _, possibility := range possibilities {
if value == possibility {
return false
}
}
return true
}
func filterByProtocol(selection settings.ServerSelection,
serverTCP, serverUDP bool) (filtered bool) {
switch selection.VPN {
case vpn.Wireguard:
return !serverUDP
default: // OpenVPN
wantTCP := *selection.OpenVPN.TCP
wantUDP := !wantTCP
return (wantTCP && !serverTCP) || (wantUDP && !serverUDP)
}
}

View File

@@ -4,21 +4,20 @@ import (
"encoding/json"
"os"
"path/filepath"
"github.com/qdm12/gluetun/internal/models"
)
var _ Flusher = (*Storage)(nil)
// FlushToFile flushes the merged servers data to the file
// specified by path, as indented JSON.
func (s *Storage) FlushToFile(path string) error {
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
type Flusher interface {
FlushToFile(allServers *models.AllServers) error
return s.flushToFile(path)
}
func (s *Storage) FlushToFile(allServers *models.AllServers) error {
return flushToFile(s.filepath, allServers)
}
func flushToFile(path string, servers *models.AllServers) error {
// flushToFile flushes the merged servers data to the file
// specified by path, as indented JSON. It is not thread-safe.
func (s *Storage) flushToFile(path string) error {
dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, 0644); err != nil {
return err
@@ -28,11 +27,15 @@ func flushToFile(path string, servers *models.AllServers) error {
if err != nil {
return err
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(servers); err != nil {
err = encoder.Encode(&s.mergedServers)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -1,4 +1,4 @@
package utils
package storage
import (
"errors"

View File

@@ -1,7 +1,118 @@
package storage
import "github.com/qdm12/gluetun/internal/models"
import (
"fmt"
"time"
func (s *Storage) GetServers() models.AllServers {
return s.mergedServers.GetCopy()
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
)
// SetServers sets the given servers for the given provider
// in the storage in-memory map and saves all the servers
// to file.
// Note the servers given are not copied so the caller must
// NOT MUTATE them after calling this method.
func (s *Storage) SetServers(provider string, servers []models.Server) (err error) {
if provider == providers.Custom {
return
}
s.mergedMutex.Lock()
defer s.mergedMutex.Unlock()
serversObject := s.getMergedServersObject(provider)
serversObject.Timestamp = time.Now().Unix()
serversObject.Servers = servers
s.mergedServers.ProviderToServers[provider] = serversObject
err = s.flushToFile(s.filepath)
if err != nil {
return fmt.Errorf("cannot save servers to file: %w", err)
}
return nil
}
// GetServerByName returns the server for the given provider
// and server name. It returns `ok` as false if the server is
// not found. The returned server is also deep copied so it is
// safe for mutation and/or thread safe use.
func (s *Storage) GetServerByName(provider, name string) (
server models.Server, ok bool) {
if provider == providers.Custom {
return server, false
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
for _, server := range serversObject.Servers {
if server.ServerName == name {
return copyServer(server), true
}
}
return server, false
}
// GetServersCount returns the number of servers for the provider given.
func (s *Storage) GetServersCount(provider string) (count int) {
if provider == providers.Custom {
return 0
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
return len(serversObject.Servers)
}
// FormatToMarkdown Markdown formats the servers for the provider given
// and returns the resulting string.
func (s *Storage) FormatToMarkdown(provider string) (formatted string) {
if provider == providers.Custom {
return ""
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
formatted = serversObject.ToMarkdown(provider)
return formatted
}
// GetServersCount returns the number of servers for the provider given.
func (s *Storage) ServersAreEqual(provider string, servers []models.Server) (equal bool) {
if provider == providers.Custom {
return true
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
existingServers := serversObject.Servers
if len(existingServers) != len(servers) {
return false
}
for i := range existingServers {
if !existingServers[i].Equal(servers[i]) {
return false
}
}
return true
}
func (s *Storage) getMergedServersObject(provider string) (serversObject models.Servers) {
serversObject, ok := s.mergedServers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in in-memory servers map", provider))
}
return serversObject
}

View File

@@ -2,11 +2,14 @@
package storage
import (
"sync"
"github.com/qdm12/gluetun/internal/models"
)
type Storage struct {
mergedServers models.AllServers
mergedMutex sync.RWMutex
// this is stored in memory to avoid re-parsing
// the embedded JSON file on every call to the
// SyncServers method.

View File

@@ -29,6 +29,9 @@ func (s *Storage) syncServers() (err error) {
hardcodedCount := countServers(s.hardcodedServers)
countOnFile := countServers(serversOnFile)
s.mergedMutex.Lock()
defer s.mergedMutex.Unlock()
if countOnFile == 0 {
s.logger.Info(fmt.Sprintf(
"creating %s with %d hardcoded servers",
@@ -47,7 +50,8 @@ func (s *Storage) syncServers() (err error) {
return nil
}
if err := flushToFile(s.filepath, &s.mergedServers); err != nil {
err = s.flushToFile(s.filepath)
if err != nil {
return fmt.Errorf("cannot write servers to file: %w", err)
}
return nil

View File

@@ -9,7 +9,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater"
)
@@ -24,15 +23,13 @@ type Looper interface {
}
type Updater interface {
UpdateServers(ctx context.Context) (allServers models.AllServers, err error)
UpdateServers(ctx context.Context) (err error)
}
type looper struct {
state state
// Objects
updater Updater
flusher storage.Flusher
setAllServers func(allServers models.AllServers)
logger Logger
// Internal channels and locks
loopLock sync.Mutex
@@ -49,23 +46,26 @@ type looper struct {
const defaultBackoffTime = 5 * time.Second
type Storage interface {
SetServers(provider string, servers []models.Server) (err error)
GetServersCount(provider string) (count int)
ServersAreEqual(provider string, servers []models.Server) (equal bool)
}
type Logger interface {
Info(s string)
Warn(s string)
Error(s string)
}
func NewLooper(settings settings.Updater, currentServers models.AllServers,
flusher storage.Flusher, setAllServers func(allServers models.AllServers),
func NewLooper(settings settings.Updater, storage Storage,
client *http.Client, logger Logger) Looper {
return &looper{
state: state{
status: constants.Stopped,
settings: settings,
},
updater: updater.New(settings, client, currentServers, logger),
flusher: flusher,
setAllServers: setAllServers,
updater: updater.New(settings, client, storage, logger),
logger: logger,
start: make(chan struct{}),
running: make(chan models.LoopStatus),
@@ -106,20 +106,19 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
for ctx.Err() == nil {
updateCtx, updateCancel := context.WithCancel(ctx)
serversCh := make(chan models.AllServers)
errorCh := make(chan error)
runWg := &sync.WaitGroup{}
runWg.Add(1)
go func() {
defer runWg.Done()
servers, err := l.updater.UpdateServers(updateCtx)
err := l.updater.UpdateServers(updateCtx)
if err != nil {
if updateCtx.Err() == nil {
errorCh <- err
}
return
}
serversCh <- servers
l.state.setStatusWithLock(constants.Completed)
}()
if !crashed {
@@ -148,16 +147,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
updateCancel()
runWg.Wait()
l.stopped <- struct{}{}
case servers := <-serversCh:
l.setAllServers(servers)
if err := l.flusher.FlushToFile(&servers); err != nil {
l.logger.Error(err.Error())
}
runWg.Wait()
l.state.setStatusWithLock(constants.Completed)
l.logger.Info("Updated servers information")
case err := <-errorCh:
close(serversCh)
runWg.Wait()
l.state.setStatusWithLock(constants.Crashed)
l.logAndWait(ctx, err)

View File

@@ -3,8 +3,6 @@ package updater
import (
"context"
"fmt"
"reflect"
"time"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
@@ -31,18 +29,25 @@ import (
)
func (u *Updater) updateProvider(ctx context.Context, provider string) (err error) {
existingServers := u.getProviderServers(provider)
minServers := getMinServers(existingServers)
existingServersCount := u.storage.GetServersCount(provider)
minServers := getMinServers(existingServersCount)
servers, err := u.getServers(ctx, provider, minServers)
if err != nil {
return err
return fmt.Errorf("cannot get servers: %w", err)
}
if reflect.DeepEqual(existingServers, servers) {
if u.storage.ServersAreEqual(provider, servers) {
return nil
}
u.patchProvider(provider, servers)
// Note the servers variable must NOT BE MUTATED after this call,
// since the implementation does not deep copy the servers.
// TODO set in storage in provider updater directly, server by server,
// to avoid accumulating server data in memory.
err = u.storage.SetServers(provider, servers)
if err != nil {
return fmt.Errorf("cannot set servers to storage: %w", err)
}
return nil
}
@@ -101,25 +106,7 @@ func (u *Updater) getServers(ctx context.Context, provider string,
return providerUpdater.GetServers(ctx, minServers)
}
func (u *Updater) getProviderServers(provider string) (servers []models.Server) {
providerServers, ok := u.servers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s is unknown", provider))
}
return providerServers.Servers
}
func getMinServers(servers []models.Server) (minServers int) {
func getMinServers(existingServersCount int) (minServers int) {
const minRatio = 0.8
return int(minRatio * float64(len(servers)))
}
func (u *Updater) patchProvider(provider string, servers []models.Server) {
providerServers, ok := u.servers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s is unknown", provider))
}
providerServers.Timestamp = time.Now().Unix()
providerServers.Servers = servers
u.servers.ProviderToServers[provider] = providerServers
return int(minRatio * float64(existingServersCount))
}

View File

@@ -19,7 +19,7 @@ type Updater struct {
options settings.Updater
// state
servers models.AllServers
storage Storage
// Functions for tests
logger Logger
@@ -29,6 +29,12 @@ type Updater struct {
unzipper unzip.Unzipper
}
type Storage interface {
SetServers(provider string, servers []models.Server) (err error)
GetServersCount(provider string) (count int)
ServersAreEqual(provider string, servers []models.Server) (equal bool)
}
type Logger interface {
Info(s string)
Warn(s string)
@@ -36,20 +42,20 @@ type Logger interface {
}
func New(settings settings.Updater, httpClient *http.Client,
currentServers models.AllServers, logger Logger) *Updater {
storage Storage, logger Logger) *Updater {
unzipper := unzip.New(httpClient)
return &Updater{
options: settings,
storage: storage,
logger: logger,
timeNow: time.Now,
presolver: resolver.NewParallelResolver(settings.DNSAddress.String()),
client: httpClient,
unzipper: unzipper,
options: settings,
servers: currentServers,
}
}
func (u *Updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) {
func (u *Updater) UpdateServers(ctx context.Context) (err error) {
caser := cases.Title(language.English)
for _, provider := range u.options.Providers {
u.logger.Info("updating " + caser.String(provider) + " servers...")
@@ -62,17 +68,17 @@ func (u *Updater) UpdateServers(ctx context.Context) (allServers models.AllServe
// return the only error for the single provider.
if len(u.options.Providers) == 1 {
return allServers, err
return err
}
// stop updating the next providers if context is canceled.
if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr
return ctxErr
}
// Log the error and continue updating the next provider.
u.logger.Error(err.Error())
}
return u.servers, nil
return nil
}

View File

@@ -27,12 +27,12 @@ type Looper interface {
loopstate.Getter
loopstate.Applier
SettingsGetSetter
ServersGetterSetter
}
type Loop struct {
statusManager loopstate.Manager
state state.Manager
storage Storage
// Fixed parameters
buildInfo models.BuildInformation
versionInfo bool
@@ -64,12 +64,17 @@ type firewallConfigurer interface {
firewall.PortAllower
}
type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}
const (
defaultBackoffTime = 15 * time.Second
)
func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16,
allServers models.AllServers, openvpnConf openvpn.Interface,
storage Storage, openvpnConf openvpn.Interface,
netLinker netlink.NetLinker, fw firewallConfigurer, routing routing.VPNGetter,
portForward portforward.StartStopper, starter command.Starter,
publicip publicip.Looper, dnsLooper dns.Looper,
@@ -81,11 +86,12 @@ func NewLoop(vpnSettings settings.VPN, vpnInputPorts []uint16,
stopped := make(chan struct{})
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
state := state.New(statusManager, vpnSettings, allServers)
state := state.New(statusManager, vpnSettings)
return &Loop{
statusManager: statusManager,
state: state,
storage: storage,
buildInfo: buildInfo,
versionInfo: versionInfo,
vpnInputPorts: vpnInputPorts,

View File

@@ -28,9 +28,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
}
for ctx.Err() == nil {
settings, allServers := l.state.GetSettingsAndServers()
settings := l.state.GetSettings()
providerConf := provider.New(*settings.Provider.Name, allServers, time.Now)
providerConf := provider.New(*settings.Provider.Name, l.storage, time.Now)
portForwarding := *settings.Provider.PortForwarding.Enabled
var vpnRunner vpnRunner

View File

@@ -1,16 +0,0 @@
package vpn
import (
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/vpn/state"
)
type ServersGetterSetter = state.ServersGetterSetter
func (l *Loop) GetServers() (servers models.AllServers) {
return l.state.GetServers()
}
func (l *Loop) SetServers(servers models.AllServers) {
l.state.SetServers(servers)
}

View File

@@ -1,20 +0,0 @@
package state
import "github.com/qdm12/gluetun/internal/models"
type ServersGetterSetter interface {
GetServers() (servers models.AllServers)
SetServers(servers models.AllServers)
}
func (s *State) GetServers() (servers models.AllServers) {
s.allServersMu.RLock()
defer s.allServersMu.RUnlock()
return s.allServers
}
func (s *State) SetServers(servers models.AllServers) {
s.allServersMu.Lock()
defer s.allServersMu.Unlock()
s.allServers = servers
}

View File

@@ -5,23 +5,18 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
)
var _ Manager = (*State)(nil)
type Manager interface {
SettingsGetSetter
ServersGetterSetter
GetSettingsAndServers() (vpn settings.VPN, allServers models.AllServers)
}
func New(statusApplier loopstate.Applier,
vpn settings.VPN, allServers models.AllServers) *State {
func New(statusApplier loopstate.Applier, vpn settings.VPN) *State {
return &State{
statusApplier: statusApplier,
vpn: vpn,
allServers: allServers,
}
}
@@ -30,18 +25,4 @@ type State struct {
vpn settings.VPN
settingsMu sync.RWMutex
allServers models.AllServers
allServersMu sync.RWMutex
}
func (s *State) GetSettingsAndServers() (vpn settings.VPN,
allServers models.AllServers) {
s.settingsMu.RLock()
s.allServersMu.RLock()
vpn = s.vpn
allServers = s.allServers
s.settingsMu.RUnlock()
s.allServersMu.RUnlock()
return vpn, allServers
}