updater: refactoring and set DNS server correctly

- Fix CLI operation not setting DNS server
- Fix periodic operation not setting DNS server
- Set DNS address for resolution once at start for both CLI and periodic operation
- Inject resolver to each provider instead of creating it within
- Use resolver settings on every call to `.Resolve` method, instead of passing it to constructor
- Move out minServers check from resolver
This commit is contained in:
Quentin McGaw
2022-06-11 17:41:57 +00:00
parent 1bd355ab96
commit 447a7c9891
70 changed files with 366 additions and 229 deletions

View File

@@ -39,6 +39,7 @@ import (
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/tun" "github.com/qdm12/gluetun/internal/tun"
updater "github.com/qdm12/gluetun/internal/updater/loop" updater "github.com/qdm12/gluetun/internal/updater/loop"
"github.com/qdm12/gluetun/internal/updater/resolver"
"github.com/qdm12/gluetun/internal/updater/unzip" "github.com/qdm12/gluetun/internal/updater/unzip"
"github.com/qdm12/gluetun/internal/vpn" "github.com/qdm12/gluetun/internal/vpn"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
@@ -379,7 +380,9 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
updaterLogger := logger.New(log.SetComponent("updater")) updaterLogger := logger.New(log.SetComponent("updater"))
unzipper := unzip.New(httpClient) unzipper := unzip.New(httpClient)
providers := provider.NewProviders(storage, time.Now, updaterLogger, httpClient, unzipper) parallelResolver := resolver.NewParallelResolver(allSettings.Updater.DNSAddress.String())
providers := provider.NewProviders(storage, time.Now,
updaterLogger, httpClient, unzipper, parallelResolver)
vpnLogger := logger.New(log.SetComponent("vpn")) vpnLogger := logger.New(log.SetComponent("vpn"))
vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.Firewall.VPNInputPorts, vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.Firewall.VPNInputPorts,

View File

@@ -3,6 +3,7 @@ package cli
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@@ -11,6 +12,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater/resolver"
) )
type OpenvpnConfigLogger interface { type OpenvpnConfigLogger interface {
@@ -23,6 +25,11 @@ type Unzipper interface {
contents map[string][]byte, err error) contents map[string][]byte, err error)
} }
type ParallelResolver interface {
Resolve(ctx context.Context, settings resolver.ParallelSettings) (
hostToIPs map[string][]net.IP, warnings []string, err error)
}
func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) error { func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) error {
storage, err := storage.New(logger, constants.ServersData) storage, err := storage.New(logger, constants.ServersData)
if err != nil { if err != nil {
@@ -42,8 +49,9 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source sources.Source) e
unzipper := (Unzipper)(nil) unzipper := (Unzipper)(nil)
client := (*http.Client)(nil) client := (*http.Client)(nil)
warner := (Warner)(nil) warner := (Warner)(nil)
parallelResolver := (ParallelResolver)(nil)
providers := provider.NewProviders(storage, time.Now, warner, client, unzipper) providers := provider.NewProviders(storage, time.Now, warner, client, unzipper, parallelResolver)
providerConf := providers.Get(*allSettings.VPN.Provider.Name) providerConf := providers.Get(*allSettings.VPN.Provider.Name)
connection, err := providerConf.GetConnection(allSettings.VPN.Provider.ServerSelection) connection, err := providerConf.GetConnection(allSettings.VPN.Provider.ServerSelection)
if err != nil { if err != nil {

View File

@@ -16,6 +16,7 @@ import (
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/gluetun/internal/updater/resolver"
"github.com/qdm12/gluetun/internal/updater/unzip" "github.com/qdm12/gluetun/internal/updater/unzip"
) )
@@ -71,17 +72,17 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
return fmt.Errorf("options validation failed: %w", err) return fmt.Errorf("options validation failed: %w", err)
} }
const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
storage, err := storage.New(logger, constants.ServersData) storage, err := storage.New(logger, constants.ServersData)
if err != nil { if err != nil {
return fmt.Errorf("cannot create servers storage: %w", err) return fmt.Errorf("cannot create servers storage: %w", err)
} }
const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
unzipper := unzip.New(httpClient) unzipper := unzip.New(httpClient)
parallelResolver := resolver.NewParallelResolver(options.DNSAddress.String())
providers := provider.NewProviders(storage, time.Now, logger, httpClient, unzipper) providers := provider.NewProviders(storage, time.Now, logger, httpClient, unzipper, parallelResolver)
updater := updater.New(httpClient, storage, providers, logger) updater := updater.New(httpClient, storage, providers, logger)
err = updater.UpdateServers(ctx, options.Providers) err = updater.UpdateServers(ctx, options.Providers)

View File

@@ -12,6 +12,7 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
settings "github.com/qdm12/gluetun/internal/configuration/settings" settings "github.com/qdm12/gluetun/internal/configuration/settings"
models "github.com/qdm12/gluetun/internal/models" models "github.com/qdm12/gluetun/internal/models"
resolver "github.com/qdm12/gluetun/internal/updater/resolver"
) )
// MockParallelResolver is a mock of ParallelResolver interface. // MockParallelResolver is a mock of ParallelResolver interface.
@@ -38,9 +39,9 @@ func (m *MockParallelResolver) EXPECT() *MockParallelResolverMockRecorder {
} }
// Resolve mocks base method. // Resolve mocks base method.
func (m *MockParallelResolver) Resolve(arg0 context.Context, arg1 []string, arg2 int) (map[string][]net.IP, []string, error) { func (m *MockParallelResolver) Resolve(arg0 context.Context, arg1 resolver.ParallelSettings) (map[string][]net.IP, []string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Resolve", arg0, arg1, arg2) ret := m.ctrl.Call(m, "Resolve", arg0, arg1)
ret0, _ := ret[0].(map[string][]net.IP) ret0, _ := ret[0].(map[string][]net.IP)
ret1, _ := ret[1].([]string) ret1, _ := ret[1].([]string)
ret2, _ := ret[2].(error) ret2, _ := ret[2].(error)
@@ -48,9 +49,9 @@ func (m *MockParallelResolver) Resolve(arg0 context.Context, arg1 []string, arg2
} }
// Resolve indicates an expected call of Resolve. // Resolve indicates an expected call of Resolve.
func (mr *MockParallelResolverMockRecorder) Resolve(arg0, arg1, arg2 interface{}) *gomock.Call { func (mr *MockParallelResolverMockRecorder) Resolve(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockParallelResolver)(nil).Resolve), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockParallelResolver)(nil).Resolve), arg0, arg1)
} }
// MockStorage is a mock of Storage interface. // MockStorage is a mock of Storage interface.

