From 447a7c9891af5f608d2ed2a951bfce186edcfb81 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 11 Jun 2022 17:41:57 +0000 Subject: [PATCH] updater: refactoring and set DNS server correctly - Fix CLI operation not setting DNS server - Fix periodic operation not setting DNS server - Set DNS address for resolution once at start for both CLI and periodic operation - Inject resolver to each provider instead of creating it within - Use resolver settings on every call to `.Resolve` method, instead of passing it to constructor - Move out minServers check from resolver --- cmd/gluetun/main.go | 5 +- internal/cli/openvpnconfig.go | 10 +++- internal/cli/update.go | 9 +-- internal/provider/common/mocks.go | 9 +-- internal/provider/common/updater.go | 3 +- internal/provider/cyberghost/provider.go | 5 +- .../provider/cyberghost/updater/resolve.go | 6 +- .../provider/cyberghost/updater/servers.go | 10 +++- .../provider/cyberghost/updater/updater.go | 4 +- .../provider/expressvpn/connection_test.go | 3 +- internal/provider/expressvpn/provider.go | 5 +- .../provider/expressvpn/updater/resolve.go | 6 +- .../provider/expressvpn/updater/servers.go | 3 +- .../provider/expressvpn/updater/updater.go | 5 +- internal/provider/fastestvpn/provider.go | 5 +- .../provider/fastestvpn/updater/resolve.go | 6 +- .../provider/fastestvpn/updater/servers.go | 13 +++-- .../provider/fastestvpn/updater/updater.go | 5 +- internal/provider/hidemyass/provider.go | 5 +- .../provider/hidemyass/updater/resolve.go | 6 +- .../provider/hidemyass/updater/servers.go | 8 ++- .../provider/hidemyass/updater/updater.go | 5 +- internal/provider/ipvanish/provider.go | 5 +- internal/provider/ipvanish/updater/resolve.go | 6 +- internal/provider/ipvanish/updater/servers.go | 13 +++-- .../provider/ipvanish/updater/servers_test.go | 58 ++++++++++++++----- internal/provider/ipvanish/updater/updater.go | 5 +- internal/provider/ivpn/connection_test.go | 3 +- internal/provider/ivpn/provider.go | 5 +- internal/provider/ivpn/updater/resolve.go | 6 +- internal/provider/ivpn/updater/servers.go | 8 ++- .../provider/ivpn/updater/servers_test.go | 34 +++++++++-- internal/provider/ivpn/updater/updater.go | 5 +- internal/provider/privado/provider.go | 5 +- internal/provider/privado/updater/resolve.go | 6 +- internal/provider/privado/updater/servers.go | 8 ++- internal/provider/privado/updater/updater.go | 4 +- internal/provider/privatevpn/provider.go | 5 +- .../provider/privatevpn/updater/resolve.go | 6 +- .../provider/privatevpn/updater/servers.go | 8 ++- .../provider/privatevpn/updater/updater.go | 5 +- internal/provider/protonvpn/provider.go | 3 +- internal/provider/providers.go | 51 ++++++++-------- internal/provider/purevpn/provider.go | 5 +- internal/provider/purevpn/updater/resolve.go | 6 +- internal/provider/purevpn/updater/servers.go | 13 +++-- internal/provider/purevpn/updater/updater.go | 4 +- internal/provider/surfshark/provider.go | 5 +- .../provider/surfshark/updater/resolve.go | 6 +- .../provider/surfshark/updater/servers.go | 13 +++-- .../provider/surfshark/updater/updater.go | 4 +- internal/provider/torguard/provider.go | 5 +- internal/provider/torguard/updater/resolve.go | 6 +- internal/provider/torguard/updater/servers.go | 8 ++- internal/provider/torguard/updater/updater.go | 5 +- internal/provider/vpnunlimited/provider.go | 5 +- .../provider/vpnunlimited/updater/resolve.go | 6 +- .../provider/vpnunlimited/updater/servers.go | 13 +++-- .../provider/vpnunlimited/updater/updater.go | 5 +- internal/provider/vyprvpn/provider.go | 5 +- internal/provider/vyprvpn/updater/resolve.go | 6 +- internal/provider/vyprvpn/updater/servers.go | 13 +++-- internal/provider/vyprvpn/updater/updater.go | 5 +- internal/provider/wevpn/connection_test.go | 3 +- internal/provider/wevpn/provider.go | 5 +- internal/provider/wevpn/updater/resolve.go | 6 +- internal/provider/wevpn/updater/servers.go | 3 +- internal/provider/wevpn/updater/updater.go | 4 +- internal/updater/resolver/parallel.go | 36 +++++------- internal/updater/resolver/repeat.go | 15 +++-- 70 files changed, 366 insertions(+), 229 deletions(-) diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 7baa1a51..830b588d 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -39,6 +39,7 @@ import ( "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/tun" updater "github.com/qdm12/gluetun/internal/updater/loop" + "github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/unzip" "github.com/qdm12/gluetun/internal/vpn" "github.com/qdm12/golibs/command" @@ -379,7 +380,9 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, updaterLogger := logger.New(log.SetComponent("updater")) unzipper := unzip.New(httpClient) - providers := provider.NewProviders(storage, time.Now, updaterLogger, httpClient, unzipper) + parallelResolver := resolver.NewParallelResolver(allSettings.Updater.DNSAddress.String()) + providers := provider.NewProviders(storage, time.Now, + updaterLogger, httpClient, unzipper, parallelResolver) vpnLogger := logger.New(log.SetComponent("vpn")) vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.Firewall.VPNInputPorts, diff --git a/internal/cli/openvpnconfig.go b/internal/cli/openvpnconfig.go index a07d96e0..3707efac 100644 --- a/internal/cli/openvpnconfig.go +++ b/internal/cli/openvpnconfig.go @@ -3,6 +3,7 @@ package cli import ( "context" "fmt" + "net" "net/http" "strings" "time" @@ -11,6 +12,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/storage" + "github.com/qdm12/gluetun/internal/updater/resolver" ) type OpenvpnConfigLogger interface { @@ -23,6 +25,11 @@ type Unzipper interface { contents map[string][]byte, err error) } +type ParallelResolver interface { + Resolve(ctx context.Context, settings resolver.ParallelSettings) ( + hostToIPs map[string][]net.IP, warnings []string, err error) +} + func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) error { storage, err := storage.New(logger, constants.ServersData) if err != nil { @@ -42,8 +49,9 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) e unzipper := (Unzipper)(nil) client := (*http.Client)(nil) warner := (Warner)(nil) + parallelResolver := (ParallelResolver)(nil) - providers := provider.NewProviders(storage, time.Now, warner, client, unzipper) + providers := provider.NewProviders(storage, time.Now, warner, client, unzipper, parallelResolver) providerConf := providers.Get(*allSettings.VPN.Provider.Name) connection, err := providerConf.GetConnection(allSettings.VPN.Provider.ServerSelection) if err != nil { diff --git a/internal/cli/update.go b/internal/cli/update.go index 3256dded..3350eae8 100644 --- a/internal/cli/update.go +++ b/internal/cli/update.go @@ -16,6 +16,7 @@ import ( "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/updater" + "github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/unzip" ) @@ -71,17 +72,17 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e return fmt.Errorf("options validation failed: %w", err) } - const clientTimeout = 10 * time.Second - httpClient := &http.Client{Timeout: clientTimeout} - storage, err := storage.New(logger, constants.ServersData) if err != nil { return fmt.Errorf("cannot create servers storage: %w", err) } + const clientTimeout = 10 * time.Second + httpClient := &http.Client{Timeout: clientTimeout} unzipper := unzip.New(httpClient) + parallelResolver := resolver.NewParallelResolver(options.DNSAddress.String()) - providers := provider.NewProviders(storage, time.Now, logger, httpClient, unzipper) + providers := provider.NewProviders(storage, time.Now, logger, httpClient, unzipper, parallelResolver) updater := updater.New(httpClient, storage, providers, logger) err = updater.UpdateServers(ctx, options.Providers) diff --git a/internal/provider/common/mocks.go b/internal/provider/common/mocks.go index 3c19b295..869e2864 100644 --- a/internal/provider/common/mocks.go +++ b/internal/provider/common/mocks.go @@ -12,6 +12,7 @@ import ( gomock "github.com/golang/mock/gomock" settings "github.com/qdm12/gluetun/internal/configuration/settings" models "github.com/qdm12/gluetun/internal/models" + resolver "github.com/qdm12/gluetun/internal/updater/resolver" ) // MockParallelResolver is a mock of ParallelResolver interface. @@ -38,9 +39,9 @@ func (m *MockParallelResolver) EXPECT() *MockParallelResolverMockRecorder { } // Resolve mocks base method. -func (m *MockParallelResolver) Resolve(arg0 context.Context, arg1 []string, arg2 int) (map[string][]net.IP, []string, error) { +func (m *MockParallelResolver) Resolve(arg0 context.Context, arg1 resolver.ParallelSettings) (map[string][]net.IP, []string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Resolve", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "Resolve", arg0, arg1) ret0, _ := ret[0].(map[string][]net.IP) ret1, _ := ret[1].([]string) ret2, _ := ret[2].(error) @@ -48,9 +49,9 @@ func (m *MockParallelResolver) Resolve(arg0 context.Context, arg1 []string, arg2 } // Resolve indicates an expected call of Resolve. -func (mr *MockParallelResolverMockRecorder) Resolve(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockParallelResolverMockRecorder) Resolve(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockParallelResolver)(nil).Resolve), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockParallelResolver)(nil).Resolve), arg0, arg1) } // MockStorage is a mock of Storage interface. diff --git a/internal/provider/common/updater.go b/internal/provider/common/updater.go index 905650e7..e4bb0db5 100644 --- a/internal/provider/common/updater.go +++ b/internal/provider/common/updater.go @@ -6,6 +6,7 @@ import ( "net" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" ) var ErrNotEnoughServers = errors.New("not enough servers found") @@ -15,7 +16,7 @@ type Fetcher interface { } type ParallelResolver interface { - Resolve(ctx context.Context, hosts []string, minToFind int) ( + Resolve(ctx context.Context, settings resolver.ParallelSettings) ( hostToIPs map[string][]net.IP, warnings []string, err error) } diff --git a/internal/provider/cyberghost/provider.go b/internal/provider/cyberghost/provider.go index a8d6e90b..2ea8f987 100644 --- a/internal/provider/cyberghost/provider.go +++ b/internal/provider/cyberghost/provider.go @@ -16,12 +16,13 @@ type Provider struct { common.Fetcher } -func New(storage common.Storage, randSource rand.Source) *Provider { +func New(storage common.Storage, randSource rand.Source, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Cyberghost), - Fetcher: updater.New(), + Fetcher: updater.New(parallelResolver), } } diff --git a/internal/provider/cyberghost/updater/resolve.go b/internal/provider/cyberghost/updater/resolve.go index 23f496a1..49e984cb 100644 --- a/internal/provider/cyberghost/updater/resolve.go +++ b/internal/provider/cyberghost/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 1 maxDuration = 20 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 4 maxFails = 10 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/cyberghost/updater/servers.go b/internal/provider/cyberghost/updater/servers.go index 963da043..a33560ea 100644 --- a/internal/provider/cyberghost/updater/servers.go +++ b/internal/provider/cyberghost/updater/servers.go @@ -4,9 +4,11 @@ package updater import ( "context" + "fmt" "sort" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider/common" ) func (u *Updater) FetchServers(ctx context.Context, minServers int) ( @@ -14,11 +16,17 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( possibleServers := getPossibleServers() possibleHosts := possibleServers.hostsSlice() - hostToIPs, _, err := u.presolver.Resolve(ctx, possibleHosts, minServers) + resolveSettings := parallelResolverSettings(possibleHosts) + hostToIPs, _, err := u.presolver.Resolve(ctx, resolveSettings) if err != nil { return nil, err } + if len(hostToIPs) < minServers { + return nil, fmt.Errorf("%w: %d and expected at least %d", + common.ErrNotEnoughServers, len(servers), minServers) + } + possibleServers.adaptWithIPs(hostToIPs) servers = possibleServers.toSlice() diff --git a/internal/provider/cyberghost/updater/updater.go b/internal/provider/cyberghost/updater/updater.go index 326513ff..b46befa2 100644 --- a/internal/provider/cyberghost/updater/updater.go +++ b/internal/provider/cyberghost/updater/updater.go @@ -8,8 +8,8 @@ type Updater struct { presolver common.ParallelResolver } -func New() *Updater { +func New(parallelResolver common.ParallelResolver) *Updater { return &Updater{ - presolver: newParallelResolver(), + presolver: parallelResolver, } } diff --git a/internal/provider/expressvpn/connection_test.go b/internal/provider/expressvpn/connection_test.go index eb6bbfa0..b7160088 100644 --- a/internal/provider/expressvpn/connection_test.go +++ b/internal/provider/expressvpn/connection_test.go @@ -89,7 +89,8 @@ func Test_Provider_GetConnection(t *testing.T) { unzipper := (common.Unzipper)(nil) warner := (common.Warner)(nil) - provider := New(storage, randSource, unzipper, warner) + parallelResolver := (common.ParallelResolver)(nil) + provider := New(storage, randSource, unzipper, warner, parallelResolver) if testCase.panicMessage != "" { assert.PanicsWithValue(t, testCase.panicMessage, func() { diff --git a/internal/provider/expressvpn/provider.go b/internal/provider/expressvpn/provider.go index c483e19b..2c746d64 100644 --- a/internal/provider/expressvpn/provider.go +++ b/internal/provider/expressvpn/provider.go @@ -17,12 +17,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Expressvpn), - Fetcher: updater.New(unzipper, updaterWarner), + Fetcher: updater.New(unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/expressvpn/updater/resolve.go b/internal/provider/expressvpn/updater/resolve.go index c38ed35a..85d315cd 100644 --- a/internal/provider/expressvpn/updater/resolve.go +++ b/internal/provider/expressvpn/updater/resolve.go @@ -6,13 +6,14 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() *resolver.Parallel { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxNoNew = 1 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: time.Second, @@ -21,5 +22,4 @@ func newParallelResolver() *resolver.Parallel { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/expressvpn/updater/servers.go b/internal/provider/expressvpn/updater/servers.go index dec25cbf..9280a75f 100644 --- a/internal/provider/expressvpn/updater/servers.go +++ b/internal/provider/expressvpn/updater/servers.go @@ -21,7 +21,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( hosts[i] = servers[i].Hostname } - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } diff --git a/internal/provider/expressvpn/updater/updater.go b/internal/provider/expressvpn/updater/updater.go index 452ef1ce..de83c263 100644 --- a/internal/provider/expressvpn/updater/updater.go +++ b/internal/provider/expressvpn/updater/updater.go @@ -10,10 +10,11 @@ type Updater struct { warner common.Warner } -func New(unzipper common.Unzipper, warner common.Warner) *Updater { +func New(unzipper common.Unzipper, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/fastestvpn/provider.go b/internal/provider/fastestvpn/provider.go index eb695827..93710796 100644 --- a/internal/provider/fastestvpn/provider.go +++ b/internal/provider/fastestvpn/provider.go @@ -17,12 +17,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Fastestvpn), - Fetcher: updater.New(unzipper, updaterWarner), + Fetcher: updater.New(unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/fastestvpn/updater/resolve.go b/internal/provider/fastestvpn/updater/resolve.go index c38ed35a..85d315cd 100644 --- a/internal/provider/fastestvpn/updater/resolve.go +++ b/internal/provider/fastestvpn/updater/resolve.go @@ -6,13 +6,14 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() *resolver.Parallel { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxNoNew = 1 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: time.Second, @@ -21,5 +22,4 @@ func newParallelResolver() *resolver.Parallel { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/fastestvpn/updater/servers.go b/internal/provider/fastestvpn/updater/servers.go index 6b87c055..9c59da3e 100644 --- a/internal/provider/fastestvpn/updater/servers.go +++ b/internal/provider/fastestvpn/updater/servers.go @@ -56,7 +56,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( } hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -64,15 +65,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } - hts.adaptWithIPs(hostToIPs) - - servers = hts.toServersSlice() - - if len(servers) < minServers { + if len(hostToIPs) < minServers { return nil, fmt.Errorf("%w: %d and expected at least %d", common.ErrNotEnoughServers, len(servers), minServers) } + hts.adaptWithIPs(hostToIPs) + + servers = hts.toServersSlice() + sort.Sort(models.SortableServers(servers)) return servers, nil diff --git a/internal/provider/fastestvpn/updater/updater.go b/internal/provider/fastestvpn/updater/updater.go index 452ef1ce..de83c263 100644 --- a/internal/provider/fastestvpn/updater/updater.go +++ b/internal/provider/fastestvpn/updater/updater.go @@ -10,10 +10,11 @@ type Updater struct { warner common.Warner } -func New(unzipper common.Unzipper, warner common.Warner) *Updater { +func New(unzipper common.Unzipper, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/hidemyass/provider.go b/internal/provider/hidemyass/provider.go index 9006cdf9..f4bc37dc 100644 --- a/internal/provider/hidemyass/provider.go +++ b/internal/provider/hidemyass/provider.go @@ -18,12 +18,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - client *http.Client, updaterWarner common.Warner) *Provider { + client *http.Client, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.HideMyAss), - Fetcher: updater.New(client, updaterWarner), + Fetcher: updater.New(client, updaterWarner, parallelResolver), } } diff --git a/internal/provider/hidemyass/updater/resolve.go b/internal/provider/hidemyass/updater/resolve.go index 9a95d22e..e3f76174 100644 --- a/internal/provider/hidemyass/updater/resolve.go +++ b/internal/provider/hidemyass/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() *resolver.Parallel { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 15 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() *resolver.Parallel { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() *resolver.Parallel { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/hidemyass/updater/servers.go b/internal/provider/hidemyass/updater/servers.go index e80a29f2..44d3496d 100644 --- a/internal/provider/hidemyass/updater/servers.go +++ b/internal/provider/hidemyass/updater/servers.go @@ -26,7 +26,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( common.ErrNotEnoughServers, len(hosts), minServers) } - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -34,6 +35,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } + if len(hostToIPs) < minServers { + return nil, fmt.Errorf("%w: %d and expected at least %d", + common.ErrNotEnoughServers, len(servers), minServers) + } + servers = make([]models.Server, 0, len(hostToIPs)) for host, IPs := range hostToIPs { tcpURL, tcp := tcpHostToURL[host] diff --git a/internal/provider/hidemyass/updater/updater.go b/internal/provider/hidemyass/updater/updater.go index fbecad18..f75e7b12 100644 --- a/internal/provider/hidemyass/updater/updater.go +++ b/internal/provider/hidemyass/updater/updater.go @@ -12,10 +12,11 @@ type Updater struct { warner common.Warner } -func New(client *http.Client, warner common.Warner) *Updater { +func New(client *http.Client, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ client: client, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/ipvanish/provider.go b/internal/provider/ipvanish/provider.go index 1ecef7cf..cb17137e 100644 --- a/internal/provider/ipvanish/provider.go +++ b/internal/provider/ipvanish/provider.go @@ -17,12 +17,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Ipvanish), - Fetcher: updater.New(unzipper, updaterWarner), + Fetcher: updater.New(unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/ipvanish/updater/resolve.go b/internal/provider/ipvanish/updater/resolve.go index c01e93b3..d0b3fa0c 100644 --- a/internal/provider/ipvanish/updater/resolve.go +++ b/internal/provider/ipvanish/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 20 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/ipvanish/updater/servers.go b/internal/provider/ipvanish/updater/servers.go index 168e3803..f3cd8a1d 100644 --- a/internal/provider/ipvanish/updater/servers.go +++ b/internal/provider/ipvanish/updater/servers.go @@ -67,7 +67,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( } hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -75,15 +76,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } - hts.adaptWithIPs(hostToIPs) - - servers = hts.toServersSlice() - - if len(servers) < minServers { + if len(hostToIPs) < minServers { return nil, fmt.Errorf("%w: %d and expected at least %d", common.ErrNotEnoughServers, len(servers), minServers) } + hts.adaptWithIPs(hostToIPs) + + servers = hts.toServersSlice() + sort.Sort(models.SortableServers(servers)) return servers, nil diff --git a/internal/provider/ipvanish/updater/servers_test.go b/internal/provider/ipvanish/updater/servers_test.go index 0ce64ce8..5f92f968 100644 --- a/internal/provider/ipvanish/updater/servers_test.go +++ b/internal/provider/ipvanish/updater/servers_test.go @@ -5,11 +5,13 @@ import ( "errors" "net" "testing" + "time" "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common" + "github.com/qdm12/gluetun/internal/updater/resolver" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -28,11 +30,11 @@ func Test_Updater_GetServers(t *testing.T) { unzipErr error // Resolution - expectResolve bool - hostsToResolve []string - hostToIPs map[string][]net.IP - resolveWarnings []string - resolveErr error + expectResolve bool + resolverSettings resolver.ParallelSettings + hostToIPs map[string][]net.IP + resolveWarnings []string + resolveErr error // Output servers []models.Server @@ -85,9 +87,19 @@ func Test_Updater_GetServers(t *testing.T) { unzipContents: map[string][]byte{ "ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta\nremote hostb"), }, - expectResolve: true, - hostsToResolve: []string{"hosta"}, - err: errors.New("not enough servers found: 0 and expected at least 1"), + expectResolve: true, + resolverSettings: resolver.ParallelSettings{ + Hosts: []string{"hosta"}, + MaxFailRatio: 0.1, + Repeat: resolver.RepeatSettings{ + MaxDuration: 20 * time.Second, + BetweenDuration: time.Second, + MaxNoNew: 2, + MaxFails: 2, + SortIPs: true, + }, + }, + err: errors.New("not enough servers found: 0 and expected at least 1"), }, "resolve error": { warnerBuilder: func(ctrl *gomock.Controller) common.Warner { @@ -98,8 +110,18 @@ func Test_Updater_GetServers(t *testing.T) { unzipContents: map[string][]byte{ "ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta"), }, - expectResolve: true, - hostsToResolve: []string{"hosta"}, + expectResolve: true, + resolverSettings: resolver.ParallelSettings{ + Hosts: []string{"hosta"}, + MaxFailRatio: 0.1, + Repeat: resolver.RepeatSettings{ + MaxDuration: 20 * time.Second, + BetweenDuration: time.Second, + MaxNoNew: 2, + MaxFails: 2, + SortIPs: true, + }, + }, resolveWarnings: []string{"resolve warning"}, resolveErr: errors.New("dummy"), err: errors.New("dummy"), @@ -127,8 +149,18 @@ func Test_Updater_GetServers(t *testing.T) { "ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta"), "ipvanish-LU-City-B-hostb.ovpn": []byte("remote hostb"), }, - expectResolve: true, - hostsToResolve: []string{"hosta", "hostb"}, + expectResolve: true, + resolverSettings: resolver.ParallelSettings{ + Hosts: []string{"hosta", "hostb"}, + MaxFailRatio: 0.1, + Repeat: resolver.RepeatSettings{ + MaxDuration: 20 * time.Second, + BetweenDuration: time.Second, + MaxNoNew: 2, + MaxFails: 2, + SortIPs: true, + }, + }, hostToIPs: map[string][]net.IP{ "hosta": {{1, 1, 1, 1}, {2, 2, 2, 2}}, "hostb": {{3, 3, 3, 3}, {4, 4, 4, 4}}, @@ -169,7 +201,7 @@ func Test_Updater_GetServers(t *testing.T) { presolver := common.NewMockParallelResolver(ctrl) if testCase.expectResolve { - presolver.EXPECT().Resolve(ctx, testCase.hostsToResolve, testCase.minServers). + presolver.EXPECT().Resolve(ctx, testCase.resolverSettings). Return(testCase.hostToIPs, testCase.resolveWarnings, testCase.resolveErr) } diff --git a/internal/provider/ipvanish/updater/updater.go b/internal/provider/ipvanish/updater/updater.go index 3d64a73e..6b9be217 100644 --- a/internal/provider/ipvanish/updater/updater.go +++ b/internal/provider/ipvanish/updater/updater.go @@ -10,10 +10,11 @@ type Updater struct { presolver common.ParallelResolver } -func New(unzipper common.Unzipper, warner common.Warner) *Updater { +func New(unzipper common.Unzipper, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ unzipper: unzipper, warner: warner, - presolver: newParallelResolver(), + presolver: parallelResolver, } } diff --git a/internal/provider/ivpn/connection_test.go b/internal/provider/ivpn/connection_test.go index a9d2ad31..e7ea9d37 100644 --- a/internal/provider/ivpn/connection_test.go +++ b/internal/provider/ivpn/connection_test.go @@ -99,7 +99,8 @@ func Test_Provider_GetConnection(t *testing.T) { client := (*http.Client)(nil) warner := (common.Warner)(nil) - provider := New(storage, randSource, client, warner) + parallelResolver := (common.ParallelResolver)(nil) + provider := New(storage, randSource, client, warner, parallelResolver) connection, err := provider.GetConnection(testCase.selection) diff --git a/internal/provider/ivpn/provider.go b/internal/provider/ivpn/provider.go index 5e59195a..1b19ff03 100644 --- a/internal/provider/ivpn/provider.go +++ b/internal/provider/ivpn/provider.go @@ -18,12 +18,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - client *http.Client, updaterWarner common.Warner) *Provider { + client *http.Client, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Ivpn), - Fetcher: updater.New(client, updaterWarner), + Fetcher: updater.New(client, updaterWarner, parallelResolver), } } diff --git a/internal/provider/ivpn/updater/resolve.go b/internal/provider/ivpn/updater/resolve.go index c01e93b3..d0b3fa0c 100644 --- a/internal/provider/ivpn/updater/resolve.go +++ b/internal/provider/ivpn/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 20 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/ivpn/updater/servers.go b/internal/provider/ivpn/updater/servers.go index 5e523c3e..9d5680f3 100644 --- a/internal/provider/ivpn/updater/servers.go +++ b/internal/provider/ivpn/updater/servers.go @@ -38,7 +38,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( common.ErrNotEnoughServers, len(hosts), minServers) } - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -46,6 +47,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } + if len(hostToIPs) < minServers { + return nil, fmt.Errorf("%w: %d and expected at least %d", + common.ErrNotEnoughServers, len(servers), minServers) + } + servers = make([]models.Server, 0, len(hosts)) for _, serverData := range data.Servers { vpnType := vpn.OpenVPN diff --git a/internal/provider/ivpn/updater/servers_test.go b/internal/provider/ivpn/updater/servers_test.go index 7ba69701..446cf36d 100644 --- a/internal/provider/ivpn/updater/servers_test.go +++ b/internal/provider/ivpn/updater/servers_test.go @@ -8,11 +8,13 @@ import ( "net/http" "strings" "testing" + "time" "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/common" + "github.com/qdm12/gluetun/internal/updater/resolver" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,7 +35,7 @@ func Test_Updater_GetServers(t *testing.T) { // Resolution expectResolve bool - hostsToResolve []string + resolveSettings resolver.ParallelSettings hostToIPs map[string][]net.IP resolveWarnings []string resolveErr error @@ -56,9 +58,19 @@ func Test_Updater_GetServers(t *testing.T) { responseBody: `{"servers":[ {"hostnames":{"openvpn":"hosta"}} ]}`, - responseStatus: http.StatusOK, - expectResolve: true, - hostsToResolve: []string{"hosta"}, + responseStatus: http.StatusOK, + expectResolve: true, + resolveSettings: resolver.ParallelSettings{ + Hosts: []string{"hosta"}, + MaxFailRatio: 0.1, + Repeat: resolver.RepeatSettings{ + MaxDuration: 20 * time.Second, + BetweenDuration: time.Second, + MaxNoNew: 2, + MaxFails: 2, + SortIPs: true, + }, + }, resolveWarnings: []string{"resolve warning"}, resolveErr: errors.New("dummy"), err: errors.New("dummy"), @@ -86,7 +98,17 @@ func Test_Updater_GetServers(t *testing.T) { ]}`, responseStatus: http.StatusOK, expectResolve: true, - hostsToResolve: []string{"hosta", "hostb", "hostc"}, + resolveSettings: resolver.ParallelSettings{ + Hosts: []string{"hosta", "hostb", "hostc"}, + MaxFailRatio: 0.1, + Repeat: resolver.RepeatSettings{ + MaxDuration: 20 * time.Second, + BetweenDuration: time.Second, + MaxNoNew: 2, + MaxFails: 2, + SortIPs: true, + }, + }, hostToIPs: map[string][]net.IP{ "hosta": {{1, 1, 1, 1}, {2, 2, 2, 2}}, "hostb": {{3, 3, 3, 3}, {4, 4, 4, 4}}, @@ -130,7 +152,7 @@ func Test_Updater_GetServers(t *testing.T) { presolver := common.NewMockParallelResolver(ctrl) if testCase.expectResolve { - presolver.EXPECT().Resolve(ctx, testCase.hostsToResolve, testCase.minServers). + presolver.EXPECT().Resolve(ctx, testCase.resolveSettings). Return(testCase.hostToIPs, testCase.resolveWarnings, testCase.resolveErr) } diff --git a/internal/provider/ivpn/updater/updater.go b/internal/provider/ivpn/updater/updater.go index fbecad18..f75e7b12 100644 --- a/internal/provider/ivpn/updater/updater.go +++ b/internal/provider/ivpn/updater/updater.go @@ -12,10 +12,11 @@ type Updater struct { warner common.Warner } -func New(client *http.Client, warner common.Warner) *Updater { +func New(client *http.Client, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ client: client, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/privado/provider.go b/internal/provider/privado/provider.go index e17f8d4e..3c8c4ef6 100644 --- a/internal/provider/privado/provider.go +++ b/internal/provider/privado/provider.go @@ -19,12 +19,13 @@ type Provider struct { func New(storage common.Storage, randSource rand.Source, client *http.Client, unzipper common.Unzipper, - updaterWarner common.Warner) *Provider { + updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Privado), - Fetcher: updater.New(client, unzipper, updaterWarner), + Fetcher: updater.New(client, unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/privado/updater/resolve.go b/internal/provider/privado/updater/resolve.go index 1c921cfc..e5d341e7 100644 --- a/internal/provider/privado/updater/resolve.go +++ b/internal/provider/privado/updater/resolve.go @@ -6,14 +6,15 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 30 * time.Second maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -22,5 +23,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/privado/updater/servers.go b/internal/provider/privado/updater/servers.go index 6e1435bd..689b5776 100644 --- a/internal/provider/privado/updater/servers.go +++ b/internal/provider/privado/updater/servers.go @@ -50,7 +50,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( } hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -58,6 +59,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } + if len(hostToIPs) < minServers { + return nil, fmt.Errorf("%w: %d and expected at least %d", + common.ErrNotEnoughServers, len(servers), minServers) + } + hts.adaptWithIPs(hostToIPs) servers = hts.toServersSlice() diff --git a/internal/provider/privado/updater/updater.go b/internal/provider/privado/updater/updater.go index 0a8c5c02..971f0037 100644 --- a/internal/provider/privado/updater/updater.go +++ b/internal/provider/privado/updater/updater.go @@ -14,11 +14,11 @@ type Updater struct { } func New(client *http.Client, unzipper common.Unzipper, - warner common.Warner) *Updater { + warner common.Warner, parallelResolver common.ParallelResolver) *Updater { return &Updater{ client: client, unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/privatevpn/provider.go b/internal/provider/privatevpn/provider.go index 67e0d2a9..10b8c82e 100644 --- a/internal/provider/privatevpn/provider.go +++ b/internal/provider/privatevpn/provider.go @@ -17,12 +17,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Privatevpn), - Fetcher: updater.New(unzipper, updaterWarner), + Fetcher: updater.New(unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/privatevpn/updater/resolve.go b/internal/provider/privatevpn/updater/resolve.go index c0112af5..e01ff836 100644 --- a/internal/provider/privatevpn/updater/resolve.go +++ b/internal/provider/privatevpn/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 6 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/privatevpn/updater/servers.go b/internal/provider/privatevpn/updater/servers.go index 159aa21c..6de29dde 100644 --- a/internal/provider/privatevpn/updater/servers.go +++ b/internal/provider/privatevpn/updater/servers.go @@ -82,7 +82,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -90,6 +91,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } + if len(noHostnameServers)+len(hostToIPs) < minServers { + return nil, fmt.Errorf("%w: %d and expected at least %d", + common.ErrNotEnoughServers, len(servers), minServers) + } + hts.adaptWithIPs(hostToIPs) servers = hts.toServersSlice() diff --git a/internal/provider/privatevpn/updater/updater.go b/internal/provider/privatevpn/updater/updater.go index 452ef1ce..de83c263 100644 --- a/internal/provider/privatevpn/updater/updater.go +++ b/internal/provider/privatevpn/updater/updater.go @@ -10,10 +10,11 @@ type Updater struct { warner common.Warner } -func New(unzipper common.Unzipper, warner common.Warner) *Updater { +func New(unzipper common.Unzipper, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/protonvpn/provider.go b/internal/provider/protonvpn/provider.go index f5b9e706..cfd71e4c 100644 --- a/internal/provider/protonvpn/provider.go +++ b/internal/provider/protonvpn/provider.go @@ -18,7 +18,8 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - client *http.Client, updaterWarner common.Warner) *Provider { + client *http.Client, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, diff --git a/internal/provider/providers.go b/internal/provider/providers.go index 59480934..a83b2be2 100644 --- a/internal/provider/providers.go +++ b/internal/provider/providers.go @@ -44,33 +44,36 @@ type Storage interface { } func NewProviders(storage Storage, timeNow func() time.Time, - updaterWarner common.Warner, client *http.Client, unzipper common.Unzipper) *Providers { + updaterWarner common.Warner, client *http.Client, unzipper common.Unzipper, + parallelResolver common.ParallelResolver) *Providers { randSource := rand.NewSource(timeNow().UnixNano()) - targetLength := len(providers.AllWithCustom()) - providerNameToProvider := make(map[string]Provider, targetLength) - providerNameToProvider[providers.Custom] = custom.New() - providerNameToProvider[providers.Cyberghost] = cyberghost.New(storage, randSource) - providerNameToProvider[providers.Expressvpn] = expressvpn.New(storage, randSource, unzipper, updaterWarner) - providerNameToProvider[providers.Fastestvpn] = fastestvpn.New(storage, randSource, unzipper, updaterWarner) - providerNameToProvider[providers.HideMyAss] = hidemyass.New(storage, randSource, client, updaterWarner) - providerNameToProvider[providers.Ipvanish] = ipvanish.New(storage, randSource, unzipper, updaterWarner) - providerNameToProvider[providers.Ivpn] = ivpn.New(storage, randSource, client, updaterWarner) - providerNameToProvider[providers.Mullvad] = mullvad.New(storage, randSource, client) - providerNameToProvider[providers.Nordvpn] = nordvpn.New(storage, randSource, client, updaterWarner) - providerNameToProvider[providers.Perfectprivacy] = perfectprivacy.New(storage, randSource, unzipper, updaterWarner) - providerNameToProvider[providers.Privado] = privado.New(storage, randSource, client, unzipper, updaterWarner) - providerNameToProvider[providers.PrivateInternetAccess] = privateinternetaccess.New(storage, randSource, timeNow, client) //nolint:lll - providerNameToProvider[providers.Privatevpn] = privatevpn.New(storage, randSource, unzipper, updaterWarner) - providerNameToProvider[providers.Protonvpn] = protonvpn.New(storage, randSource, client, updaterWarner) - providerNameToProvider[providers.Purevpn] = purevpn.New(storage, randSource, client, unzipper, updaterWarner) - providerNameToProvider[providers.Surfshark] = surfshark.New(storage, randSource, client, unzipper, updaterWarner) - providerNameToProvider[providers.Torguard] = torguard.New(storage, randSource, unzipper, updaterWarner) - providerNameToProvider[providers.VPNUnlimited] = vpnunlimited.New(storage, randSource, unzipper, updaterWarner) - providerNameToProvider[providers.Vyprvpn] = vyprvpn.New(storage, randSource, unzipper, updaterWarner) - providerNameToProvider[providers.Wevpn] = wevpn.New(storage, randSource, updaterWarner) - providerNameToProvider[providers.Windscribe] = windscribe.New(storage, randSource, client, updaterWarner) + //nolint:lll + providerNameToProvider := map[string]Provider{ + providers.Custom: custom.New(), + providers.Cyberghost: cyberghost.New(storage, randSource, parallelResolver), + providers.Expressvpn: expressvpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver), + providers.Fastestvpn: fastestvpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver), + providers.HideMyAss: hidemyass.New(storage, randSource, client, updaterWarner, parallelResolver), + providers.Ipvanish: ipvanish.New(storage, randSource, unzipper, updaterWarner, parallelResolver), + providers.Ivpn: ivpn.New(storage, randSource, client, updaterWarner, parallelResolver), + providers.Mullvad: mullvad.New(storage, randSource, client), + providers.Nordvpn: nordvpn.New(storage, randSource, client, updaterWarner), + providers.Perfectprivacy: perfectprivacy.New(storage, randSource, unzipper, updaterWarner), + providers.Privado: privado.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver), + providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client), + providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver), + providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, parallelResolver), + providers.Purevpn: purevpn.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver), + providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver), + providers.Torguard: torguard.New(storage, randSource, unzipper, updaterWarner, parallelResolver), + providers.VPNUnlimited: vpnunlimited.New(storage, randSource, unzipper, updaterWarner, parallelResolver), + providers.Vyprvpn: vyprvpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver), + providers.Wevpn: wevpn.New(storage, randSource, updaterWarner, parallelResolver), + providers.Windscribe: windscribe.New(storage, randSource, client, updaterWarner), + } + targetLength := len(providers.AllWithCustom()) if len(providerNameToProvider) != targetLength { // Programming sanity check panic(fmt.Sprintf("invalid number of providers, expected %d but got %d", diff --git a/internal/provider/purevpn/provider.go b/internal/provider/purevpn/provider.go index ac78689b..dc6e9adc 100644 --- a/internal/provider/purevpn/provider.go +++ b/internal/provider/purevpn/provider.go @@ -18,12 +18,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Purevpn), - Fetcher: updater.New(client, unzipper, updaterWarner), + Fetcher: updater.New(client, unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/purevpn/updater/resolve.go b/internal/provider/purevpn/updater/resolve.go index c01e93b3..d0b3fa0c 100644 --- a/internal/provider/purevpn/updater/resolve.go +++ b/internal/provider/purevpn/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 20 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/purevpn/updater/servers.go b/internal/provider/purevpn/updater/servers.go index cf862a2d..21098297 100644 --- a/internal/provider/purevpn/updater/servers.go +++ b/internal/provider/purevpn/updater/servers.go @@ -60,7 +60,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( } hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -68,15 +69,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } - hts.adaptWithIPs(hostToIPs) - - servers = hts.toServersSlice() - - if len(servers) < minServers { + if len(hostToIPs) < minServers { return nil, fmt.Errorf("%w: %d and expected at least %d", common.ErrNotEnoughServers, len(servers), minServers) } + hts.adaptWithIPs(hostToIPs) + + servers = hts.toServersSlice() + // Get public IP address information ipsToGetInfo := make([]net.IP, len(servers)) for i := range servers { diff --git a/internal/provider/purevpn/updater/updater.go b/internal/provider/purevpn/updater/updater.go index 0a8c5c02..971f0037 100644 --- a/internal/provider/purevpn/updater/updater.go +++ b/internal/provider/purevpn/updater/updater.go @@ -14,11 +14,11 @@ type Updater struct { } func New(client *http.Client, unzipper common.Unzipper, - warner common.Warner) *Updater { + warner common.Warner, parallelResolver common.ParallelResolver) *Updater { return &Updater{ client: client, unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/surfshark/provider.go b/internal/provider/surfshark/provider.go index 065b8f4d..58872deb 100644 --- a/internal/provider/surfshark/provider.go +++ b/internal/provider/surfshark/provider.go @@ -18,12 +18,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Surfshark), - Fetcher: updater.New(client, unzipper, updaterWarner), + Fetcher: updater.New(client, unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/surfshark/updater/resolve.go b/internal/provider/surfshark/updater/resolve.go index c01e93b3..d0b3fa0c 100644 --- a/internal/provider/surfshark/updater/resolve.go +++ b/internal/provider/surfshark/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 20 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/surfshark/updater/servers.go b/internal/provider/surfshark/updater/servers.go index 602bcd86..c2032eaf 100644 --- a/internal/provider/surfshark/updater/servers.go +++ b/internal/provider/surfshark/updater/servers.go @@ -31,7 +31,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( getRemainingServers(hts) hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -39,15 +40,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } - hts.adaptWithIPs(hostToIPs) - - servers = hts.toServersSlice() - - if len(servers) < minServers { + if len(hostToIPs) < minServers { return nil, fmt.Errorf("%w: %d and expected at least %d", common.ErrNotEnoughServers, len(servers), minServers) } + hts.adaptWithIPs(hostToIPs) + + servers = hts.toServersSlice() + sort.Sort(models.SortableServers(servers)) return servers, nil diff --git a/internal/provider/surfshark/updater/updater.go b/internal/provider/surfshark/updater/updater.go index 0a8c5c02..971f0037 100644 --- a/internal/provider/surfshark/updater/updater.go +++ b/internal/provider/surfshark/updater/updater.go @@ -14,11 +14,11 @@ type Updater struct { } func New(client *http.Client, unzipper common.Unzipper, - warner common.Warner) *Updater { + warner common.Warner, parallelResolver common.ParallelResolver) *Updater { return &Updater{ client: client, unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/torguard/provider.go b/internal/provider/torguard/provider.go index 6f3a8e5f..e5a85ffb 100644 --- a/internal/provider/torguard/provider.go +++ b/internal/provider/torguard/provider.go @@ -17,12 +17,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Torguard), - Fetcher: updater.New(unzipper, updaterWarner), + Fetcher: updater.New(unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/torguard/updater/resolve.go b/internal/provider/torguard/updater/resolve.go index c01e93b3..d0b3fa0c 100644 --- a/internal/provider/torguard/updater/resolve.go +++ b/internal/provider/torguard/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 20 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/torguard/updater/servers.go b/internal/provider/torguard/updater/servers.go index 5c8e7851..369d9396 100644 --- a/internal/provider/torguard/updater/servers.go +++ b/internal/provider/torguard/updater/servers.go @@ -50,12 +50,18 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( } hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) u.warnWarnings(warnings) if err != nil { return nil, err } + if len(hostToIPs) < minServers { + return nil, fmt.Errorf("%w: %d and expected at least %d", + common.ErrNotEnoughServers, len(servers), minServers) + } + hts.adaptWithIPs(hostToIPs) servers = hts.toServersSlice() diff --git a/internal/provider/torguard/updater/updater.go b/internal/provider/torguard/updater/updater.go index 452ef1ce..de83c263 100644 --- a/internal/provider/torguard/updater/updater.go +++ b/internal/provider/torguard/updater/updater.go @@ -10,10 +10,11 @@ type Updater struct { warner common.Warner } -func New(unzipper common.Unzipper, warner common.Warner) *Updater { +func New(unzipper common.Unzipper, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/vpnunlimited/provider.go b/internal/provider/vpnunlimited/provider.go index 51c0c999..33f3412c 100644 --- a/internal/provider/vpnunlimited/provider.go +++ b/internal/provider/vpnunlimited/provider.go @@ -17,12 +17,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.VPNUnlimited), - Fetcher: updater.New(unzipper, updaterWarner), + Fetcher: updater.New(unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/vpnunlimited/updater/resolve.go b/internal/provider/vpnunlimited/updater/resolve.go index c01e93b3..d0b3fa0c 100644 --- a/internal/provider/vpnunlimited/updater/resolve.go +++ b/internal/provider/vpnunlimited/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 20 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/vpnunlimited/updater/servers.go b/internal/provider/vpnunlimited/updater/servers.go index c485556b..dd58f7f5 100644 --- a/internal/provider/vpnunlimited/updater/servers.go +++ b/internal/provider/vpnunlimited/updater/servers.go @@ -20,7 +20,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( } hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -28,15 +29,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } - hts.adaptWithIPs(hostToIPs) - - servers = hts.toServersSlice() - - if len(servers) < minServers { + if len(hostToIPs) < minServers { return nil, fmt.Errorf("%w: %d and expected at least %d", common.ErrNotEnoughServers, len(servers), minServers) } + hts.adaptWithIPs(hostToIPs) + + servers = hts.toServersSlice() + sort.Sort(models.SortableServers(servers)) return servers, nil diff --git a/internal/provider/vpnunlimited/updater/updater.go b/internal/provider/vpnunlimited/updater/updater.go index 452ef1ce..de83c263 100644 --- a/internal/provider/vpnunlimited/updater/updater.go +++ b/internal/provider/vpnunlimited/updater/updater.go @@ -10,10 +10,11 @@ type Updater struct { warner common.Warner } -func New(unzipper common.Unzipper, warner common.Warner) *Updater { +func New(unzipper common.Unzipper, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/vyprvpn/provider.go b/internal/provider/vyprvpn/provider.go index d055489d..d5eddfcb 100644 --- a/internal/provider/vyprvpn/provider.go +++ b/internal/provider/vyprvpn/provider.go @@ -17,12 +17,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - unzipper common.Unzipper, updaterWarner common.Warner) *Provider { + unzipper common.Unzipper, updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Vyprvpn), - Fetcher: updater.New(unzipper, updaterWarner), + Fetcher: updater.New(unzipper, updaterWarner, parallelResolver), } } diff --git a/internal/provider/vyprvpn/updater/resolve.go b/internal/provider/vyprvpn/updater/resolve.go index 7a04574b..fde66d47 100644 --- a/internal/provider/vyprvpn/updater/resolve.go +++ b/internal/provider/vyprvpn/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 5 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/vyprvpn/updater/servers.go b/internal/provider/vyprvpn/updater/servers.go index 54dcc304..c21f4c5f 100644 --- a/internal/provider/vyprvpn/updater/servers.go +++ b/internal/provider/vyprvpn/updater/servers.go @@ -64,7 +64,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( } hosts := hts.toHostsSlice() - hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) + resolveSettings := parallelResolverSettings(hosts) + hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings) for _, warning := range warnings { u.warner.Warn(warning) } @@ -72,15 +73,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( return nil, err } - hts.adaptWithIPs(hostToIPs) - - servers = hts.toServersSlice() - - if len(servers) < minServers { + if len(hostToIPs) < minServers { return nil, fmt.Errorf("%w: %d and expected at least %d", common.ErrNotEnoughServers, len(servers), minServers) } + hts.adaptWithIPs(hostToIPs) + + servers = hts.toServersSlice() + sort.Sort(models.SortableServers(servers)) return servers, nil diff --git a/internal/provider/vyprvpn/updater/updater.go b/internal/provider/vyprvpn/updater/updater.go index 452ef1ce..de83c263 100644 --- a/internal/provider/vyprvpn/updater/updater.go +++ b/internal/provider/vyprvpn/updater/updater.go @@ -10,10 +10,11 @@ type Updater struct { warner common.Warner } -func New(unzipper common.Unzipper, warner common.Warner) *Updater { +func New(unzipper common.Unzipper, warner common.Warner, + parallelResolver common.ParallelResolver) *Updater { return &Updater{ unzipper: unzipper, - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/provider/wevpn/connection_test.go b/internal/provider/wevpn/connection_test.go index 7033f66c..979b53ca 100644 --- a/internal/provider/wevpn/connection_test.go +++ b/internal/provider/wevpn/connection_test.go @@ -93,7 +93,8 @@ func Test_Provider_GetConnection(t *testing.T) { randSource := rand.NewSource(0) warner := (common.Warner)(nil) - provider := New(storage, randSource, warner) + parallelResolver := (common.ParallelResolver)(nil) + provider := New(storage, randSource, warner, parallelResolver) if testCase.panicMessage != "" { assert.PanicsWithValue(t, testCase.panicMessage, func() { diff --git a/internal/provider/wevpn/provider.go b/internal/provider/wevpn/provider.go index 43db02fc..6c89da2f 100644 --- a/internal/provider/wevpn/provider.go +++ b/internal/provider/wevpn/provider.go @@ -17,12 +17,13 @@ type Provider struct { } func New(storage common.Storage, randSource rand.Source, - updaterWarner common.Warner) *Provider { + updaterWarner common.Warner, + parallelResolver common.ParallelResolver) *Provider { return &Provider{ storage: storage, randSource: randSource, NoPortForwarder: utils.NewNoPortForwarding(providers.Wevpn), - Fetcher: updater.New(updaterWarner), + Fetcher: updater.New(updaterWarner, parallelResolver), } } diff --git a/internal/provider/wevpn/updater/resolve.go b/internal/provider/wevpn/updater/resolve.go index c01e93b3..d0b3fa0c 100644 --- a/internal/provider/wevpn/updater/resolve.go +++ b/internal/provider/wevpn/updater/resolve.go @@ -6,7 +6,7 @@ import ( "github.com/qdm12/gluetun/internal/updater/resolver" ) -func newParallelResolver() (parallelResolver *resolver.Parallel) { +func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) { const ( maxFailRatio = 0.1 maxDuration = 20 * time.Second @@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { maxNoNew = 2 maxFails = 2 ) - settings := resolver.ParallelSettings{ + return resolver.ParallelSettings{ + Hosts: hosts, MaxFailRatio: maxFailRatio, Repeat: resolver.RepeatSettings{ MaxDuration: maxDuration, @@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) { SortIPs: true, }, } - return resolver.NewParallelResolver(settings) } diff --git a/internal/provider/wevpn/updater/servers.go b/internal/provider/wevpn/updater/servers.go index 0647fc9c..b8d68fb2 100644 --- a/internal/provider/wevpn/updater/servers.go +++ b/internal/provider/wevpn/updater/servers.go @@ -31,7 +31,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) ( hostnameToCity[hostname] = city } - hostnameToIPs, warnings, err := u.presolver.Resolve(ctx, hostnames, minServers) + resolverSettings := parallelResolverSettings(hostnames) + hostnameToIPs, warnings, err := u.presolver.Resolve(ctx, resolverSettings) for _, warning := range warnings { u.warner.Warn(warning) } diff --git a/internal/provider/wevpn/updater/updater.go b/internal/provider/wevpn/updater/updater.go index 06a68c69..3a3140b6 100644 --- a/internal/provider/wevpn/updater/updater.go +++ b/internal/provider/wevpn/updater/updater.go @@ -7,9 +7,9 @@ type Updater struct { warner common.Warner } -func New(warner common.Warner) *Updater { +func New(warner common.Warner, parallelResolver common.ParallelResolver) *Updater { return &Updater{ - presolver: newParallelResolver(), + presolver: parallelResolver, warner: warner, } } diff --git a/internal/updater/resolver/parallel.go b/internal/updater/resolver/parallel.go index 27ff2800..6f3506e4 100644 --- a/internal/updater/resolver/parallel.go +++ b/internal/updater/resolver/parallel.go @@ -9,17 +9,17 @@ import ( type Parallel struct { repeatResolver *Repeat - settings ParallelSettings } -func NewParallelResolver(settings ParallelSettings) *Parallel { +func NewParallelResolver(resolverAddress string) *Parallel { return &Parallel{ - repeatResolver: NewRepeat(settings.Repeat), - settings: settings, + repeatResolver: NewRepeat(resolverAddress), } } type ParallelSettings struct { + // Hosts to resolve in parallel. + Hosts []string Repeat RepeatSettings FailEarly bool // Maximum ratio of the hosts failing DNS resolution @@ -39,7 +39,7 @@ var ( ErrMaxFailRatio = errors.New("maximum failure ratio reached") ) -func (pr *Parallel) Resolve(ctx context.Context, hosts []string, minToFind int) ( +func (pr *Parallel) Resolve(ctx context.Context, settings ParallelSettings) ( hostToIPs map[string][]net.IP, warnings []string, err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -49,17 +49,17 @@ func (pr *Parallel) Resolve(ctx context.Context, hosts []string, minToFind int) errors := make(chan error) defer close(errors) - for _, host := range hosts { - go pr.resolveAsync(ctx, host, results, errors) + for _, host := range settings.Hosts { + go pr.resolveAsync(ctx, host, settings.Repeat, results, errors) } - hostToIPs = make(map[string][]net.IP, len(hosts)) - maxFails := int(pr.settings.MaxFailRatio * float64(len(hosts))) + hostToIPs = make(map[string][]net.IP, len(settings.Hosts)) + maxFails := int(settings.MaxFailRatio * float64(len(settings.Hosts))) - for range hosts { + for range settings.Hosts { select { case newErr := <-errors: - if pr.settings.FailEarly { + if settings.FailEarly { if err == nil { // only set the error to the first error encountered // and not the context canceled errors coming after. @@ -86,14 +86,8 @@ func (pr *Parallel) Resolve(ctx context.Context, hosts []string, minToFind int) return nil, warnings, err } - if len(hostToIPs) < minToFind { - return nil, warnings, - fmt.Errorf("%w: found %d hosts but expected at least %d", - ErrMinFound, len(hostToIPs), minToFind) - } - - failureRatio := float64(len(warnings)) / float64(len(hosts)) - if failureRatio > pr.settings.MaxFailRatio { + failureRatio := float64(len(warnings)) / float64(len(settings.Hosts)) + if failureRatio > settings.MaxFailRatio { return hostToIPs, warnings, fmt.Errorf("%w: %.2f failure ratio reached", ErrMaxFailRatio, failureRatio) } @@ -102,8 +96,8 @@ func (pr *Parallel) Resolve(ctx context.Context, hosts []string, minToFind int) } func (pr *Parallel) resolveAsync(ctx context.Context, host string, - results chan<- parallelResult, errors chan<- error) { - IPs, err := pr.repeatResolver.Resolve(ctx, host) + settings RepeatSettings, results chan<- parallelResult, errors chan<- error) { + IPs, err := pr.repeatResolver.Resolve(ctx, host, settings) if err != nil { errors <- err return diff --git a/internal/updater/resolver/repeat.go b/internal/updater/resolver/repeat.go index 1b470498..94822f8b 100644 --- a/internal/updater/resolver/repeat.go +++ b/internal/updater/resolver/repeat.go @@ -12,13 +12,11 @@ import ( type Repeat struct { resolver *net.Resolver - settings RepeatSettings } -func NewRepeat(settings RepeatSettings) *Repeat { +func NewRepeat(resolverAddress string) *Repeat { return &Repeat{ - resolver: newResolver(settings.Address), - settings: settings, + resolver: newResolver(resolverAddress), } } @@ -32,8 +30,9 @@ type RepeatSettings struct { SortIPs bool } -func (r *Repeat) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { - timedCtx, cancel := context.WithTimeout(ctx, r.settings.MaxDuration) +func (r *Repeat) Resolve(ctx context.Context, host string, settings RepeatSettings) ( + ips []net.IP, err error) { + timedCtx, cancel := context.WithTimeout(ctx, settings.MaxDuration) defer cancel() noNewCounter := 0 @@ -44,7 +43,7 @@ func (r *Repeat) Resolve(ctx context.Context, host string) (ips []net.IP, err er // TODO // - one resolving every 100ms for round robin DNS responses // - one every second for time based DNS cycling responses - noNewCounter, failCounter, err = r.resolveOnce(ctx, timedCtx, host, r.settings, uniqueIPs, noNewCounter, failCounter) + noNewCounter, failCounter, err = r.resolveOnce(ctx, timedCtx, host, settings, uniqueIPs, noNewCounter, failCounter) } if len(uniqueIPs) == 0 { @@ -53,7 +52,7 @@ func (r *Repeat) Resolve(ctx context.Context, host string) (ips []net.IP, err er ips = uniqueIPsToSlice(uniqueIPs) - if r.settings.SortIPs { + if settings.SortIPs { sort.Slice(ips, func(i, j int) bool { return bytes.Compare(ips[i], ips[j]) < 1 })