View File

@@ -6,6 +6,7 @@ import (
"net" "net"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/updater/resolver"
) )
var ErrNotEnoughServers = errors.New("not enough servers found") var ErrNotEnoughServers = errors.New("not enough servers found")
@@ -15,7 +16,7 @@ type Fetcher interface {
} }
type ParallelResolver interface { type ParallelResolver interface {
Resolve(ctx context.Context, hosts []string, minToFind int) ( Resolve(ctx context.Context, settings resolver.ParallelSettings) (
hostToIPs map[string][]net.IP, warnings []string, err error) hostToIPs map[string][]net.IP, warnings []string, err error)
} }

View File

@@ -16,12 +16,13 @@ type Provider struct {
common.Fetcher common.Fetcher
} }
func New(storage common.Storage, randSource rand.Source) *Provider { func New(storage common.Storage, randSource rand.Source,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Cyberghost), NoPortForwarder: utils.NewNoPortForwarding(providers.Cyberghost),
Fetcher: updater.New(), Fetcher: updater.New(parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 1 maxFailRatio = 1
maxDuration = 20 * time.Second maxDuration = 20 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 4 maxNoNew = 4
maxFails = 10 maxFails = 10
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -4,9 +4,11 @@ package updater
import ( import (
"context" "context"
"fmt"
"sort" "sort"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
) )
func (u *Updater) FetchServers(ctx context.Context, minServers int) ( func (u *Updater) FetchServers(ctx context.Context, minServers int) (
@@ -14,11 +16,17 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
possibleServers := getPossibleServers() possibleServers := getPossibleServers()
possibleHosts := possibleServers.hostsSlice() possibleHosts := possibleServers.hostsSlice()
hostToIPs, _, err := u.presolver.Resolve(ctx, possibleHosts, minServers) resolveSettings := parallelResolverSettings(possibleHosts)
hostToIPs, _, err := u.presolver.Resolve(ctx, resolveSettings)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(hostToIPs) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers)
}
possibleServers.adaptWithIPs(hostToIPs) possibleServers.adaptWithIPs(hostToIPs)
servers = possibleServers.toSlice() servers = possibleServers.toSlice()

View File

@@ -8,8 +8,8 @@ type Updater struct {
presolver common.ParallelResolver presolver common.ParallelResolver
} }
func New() *Updater { func New(parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
presolver: newParallelResolver(), presolver: parallelResolver,
} }
} }

View File

@@ -89,7 +89,8 @@ func Test_Provider_GetConnection(t *testing.T) {
unzipper := (common.Unzipper)(nil) unzipper := (common.Unzipper)(nil)
warner := (common.Warner)(nil) warner := (common.Warner)(nil)
provider := New(storage, randSource, unzipper, warner) parallelResolver := (common.ParallelResolver)(nil)
provider := New(storage, randSource, unzipper, warner, parallelResolver)
if testCase.panicMessage != "" { if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() { assert.PanicsWithValue(t, testCase.panicMessage, func() {

View File

@@ -17,12 +17,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
unzipper common.Unzipper, updaterWarner common.Warner) *Provider { unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Expressvpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Expressvpn),
Fetcher: updater.New(unzipper, updaterWarner), Fetcher: updater.New(unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,13 +6,14 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() *resolver.Parallel { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxNoNew = 1 maxNoNew = 1
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: time.Second, MaxDuration: time.Second,
@@ -21,5 +22,4 @@ func newParallelResolver() *resolver.Parallel {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -21,7 +21,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
hosts[i] = servers[i].Hostname hosts[i] = servers[i].Hostname
} }
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }

View File

@@ -10,10 +10,11 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(unzipper common.Unzipper, warner common.Warner) *Updater { func New(unzipper common.Unzipper, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -17,12 +17,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
unzipper common.Unzipper, updaterWarner common.Warner) *Provider { unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Fastestvpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Fastestvpn),
Fetcher: updater.New(unzipper, updaterWarner), Fetcher: updater.New(unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,13 +6,14 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() *resolver.Parallel { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxNoNew = 1 maxNoNew = 1
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: time.Second, MaxDuration: time.Second,
@@ -21,5 +22,4 @@ func newParallelResolver() *resolver.Parallel {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -56,7 +56,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
} }
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -64,15 +65,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
hts.adaptWithIPs(hostToIPs) if len(hostToIPs) < minServers {
servers = hts.toServersSlice()
if len(servers) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d", return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers) common.ErrNotEnoughServers, len(servers), minServers)
} }
hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice()
sort.Sort(models.SortableServers(servers)) sort.Sort(models.SortableServers(servers))
return servers, nil return servers, nil

View File

@@ -10,10 +10,11 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(unzipper common.Unzipper, warner common.Warner) *Updater { func New(unzipper common.Unzipper, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -18,12 +18,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
client *http.Client, updaterWarner common.Warner) *Provider { client *http.Client, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.HideMyAss), NoPortForwarder: utils.NewNoPortForwarding(providers.HideMyAss),
Fetcher: updater.New(client, updaterWarner), Fetcher: updater.New(client, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() *resolver.Parallel { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 15 * time.Second maxDuration = 15 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() *resolver.Parallel {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() *resolver.Parallel {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -26,7 +26,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
common.ErrNotEnoughServers, len(hosts), minServers) common.ErrNotEnoughServers, len(hosts), minServers)
} }
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -34,6 +35,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
if len(hostToIPs) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers)
}
servers = make([]models.Server, 0, len(hostToIPs)) servers = make([]models.Server, 0, len(hostToIPs))
for host, IPs := range hostToIPs { for host, IPs := range hostToIPs {
tcpURL, tcp := tcpHostToURL[host] tcpURL, tcp := tcpHostToURL[host]

View File

@@ -12,10 +12,11 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(client *http.Client, warner common.Warner) *Updater { func New(client *http.Client, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
client: client, client: client,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -17,12 +17,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
unzipper common.Unzipper, updaterWarner common.Warner) *Provider { unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Ipvanish), NoPortForwarder: utils.NewNoPortForwarding(providers.Ipvanish),
Fetcher: updater.New(unzipper, updaterWarner), Fetcher: updater.New(unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 20 * time.Second maxDuration = 20 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -67,7 +67,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
} }
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -75,15 +76,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
hts.adaptWithIPs(hostToIPs) if len(hostToIPs) < minServers {
servers = hts.toServersSlice()
if len(servers) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d", return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers) common.ErrNotEnoughServers, len(servers), minServers)
} }
hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice()
sort.Sort(models.SortableServers(servers)) sort.Sort(models.SortableServers(servers))
return servers, nil return servers, nil

View File

@@ -5,11 +5,13 @@ import (
"errors" "errors"
"net" "net"
"testing" "testing"
"time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/updater/resolver"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -28,11 +30,11 @@ func Test_Updater_GetServers(t *testing.T) {
unzipErr error unzipErr error
// Resolution // Resolution
expectResolve bool expectResolve bool
hostsToResolve []string resolverSettings resolver.ParallelSettings
hostToIPs map[string][]net.IP hostToIPs map[string][]net.IP
resolveWarnings []string resolveWarnings []string
resolveErr error resolveErr error
// Output // Output
servers []models.Server servers []models.Server
@@ -85,9 +87,19 @@ func Test_Updater_GetServers(t *testing.T) {
unzipContents: map[string][]byte{ unzipContents: map[string][]byte{
"ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta\nremote hostb"), "ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta\nremote hostb"),
}, },
expectResolve: true, expectResolve: true,
hostsToResolve: []string{"hosta"}, resolverSettings: resolver.ParallelSettings{
err: errors.New("not enough servers found: 0 and expected at least 1"), Hosts: []string{"hosta"},
MaxFailRatio: 0.1,
Repeat: resolver.RepeatSettings{
MaxDuration: 20 * time.Second,
BetweenDuration: time.Second,
MaxNoNew: 2,
MaxFails: 2,
SortIPs: true,
},
},
err: errors.New("not enough servers found: 0 and expected at least 1"),
}, },
"resolve error": { "resolve error": {
warnerBuilder: func(ctrl *gomock.Controller) common.Warner { warnerBuilder: func(ctrl *gomock.Controller) common.Warner {
@@ -98,8 +110,18 @@ func Test_Updater_GetServers(t *testing.T) {
unzipContents: map[string][]byte{ unzipContents: map[string][]byte{
"ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta"), "ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta"),
}, },
expectResolve: true, expectResolve: true,
hostsToResolve: []string{"hosta"}, resolverSettings: resolver.ParallelSettings{
Hosts: []string{"hosta"},
MaxFailRatio: 0.1,
Repeat: resolver.RepeatSettings{
MaxDuration: 20 * time.Second,
BetweenDuration: time.Second,
MaxNoNew: 2,
MaxFails: 2,
SortIPs: true,
},
},
resolveWarnings: []string{"resolve warning"}, resolveWarnings: []string{"resolve warning"},
resolveErr: errors.New("dummy"), resolveErr: errors.New("dummy"),
err: errors.New("dummy"), err: errors.New("dummy"),
@@ -127,8 +149,18 @@ func Test_Updater_GetServers(t *testing.T) {
"ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta"), "ipvanish-CA-City-A-hosta.ovpn": []byte("remote hosta"),
"ipvanish-LU-City-B-hostb.ovpn": []byte("remote hostb"), "ipvanish-LU-City-B-hostb.ovpn": []byte("remote hostb"),
}, },
expectResolve: true, expectResolve: true,
hostsToResolve: []string{"hosta", "hostb"}, resolverSettings: resolver.ParallelSettings{
Hosts: []string{"hosta", "hostb"},
MaxFailRatio: 0.1,
Repeat: resolver.RepeatSettings{
MaxDuration: 20 * time.Second,
BetweenDuration: time.Second,
MaxNoNew: 2,
MaxFails: 2,
SortIPs: true,
},
},
hostToIPs: map[string][]net.IP{ hostToIPs: map[string][]net.IP{
"hosta": {{1, 1, 1, 1}, {2, 2, 2, 2}}, "hosta": {{1, 1, 1, 1}, {2, 2, 2, 2}},
"hostb": {{3, 3, 3, 3}, {4, 4, 4, 4}}, "hostb": {{3, 3, 3, 3}, {4, 4, 4, 4}},
@@ -169,7 +201,7 @@ func Test_Updater_GetServers(t *testing.T) {
presolver := common.NewMockParallelResolver(ctrl) presolver := common.NewMockParallelResolver(ctrl)
if testCase.expectResolve { if testCase.expectResolve {
presolver.EXPECT().Resolve(ctx, testCase.hostsToResolve, testCase.minServers). presolver.EXPECT().Resolve(ctx, testCase.resolverSettings).
Return(testCase.hostToIPs, testCase.resolveWarnings, testCase.resolveErr) Return(testCase.hostToIPs, testCase.resolveWarnings, testCase.resolveErr)
} }

View File

@@ -10,10 +10,11 @@ type Updater struct {
presolver common.ParallelResolver presolver common.ParallelResolver
} }
func New(unzipper common.Unzipper, warner common.Warner) *Updater { func New(unzipper common.Unzipper, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
unzipper: unzipper, unzipper: unzipper,
warner: warner, warner: warner,
presolver: newParallelResolver(), presolver: parallelResolver,
} }
} }

View File

@@ -99,7 +99,8 @@ func Test_Provider_GetConnection(t *testing.T) {
client := (*http.Client)(nil) client := (*http.Client)(nil)
warner := (common.Warner)(nil) warner := (common.Warner)(nil)
provider := New(storage, randSource, client, warner) parallelResolver := (common.ParallelResolver)(nil)
provider := New(storage, randSource, client, warner, parallelResolver)
connection, err := provider.GetConnection(testCase.selection) connection, err := provider.GetConnection(testCase.selection)

View File

@@ -18,12 +18,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
client *http.Client, updaterWarner common.Warner) *Provider { client *http.Client, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Ivpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Ivpn),
Fetcher: updater.New(client, updaterWarner), Fetcher: updater.New(client, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 20 * time.Second maxDuration = 20 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -38,7 +38,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
common.ErrNotEnoughServers, len(hosts), minServers) common.ErrNotEnoughServers, len(hosts), minServers)
} }
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -46,6 +47,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
if len(hostToIPs) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers)
}
servers = make([]models.Server, 0, len(hosts)) servers = make([]models.Server, 0, len(hosts))
for _, serverData := range data.Servers { for _, serverData := range data.Servers {
vpnType := vpn.OpenVPN vpnType := vpn.OpenVPN

View File

@@ -8,11 +8,13 @@ import (
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/updater/resolver"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -33,7 +35,7 @@ func Test_Updater_GetServers(t *testing.T) {
// Resolution // Resolution
expectResolve bool expectResolve bool
hostsToResolve []string resolveSettings resolver.ParallelSettings
hostToIPs map[string][]net.IP hostToIPs map[string][]net.IP
resolveWarnings []string resolveWarnings []string
resolveErr error resolveErr error
@@ -56,9 +58,19 @@ func Test_Updater_GetServers(t *testing.T) {
responseBody: `{"servers":[ responseBody: `{"servers":[
{"hostnames":{"openvpn":"hosta"}} {"hostnames":{"openvpn":"hosta"}}
]}`, ]}`,
responseStatus: http.StatusOK, responseStatus: http.StatusOK,
expectResolve: true, expectResolve: true,
hostsToResolve: []string{"hosta"}, resolveSettings: resolver.ParallelSettings{
Hosts: []string{"hosta"},
MaxFailRatio: 0.1,
Repeat: resolver.RepeatSettings{
MaxDuration: 20 * time.Second,
BetweenDuration: time.Second,
MaxNoNew: 2,
MaxFails: 2,
SortIPs: true,
},
},
resolveWarnings: []string{"resolve warning"}, resolveWarnings: []string{"resolve warning"},
resolveErr: errors.New("dummy"), resolveErr: errors.New("dummy"),
err: errors.New("dummy"), err: errors.New("dummy"),
@@ -86,7 +98,17 @@ func Test_Updater_GetServers(t *testing.T) {
]}`, ]}`,
responseStatus: http.StatusOK, responseStatus: http.StatusOK,
expectResolve: true, expectResolve: true,
hostsToResolve: []string{"hosta", "hostb", "hostc"}, resolveSettings: resolver.ParallelSettings{
Hosts: []string{"hosta", "hostb", "hostc"},
MaxFailRatio: 0.1,
Repeat: resolver.RepeatSettings{
MaxDuration: 20 * time.Second,
BetweenDuration: time.Second,
MaxNoNew: 2,
MaxFails: 2,
SortIPs: true,
},
},
hostToIPs: map[string][]net.IP{ hostToIPs: map[string][]net.IP{
"hosta": {{1, 1, 1, 1}, {2, 2, 2, 2}}, "hosta": {{1, 1, 1, 1}, {2, 2, 2, 2}},
"hostb": {{3, 3, 3, 3}, {4, 4, 4, 4}}, "hostb": {{3, 3, 3, 3}, {4, 4, 4, 4}},
@@ -130,7 +152,7 @@ func Test_Updater_GetServers(t *testing.T) {
presolver := common.NewMockParallelResolver(ctrl) presolver := common.NewMockParallelResolver(ctrl)
if testCase.expectResolve { if testCase.expectResolve {
presolver.EXPECT().Resolve(ctx, testCase.hostsToResolve, testCase.minServers). presolver.EXPECT().Resolve(ctx, testCase.resolveSettings).
Return(testCase.hostToIPs, testCase.resolveWarnings, testCase.resolveErr) Return(testCase.hostToIPs, testCase.resolveWarnings, testCase.resolveErr)
} }

View File

@@ -12,10 +12,11 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(client *http.Client, warner common.Warner) *Updater { func New(client *http.Client, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
client: client, client: client,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -19,12 +19,13 @@ type Provider struct {
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
client *http.Client, unzipper common.Unzipper, client *http.Client, unzipper common.Unzipper,
updaterWarner common.Warner) *Provider { updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Privado), NoPortForwarder: utils.NewNoPortForwarding(providers.Privado),
Fetcher: updater.New(client, unzipper, updaterWarner), Fetcher: updater.New(client, unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,14 +6,15 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 30 * time.Second maxDuration = 30 * time.Second
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -22,5 +23,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -50,7 +50,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
} }
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -58,6 +59,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
if len(hostToIPs) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers)
}
hts.adaptWithIPs(hostToIPs) hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice() servers = hts.toServersSlice()

View File

@@ -14,11 +14,11 @@ type Updater struct {
} }
func New(client *http.Client, unzipper common.Unzipper, func New(client *http.Client, unzipper common.Unzipper,
warner common.Warner) *Updater { warner common.Warner, parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
client: client, client: client,
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -17,12 +17,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
unzipper common.Unzipper, updaterWarner common.Warner) *Provider { unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Privatevpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Privatevpn),
Fetcher: updater.New(unzipper, updaterWarner), Fetcher: updater.New(unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 6 * time.Second maxDuration = 6 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -82,7 +82,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -90,6 +91,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
if len(noHostnameServers)+len(hostToIPs) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers)
}
hts.adaptWithIPs(hostToIPs) hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice() servers = hts.toServersSlice()

View File

@@ -10,10 +10,11 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(unzipper common.Unzipper, warner common.Warner) *Updater { func New(unzipper common.Unzipper, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -18,7 +18,8 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
client *http.Client, updaterWarner common.Warner) *Provider { client *http.Client, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,

View File

@@ -44,33 +44,36 @@ type Storage interface {
} }
func NewProviders(storage Storage, timeNow func() time.Time, func NewProviders(storage Storage, timeNow func() time.Time,
updaterWarner common.Warner, client *http.Client, unzipper common.Unzipper) *Providers { updaterWarner common.Warner, client *http.Client, unzipper common.Unzipper,
parallelResolver common.ParallelResolver) *Providers {
randSource := rand.NewSource(timeNow().UnixNano()) randSource := rand.NewSource(timeNow().UnixNano())
targetLength := len(providers.AllWithCustom()) //nolint:lll
providerNameToProvider := make(map[string]Provider, targetLength) providerNameToProvider := map[string]Provider{
providerNameToProvider[providers.Custom] = custom.New() providers.Custom: custom.New(),
providerNameToProvider[providers.Cyberghost] = cyberghost.New(storage, randSource) providers.Cyberghost: cyberghost.New(storage, randSource, parallelResolver),
providerNameToProvider[providers.Expressvpn] = expressvpn.New(storage, randSource, unzipper, updaterWarner) providers.Expressvpn: expressvpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.Fastestvpn] = fastestvpn.New(storage, randSource, unzipper, updaterWarner) providers.Fastestvpn: fastestvpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.HideMyAss] = hidemyass.New(storage, randSource, client, updaterWarner) providers.HideMyAss: hidemyass.New(storage, randSource, client, updaterWarner, parallelResolver),
providerNameToProvider[providers.Ipvanish] = ipvanish.New(storage, randSource, unzipper, updaterWarner) providers.Ipvanish: ipvanish.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.Ivpn] = ivpn.New(storage, randSource, client, updaterWarner) providers.Ivpn: ivpn.New(storage, randSource, client, updaterWarner, parallelResolver),
providerNameToProvider[providers.Mullvad] = mullvad.New(storage, randSource, client) providers.Mullvad: mullvad.New(storage, randSource, client),
providerNameToProvider[providers.Nordvpn] = nordvpn.New(storage, randSource, client, updaterWarner) providers.Nordvpn: nordvpn.New(storage, randSource, client, updaterWarner),
providerNameToProvider[providers.Perfectprivacy] = perfectprivacy.New(storage, randSource, unzipper, updaterWarner) providers.Perfectprivacy: perfectprivacy.New(storage, randSource, unzipper, updaterWarner),
providerNameToProvider[providers.Privado] = privado.New(storage, randSource, client, unzipper, updaterWarner) providers.Privado: privado.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.PrivateInternetAccess] = privateinternetaccess.New(storage, randSource, timeNow, client) //nolint:lll providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client),
providerNameToProvider[providers.Privatevpn] = privatevpn.New(storage, randSource, unzipper, updaterWarner) providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.Protonvpn] = protonvpn.New(storage, randSource, client, updaterWarner) providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, parallelResolver),
providerNameToProvider[providers.Purevpn] = purevpn.New(storage, randSource, client, unzipper, updaterWarner) providers.Purevpn: purevpn.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.Surfshark] = surfshark.New(storage, randSource, client, unzipper, updaterWarner) providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.Torguard] = torguard.New(storage, randSource, unzipper, updaterWarner) providers.Torguard: torguard.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.VPNUnlimited] = vpnunlimited.New(storage, randSource, unzipper, updaterWarner) providers.VPNUnlimited: vpnunlimited.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.Vyprvpn] = vyprvpn.New(storage, randSource, unzipper, updaterWarner) providers.Vyprvpn: vyprvpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providerNameToProvider[providers.Wevpn] = wevpn.New(storage, randSource, updaterWarner) providers.Wevpn: wevpn.New(storage, randSource, updaterWarner, parallelResolver),
providerNameToProvider[providers.Windscribe] = windscribe.New(storage, randSource, client, updaterWarner) providers.Windscribe: windscribe.New(storage, randSource, client, updaterWarner),
}
targetLength := len(providers.AllWithCustom())
if len(providerNameToProvider) != targetLength { if len(providerNameToProvider) != targetLength {
// Programming sanity check // Programming sanity check
panic(fmt.Sprintf("invalid number of providers, expected %d but got %d", panic(fmt.Sprintf("invalid number of providers, expected %d but got %d",

View File

@@ -18,12 +18,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner) *Provider { client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Purevpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Purevpn),
Fetcher: updater.New(client, unzipper, updaterWarner), Fetcher: updater.New(client, unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 20 * time.Second maxDuration = 20 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -60,7 +60,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
} }
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -68,15 +69,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
hts.adaptWithIPs(hostToIPs) if len(hostToIPs) < minServers {
servers = hts.toServersSlice()
if len(servers) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d", return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers) common.ErrNotEnoughServers, len(servers), minServers)
} }
hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice()
// Get public IP address information // Get public IP address information
ipsToGetInfo := make([]net.IP, len(servers)) ipsToGetInfo := make([]net.IP, len(servers))
for i := range servers { for i := range servers {

View File

@@ -14,11 +14,11 @@ type Updater struct {
} }
func New(client *http.Client, unzipper common.Unzipper, func New(client *http.Client, unzipper common.Unzipper,
warner common.Warner) *Updater { warner common.Warner, parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
client: client, client: client,
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -18,12 +18,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner) *Provider { client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Surfshark), NoPortForwarder: utils.NewNoPortForwarding(providers.Surfshark),
Fetcher: updater.New(client, unzipper, updaterWarner), Fetcher: updater.New(client, unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 20 * time.Second maxDuration = 20 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -31,7 +31,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
getRemainingServers(hts) getRemainingServers(hts)
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -39,15 +40,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
hts.adaptWithIPs(hostToIPs) if len(hostToIPs) < minServers {
servers = hts.toServersSlice()
if len(servers) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d", return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers) common.ErrNotEnoughServers, len(servers), minServers)
} }
hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice()
sort.Sort(models.SortableServers(servers)) sort.Sort(models.SortableServers(servers))
return servers, nil return servers, nil

View File

@@ -14,11 +14,11 @@ type Updater struct {
} }
func New(client *http.Client, unzipper common.Unzipper, func New(client *http.Client, unzipper common.Unzipper,
warner common.Warner) *Updater { warner common.Warner, parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
client: client, client: client,
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -17,12 +17,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
unzipper common.Unzipper, updaterWarner common.Warner) *Provider { unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Torguard), NoPortForwarder: utils.NewNoPortForwarding(providers.Torguard),
Fetcher: updater.New(unzipper, updaterWarner), Fetcher: updater.New(unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 20 * time.Second maxDuration = 20 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -50,12 +50,18 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
} }
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
u.warnWarnings(warnings) u.warnWarnings(warnings)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(hostToIPs) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers)
}
hts.adaptWithIPs(hostToIPs) hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice() servers = hts.toServersSlice()

View File

@@ -10,10 +10,11 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(unzipper common.Unzipper, warner common.Warner) *Updater { func New(unzipper common.Unzipper, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -17,12 +17,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
unzipper common.Unzipper, updaterWarner common.Warner) *Provider { unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.VPNUnlimited), NoPortForwarder: utils.NewNoPortForwarding(providers.VPNUnlimited),
Fetcher: updater.New(unzipper, updaterWarner), Fetcher: updater.New(unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 20 * time.Second maxDuration = 20 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -20,7 +20,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
} }
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -28,15 +29,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
hts.adaptWithIPs(hostToIPs) if len(hostToIPs) < minServers {
servers = hts.toServersSlice()
if len(servers) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d", return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers) common.ErrNotEnoughServers, len(servers), minServers)
} }
hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice()
sort.Sort(models.SortableServers(servers)) sort.Sort(models.SortableServers(servers))
return servers, nil return servers, nil

View File

@@ -10,10 +10,11 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(unzipper common.Unzipper, warner common.Warner) *Updater { func New(unzipper common.Unzipper, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -17,12 +17,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
unzipper common.Unzipper, updaterWarner common.Warner) *Provider { unzipper common.Unzipper, updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Vyprvpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Vyprvpn),
Fetcher: updater.New(unzipper, updaterWarner), Fetcher: updater.New(unzipper, updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 5 * time.Second maxDuration = 5 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -64,7 +64,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
} }
hosts := hts.toHostsSlice() hosts := hts.toHostsSlice()
hostToIPs, warnings, err := u.presolver.Resolve(ctx, hosts, minServers) resolveSettings := parallelResolverSettings(hosts)
hostToIPs, warnings, err := u.presolver.Resolve(ctx, resolveSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }
@@ -72,15 +73,15 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return nil, err return nil, err
} }
hts.adaptWithIPs(hostToIPs) if len(hostToIPs) < minServers {
servers = hts.toServersSlice()
if len(servers) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d", return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(servers), minServers) common.ErrNotEnoughServers, len(servers), minServers)
} }
hts.adaptWithIPs(hostToIPs)
servers = hts.toServersSlice()
sort.Sort(models.SortableServers(servers)) sort.Sort(models.SortableServers(servers))
return servers, nil return servers, nil

View File

@@ -10,10 +10,11 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(unzipper common.Unzipper, warner common.Warner) *Updater { func New(unzipper common.Unzipper, warner common.Warner,
parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
unzipper: unzipper, unzipper: unzipper,
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -93,7 +93,8 @@ func Test_Provider_GetConnection(t *testing.T) {
randSource := rand.NewSource(0) randSource := rand.NewSource(0)
warner := (common.Warner)(nil) warner := (common.Warner)(nil)
provider := New(storage, randSource, warner) parallelResolver := (common.ParallelResolver)(nil)
provider := New(storage, randSource, warner, parallelResolver)
if testCase.panicMessage != "" { if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() { assert.PanicsWithValue(t, testCase.panicMessage, func() {

View File

@@ -17,12 +17,13 @@ type Provider struct {
} }
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
updaterWarner common.Warner) *Provider { updaterWarner common.Warner,
parallelResolver common.ParallelResolver) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
NoPortForwarder: utils.NewNoPortForwarding(providers.Wevpn), NoPortForwarder: utils.NewNoPortForwarding(providers.Wevpn),
Fetcher: updater.New(updaterWarner), Fetcher: updater.New(updaterWarner, parallelResolver),
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/gluetun/internal/updater/resolver"
) )
func newParallelResolver() (parallelResolver *resolver.Parallel) { func parallelResolverSettings(hosts []string) (settings resolver.ParallelSettings) {
const ( const (
maxFailRatio = 0.1 maxFailRatio = 0.1
maxDuration = 20 * time.Second maxDuration = 20 * time.Second
@@ -14,7 +14,8 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
maxNoNew = 2 maxNoNew = 2
maxFails = 2 maxFails = 2
) )
settings := resolver.ParallelSettings{ return resolver.ParallelSettings{
Hosts: hosts,
MaxFailRatio: maxFailRatio, MaxFailRatio: maxFailRatio,
Repeat: resolver.RepeatSettings{ Repeat: resolver.RepeatSettings{
MaxDuration: maxDuration, MaxDuration: maxDuration,
@@ -24,5 +25,4 @@ func newParallelResolver() (parallelResolver *resolver.Parallel) {
SortIPs: true, SortIPs: true,
}, },
} }
return resolver.NewParallelResolver(settings)
} }

View File

@@ -31,7 +31,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
hostnameToCity[hostname] = city hostnameToCity[hostname] = city
} }
hostnameToIPs, warnings, err := u.presolver.Resolve(ctx, hostnames, minServers) resolverSettings := parallelResolverSettings(hostnames)
hostnameToIPs, warnings, err := u.presolver.Resolve(ctx, resolverSettings)
for _, warning := range warnings { for _, warning := range warnings {
u.warner.Warn(warning) u.warner.Warn(warning)
} }

View File

@@ -7,9 +7,9 @@ type Updater struct {
warner common.Warner warner common.Warner
} }
func New(warner common.Warner) *Updater { func New(warner common.Warner, parallelResolver common.ParallelResolver) *Updater {
return &Updater{ return &Updater{
presolver: newParallelResolver(), presolver: parallelResolver,
warner: warner, warner: warner,
} }
} }

View File

@@ -9,17 +9,17 @@ import (
type Parallel struct { type Parallel struct {
repeatResolver *Repeat repeatResolver *Repeat
settings ParallelSettings
} }
func NewParallelResolver(settings ParallelSettings) *Parallel { func NewParallelResolver(resolverAddress string) *Parallel {
return &Parallel{ return &Parallel{
repeatResolver: NewRepeat(settings.Repeat), repeatResolver: NewRepeat(resolverAddress),
settings: settings,
} }
} }
type ParallelSettings struct { type ParallelSettings struct {
// Hosts to resolve in parallel.
Hosts []string
Repeat RepeatSettings Repeat RepeatSettings
FailEarly bool FailEarly bool
// Maximum ratio of the hosts failing DNS resolution // Maximum ratio of the hosts failing DNS resolution
@@ -39,7 +39,7 @@ var (
ErrMaxFailRatio = errors.New("maximum failure ratio reached") ErrMaxFailRatio = errors.New("maximum failure ratio reached")
) )
func (pr *Parallel) Resolve(ctx context.Context, hosts []string, minToFind int) ( func (pr *Parallel) Resolve(ctx context.Context, settings ParallelSettings) (
hostToIPs map[string][]net.IP, warnings []string, err error) { hostToIPs map[string][]net.IP, warnings []string, err error) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@@ -49,17 +49,17 @@ func (pr *Parallel) Resolve(ctx context.Context, hosts []string, minToFind int)
errors := make(chan error) errors := make(chan error)
defer close(errors) defer close(errors)
for _, host := range hosts { for _, host := range settings.Hosts {
go pr.resolveAsync(ctx, host, results, errors) go pr.resolveAsync(ctx, host, settings.Repeat, results, errors)
} }
hostToIPs = make(map[string][]net.IP, len(hosts)) hostToIPs = make(map[string][]net.IP, len(settings.Hosts))
maxFails := int(pr.settings.MaxFailRatio * float64(len(hosts))) maxFails := int(settings.MaxFailRatio * float64(len(settings.Hosts)))
for range hosts { for range settings.Hosts {
select { select {
case newErr := <-errors: case newErr := <-errors:
if pr.settings.FailEarly { if settings.FailEarly {
if err == nil { if err == nil {
// only set the error to the first error encountered // only set the error to the first error encountered
// and not the context canceled errors coming after. // and not the context canceled errors coming after.
@@ -86,14 +86,8 @@ func (pr *Parallel) Resolve(ctx context.Context, hosts []string, minToFind int)
return nil, warnings, err return nil, warnings, err
} }
if len(hostToIPs) < minToFind { failureRatio := float64(len(warnings)) / float64(len(settings.Hosts))
return nil, warnings, if failureRatio > settings.MaxFailRatio {
fmt.Errorf("%w: found %d hosts but expected at least %d",
ErrMinFound, len(hostToIPs), minToFind)
}
failureRatio := float64(len(warnings)) / float64(len(hosts))
if failureRatio > pr.settings.MaxFailRatio {
return hostToIPs, warnings, return hostToIPs, warnings,
fmt.Errorf("%w: %.2f failure ratio reached", ErrMaxFailRatio, failureRatio) fmt.Errorf("%w: %.2f failure ratio reached", ErrMaxFailRatio, failureRatio)
} }
@@ -102,8 +96,8 @@ func (pr *Parallel) Resolve(ctx context.Context, hosts []string, minToFind int)
} }
func (pr *Parallel) resolveAsync(ctx context.Context, host string, func (pr *Parallel) resolveAsync(ctx context.Context, host string,
results chan<- parallelResult, errors chan<- error) { settings RepeatSettings, results chan<- parallelResult, errors chan<- error) {
IPs, err := pr.repeatResolver.Resolve(ctx, host) IPs, err := pr.repeatResolver.Resolve(ctx, host, settings)
if err != nil { if err != nil {
errors <- err errors <- err
return return

View File

@@ -12,13 +12,11 @@ import (
type Repeat struct { type Repeat struct {
resolver *net.Resolver resolver *net.Resolver
settings RepeatSettings
} }
func NewRepeat(settings RepeatSettings) *Repeat { func NewRepeat(resolverAddress string) *Repeat {
return &Repeat{ return &Repeat{
resolver: newResolver(settings.Address), resolver: newResolver(resolverAddress),
settings: settings,
} }
} }
@@ -32,8 +30,9 @@ type RepeatSettings struct {
SortIPs bool SortIPs bool
} }
func (r *Repeat) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { func (r *Repeat) Resolve(ctx context.Context, host string, settings RepeatSettings) (
timedCtx, cancel := context.WithTimeout(ctx, r.settings.MaxDuration) ips []net.IP, err error) {
timedCtx, cancel := context.WithTimeout(ctx, settings.MaxDuration)
defer cancel() defer cancel()
noNewCounter := 0 noNewCounter := 0
@@ -44,7 +43,7 @@ func (r *Repeat) Resolve(ctx context.Context, host string) (ips []net.IP, err er
// TODO // TODO
// - one resolving every 100ms for round robin DNS responses // - one resolving every 100ms for round robin DNS responses
// - one every second for time based DNS cycling responses // - one every second for time based DNS cycling responses
noNewCounter, failCounter, err = r.resolveOnce(ctx, timedCtx, host, r.settings, uniqueIPs, noNewCounter, failCounter) noNewCounter, failCounter, err = r.resolveOnce(ctx, timedCtx, host, settings, uniqueIPs, noNewCounter, failCounter)
} }
if len(uniqueIPs) == 0 { if len(uniqueIPs) == 0 {
@@ -53,7 +52,7 @@ func (r *Repeat) Resolve(ctx context.Context, host string) (ips []net.IP, err er
ips = uniqueIPsToSlice(uniqueIPs) ips = uniqueIPsToSlice(uniqueIPs)
if r.settings.SortIPs { if settings.SortIPs {
sort.Slice(ips, func(i, j int) bool { sort.Slice(ips, func(i, j int) bool {
return bytes.Compare(ips[i], ips[j]) < 1 return bytes.Compare(ips[i], ips[j]) < 1
}) })