From a19efbd923b1c3c65746ced15cbb7487cb6da517 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 12 Sep 2020 14:04:54 -0400 Subject: [PATCH] Updater loop with period and http route (#240) * Updater loop with period and http route * Using DNS over TLS to update servers * Better logging * Remove goroutines for cyberghost updater * Respects context for servers update (quite slow overall) * Increase shutdown grace period to 5 seconds * Update announcement * Add log lines for each provider update start --- README.md | 1 + cmd/gluetun/main.go | 18 +++- internal/cli/cli.go | 23 ++++-- internal/constants/splash.go | 4 +- internal/openvpn/loop.go | 16 +++- internal/params/params.go | 2 + internal/params/updater.go | 17 ++++ internal/server/server.go | 6 +- internal/settings/settings.go | 13 ++- internal/updater/cyberghost.go | 49 +++++------ internal/updater/loop.go | 146 +++++++++++++++++++++++++++++++++ internal/updater/nordvpn.go | 6 +- internal/updater/options.go | 21 ++++- internal/updater/pia.go | 3 + internal/updater/purevpn.go | 9 +- internal/updater/surfshark.go | 3 + internal/updater/updater.go | 84 ++++++++++++------- internal/updater/vyprvpn.go | 3 + internal/updater/windscribe.go | 16 +++- 19 files changed, 358 insertions(+), 82 deletions(-) create mode 100644 internal/params/updater.go create mode 100644 internal/updater/loop.go diff --git a/README.md b/README.md index 22dcbfdc..0a9e6fc4 100644 --- a/README.md +++ b/README.md @@ -260,6 +260,7 @@ That one is important if you want to connect to the container from your LAN for | --- | --- | --- | --- | | `PUBLICIP_PERIOD` | `12h` | Valid duration | Period to check for public IP address. Set to `0` to disable. | | `VERSION_INFORMATION` | `on` | `on`, `off` | Logs a message indicating if a newer version is available once the VPN is connected | +| `UPDATER_PERIOD` | `0` | Valid duration string such as `24h` | Period to update all VPN servers information in memory and to /gluetun/servers.json. Set to `0` to disable. This does a burst of DNS over TLS requests, which may be blocked if you set `BLOCK_MALICIOUS=on` for example. | ## Connect to it diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 9bd60835..be480865 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -26,6 +26,7 @@ import ( "github.com/qdm12/gluetun/internal/shadowsocks" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/tinyproxy" + "github.com/qdm12/gluetun/internal/updater" versionpkg "github.com/qdm12/gluetun/internal/version" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/files" @@ -70,6 +71,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go defer cancel() logger := createLogger() + httpClient := &http.Client{Timeout: 15 * time.Second} client := network.NewClient(15 * time.Second) // Create configurators fileManager := files.NewFileManager() @@ -195,6 +197,12 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go // wait for restartOpenvpn go openvpnLooper.Run(ctx, wg) + updaterOptions := updater.NewOptions("127.0.0.1") + updaterLooper := updater.NewLooper(updaterOptions, allSettings.UpdaterPeriod, allServers, storage, openvpnLooper.SetAllServers, httpClient, logger) + wg.Add(1) + // wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker + go updaterLooper.Run(ctx, wg) + unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid) restartUnbound := unboundLooper.Restart // wait for restartUnbound @@ -226,8 +234,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go if !allSettings.VersionInformation { return } - client := &http.Client{Timeout: 5 * time.Second} - message, err := versionpkg.GetMessage(version, commit, client) + message, err := versionpkg.GetMessage(version, commit, httpClient) if err != nil { logger.Error(err) return @@ -246,12 +253,14 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go restartTickerCancel() restartTickerContext, restartTickerCancel = context.WithCancel(ctx) go unboundLooper.RunRestartTicker(restartTickerContext) + go updaterLooper.RunRestartTicker(ctx) onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation) } } }() - httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound, getOpenvpnSettings, getPortForwarded) + httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound, updaterLooper.Restart, + getOpenvpnSettings, getPortForwarded) go httpServer.Run(ctx, wg) // Start openvpn for the first time @@ -283,7 +292,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go shutdownErrorsCount++ } } - waiting, waited := context.WithTimeout(context.Background(), time.Second) + const shutdownGracePeriod = 5 * time.Second + waiting, waited := context.WithTimeout(context.Background(), shutdownGracePeriod) go func() { defer waited() wg.Wait() diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 11ec3bfa..0f6197e8 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -90,9 +90,10 @@ func OpenvpnConfig() error { } func Update(args []string) error { - var options updater.Options + options := updater.Options{CLI: true} + var flushToFile bool flagSet := flag.NewFlagSet("update", flag.ExitOnError) - flagSet.BoolVar(&options.File, "file", false, "Write results to /gluetun/servers.json (for end users)") + flagSet.BoolVar(&flushToFile, "file", true, "Write results to /gluetun/servers.json (for end users)") flagSet.BoolVar(&options.Stdout, "stdout", false, "Write results to console to modify the program (for maintainers)") flagSet.StringVar(&options.DNSAddress, "dns", "1.1.1.1", "DNS resolver address to use") flagSet.BoolVar(&options.Cyberghost, "cyberghost", false, "Update Cyberghost servers") @@ -110,15 +111,27 @@ func Update(args []string) error { if err != nil { return err } - if !options.File && !options.Stdout { + if !flushToFile && !options.Stdout { return fmt.Errorf("at least one of -file or -stdout must be specified") } ctx := context.Background() httpClient := &http.Client{Timeout: 10 * time.Second} storage := storage.New(logger) - updater := updater.New(options, storage, httpClient) - if err := updater.UpdateServers(ctx); err != nil { + const writeSync = false + currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync) + if err != nil { + return fmt.Errorf("cannot update servers: %w", err) + } + updater := updater.New(options, httpClient, currentServers, logger) + allServers, err := updater.UpdateServers(ctx) + if err != nil { return err } + if flushToFile { + if err := storage.FlushToFile(allServers); err != nil { + return fmt.Errorf("cannot update servers: %w", err) + } + } + return nil } diff --git a/internal/constants/splash.go b/internal/constants/splash.go index 866cbbf7..3172803d 100644 --- a/internal/constants/splash.go +++ b/internal/constants/splash.go @@ -2,9 +2,9 @@ package constants const ( // Announcement is a message announcement - Announcement = "Persistent server IP addresses at /gluetun/servers.json, please BIND MOUNT" + Announcement = "Update servers information see https://github.com/qdm12/gluetun/wiki/Update-servers-information" // AnnouncementExpiration is the expiration date of the announcement in format yyyy-mm-dd - AnnouncementExpiration = "2020-09-30" + AnnouncementExpiration = "2020-10-10" ) const ( diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 3a7e7675..e6aa46e0 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -24,6 +24,7 @@ type Looper interface { GetSettings() (settings settings.OpenVPN) SetSettings(settings settings.OpenVPN) GetPortForwarded() (portForwarded uint16) + SetAllServers(allServers models.AllServers) } type looper struct { @@ -33,10 +34,11 @@ type looper struct { settingsMutex sync.RWMutex portForwarded uint16 portForwardedMutex sync.RWMutex + allServers models.AllServers + allServersMutex sync.RWMutex // Fixed parameters - uid int - gid int - allServers models.AllServers + uid int + gid int // Configurators conf Configurator fw firewall.Configurator @@ -89,6 +91,12 @@ func (l *looper) SetSettings(settings settings.OpenVPN) { l.settings = settings } +func (l *looper) SetAllServers(allServers models.AllServers) { + l.allServersMutex.Lock() + defer l.allServersMutex.Unlock() + l.allServers = allServers +} + func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() @@ -101,7 +109,9 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { for ctx.Err() == nil { settings := l.GetSettings() + l.allServersMutex.RLock() providerConf := provider.New(l.provider, l.allServers) + l.allServersMutex.RUnlock() connections, err := providerConf.GetOpenVPNConnections(settings.Provider.ServerSelection) if err != nil { l.logger.Error(err) diff --git a/internal/params/params.go b/internal/params/params.go index 1c765dd8..959cf2eb 100644 --- a/internal/params/params.go +++ b/internal/params/params.go @@ -110,6 +110,8 @@ type Reader interface { GetPublicIPPeriod() (period time.Duration, err error) GetVersionInformation() (enabled bool, err error) + + GetUpdaterPeriod() (period time.Duration, err error) } type reader struct { diff --git a/internal/params/updater.go b/internal/params/updater.go new file mode 100644 index 00000000..96a7bada --- /dev/null +++ b/internal/params/updater.go @@ -0,0 +1,17 @@ +package params + +import ( + "time" + + libparams "github.com/qdm12/golibs/params" +) + +// GetUpdaterPeriod obtains the period to fetch the servers information when the tunnel is up. +// Set to 0 to disable +func (r *reader) GetUpdaterPeriod() (period time.Duration, err error) { + s, err := r.envParams.GetEnv("UPDATER_PERIOD", libparams.Default("0")) + if err != nil { + return 0, err + } + return time.ParseDuration(s) +} diff --git a/internal/server/server.go b/internal/server/server.go index 10ef8c4f..3819362b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -21,18 +21,20 @@ type server struct { logger logging.Logger restartOpenvpn func() restartUnbound func() + restartUpdater func() getOpenvpnSettings func() settings.OpenVPN getPortForwarded func() uint16 lookupIP func(host string) ([]net.IP, error) } -func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound func(), +func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound, restartUpdater func(), getOpenvpnSettings func() settings.OpenVPN, getPortForwarded func() uint16) Server { return &server{ address: address, logger: logger.WithPrefix("http server: "), restartOpenvpn: restartOpenvpn, restartUnbound: restartUnbound, + restartUpdater: restartUpdater, getOpenvpnSettings: getOpenvpnSettings, getPortForwarded: getPortForwarded, lookupIP: net.LookupIP, @@ -76,6 +78,8 @@ func (s *server) makeHandler() http.HandlerFunc { s.handleGetOpenvpnSettings(w) case "/health": s.handleHealth(w) + case "/updater/restart": + s.restartUpdater() default: routeDoesNotExist(s.logger, w, r) } diff --git a/internal/settings/settings.go b/internal/settings/settings.go index ab8131ca..823cc5be 100644 --- a/internal/settings/settings.go +++ b/internal/settings/settings.go @@ -1,6 +1,7 @@ package settings import ( + "fmt" "strings" "time" @@ -23,6 +24,7 @@ type Settings struct { TinyProxy TinyProxy ShadowSocks ShadowSocks PublicIPPeriod time.Duration + UpdaterPeriod time.Duration VersionInformation bool } @@ -31,6 +33,10 @@ func (s *Settings) String() string { if s.VersionInformation { versionInformation = enabled } + updaterLine := "Updater: disabled" + if s.UpdaterPeriod > 0 { + updaterLine = fmt.Sprintf("Updater period: %s", s.UpdaterPeriod) + } return strings.Join([]string{ "Settings summary below:", s.OpenVPN.String(), @@ -39,8 +45,9 @@ func (s *Settings) String() string { s.Firewall.String(), s.TinyProxy.String(), s.ShadowSocks.String(), - "Public IP check period: " + s.PublicIPPeriod.String(), + "Public IP check period: " + s.PublicIPPeriod.String(), // TODO print disabled if 0 "Version information: " + versionInformation, + updaterLine, "", // new line at the end }, "\n") } @@ -84,5 +91,9 @@ func GetAllSettings(paramsReader params.Reader) (settings Settings, err error) { if err != nil { return settings, err } + settings.UpdaterPeriod, err = paramsReader.GetUpdaterPeriod() + if err != nil { + return settings, err + } return settings, nil } diff --git a/internal/updater/cyberghost.go b/internal/updater/cyberghost.go index c076a174..c458ecdd 100644 --- a/internal/updater/cyberghost.go +++ b/internal/updater/cyberghost.go @@ -8,53 +8,46 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -func (u *updater) updateCyberghost(ctx context.Context) { - servers := findCyberghostServers(ctx, u.lookupIP) +func (u *updater) updateCyberghost(ctx context.Context) (err error) { + servers, err := findCyberghostServers(ctx, u.lookupIP) + if err != nil { + return err + } if u.options.Stdout { u.println(stringifyCyberghostServers(servers)) } u.servers.Cyberghost.Timestamp = u.timeNow().Unix() u.servers.Cyberghost.Servers = servers + return nil } -func findCyberghostServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.CyberghostServer) { +func findCyberghostServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.CyberghostServer, err error) { groups := getCyberghostGroups() allCountryCodes := getCountryCodes() cyberghostCountryCodes := getCyberghostSubdomainToRegion() possibleCountryCodes := mergeCountryCodes(cyberghostCountryCodes, allCountryCodes) - resultsChannel := make(chan models.CyberghostServer) - const maxGoroutines = 10 - guard := make(chan struct{}, maxGoroutines) for groupID, groupName := range groups { for countryCode, region := range possibleCountryCodes { - go func(groupName, groupID, region, countryCode string) { - host := fmt.Sprintf("%s-%s.cg-dialup.net", groupID, countryCode) - guard <- struct{}{} - IPs, err := resolveRepeat(ctx, lookupIP, host, 2) - if err != nil { - IPs = nil - } - <-guard - resultsChannel <- models.CyberghostServer{ - Region: region, - Group: groupName, - IPs: IPs, - } - }(groupName, groupID, region, countryCode) + if err := ctx.Err(); err != nil { + return nil, err + } + host := fmt.Sprintf("%s-%s.cg-dialup.net", groupID, countryCode) + IPs, err := resolveRepeat(ctx, lookupIP, host, 2) + if err != nil || len(IPs) == 0 { + continue + } + servers = append(servers, models.CyberghostServer{ + Region: region, + Group: groupName, + IPs: IPs, + }) } } - for i := 0; i < len(groups)*len(possibleCountryCodes); i++ { - server := <-resultsChannel - if server.IPs == nil { - continue - } - servers = append(servers, server) - } sort.Slice(servers, func(i, j int) bool { return servers[i].Region < servers[j].Region }) - return servers + return servers, nil } //nolint:goconst diff --git a/internal/updater/loop.go b/internal/updater/loop.go new file mode 100644 index 00000000..12557afb --- /dev/null +++ b/internal/updater/loop.go @@ -0,0 +1,146 @@ +package updater + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/storage" + "github.com/qdm12/golibs/logging" +) + +type Looper interface { + Run(ctx context.Context, wg *sync.WaitGroup) + RunRestartTicker(ctx context.Context) + Restart() + Stop() + GetPeriod() (period time.Duration) + SetPeriod(period time.Duration) +} + +type looper struct { + period time.Duration + periodMutex sync.RWMutex + updater Updater + storage storage.Storage + setAllServers func(allServers models.AllServers) + logger logging.Logger + restart chan struct{} + stop chan struct{} + updateTicker chan struct{} +} + +func NewLooper(options Options, period time.Duration, currentServers models.AllServers, + storage storage.Storage, setAllServers func(allServers models.AllServers), + client *http.Client, logger logging.Logger) Looper { + loggerWithPrefix := logger.WithPrefix("updater: ") + return &looper{ + period: period, + updater: New(options, client, currentServers, loggerWithPrefix), + storage: storage, + setAllServers: setAllServers, + logger: loggerWithPrefix, + restart: make(chan struct{}), + stop: make(chan struct{}), + updateTicker: make(chan struct{}), + } +} + +func (l *looper) Restart() { l.restart <- struct{}{} } +func (l *looper) Stop() { l.stop <- struct{}{} } + +func (l *looper) GetPeriod() (period time.Duration) { + l.periodMutex.RLock() + defer l.periodMutex.RUnlock() + return l.period +} + +func (l *looper) SetPeriod(period time.Duration) { + l.periodMutex.Lock() + l.period = period + l.periodMutex.Unlock() + l.updateTicker <- struct{}{} +} + +func (l *looper) logAndWait(ctx context.Context, err error) { + l.logger.Error(err) + l.logger.Info("retrying in 5 minutes") + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() // just for the linter + <-ctx.Done() +} + +func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + select { + case <-l.restart: + l.logger.Info("starting...") + case <-ctx.Done(): + return + } + defer l.logger.Warn("loop exited") + + enabled := true + + for ctx.Err() == nil { + for !enabled { + // wait for a signal to re-enable + select { + case <-l.stop: + l.logger.Info("already disabled") + case <-l.restart: + enabled = true + case <-ctx.Done(): + return + } + } + + // Enabled and has a period set + + servers, err := l.updater.UpdateServers(ctx) + if err != nil { + if ctx.Err() != nil { + return + } + l.logAndWait(ctx, err) + continue + } + l.setAllServers(servers) + if err := l.storage.FlushToFile(servers); err != nil { + l.logger.Error(err) + } + l.logger.Info("Updated servers information") + + select { + case <-l.restart: // triggered restart + case <-l.stop: + enabled = false + case <-ctx.Done(): + return + } + } +} + +func (l *looper) RunRestartTicker(ctx context.Context) { + ticker := time.NewTicker(time.Hour) + period := l.GetPeriod() + if period > 0 { + ticker = time.NewTicker(period) + } else { + ticker.Stop() + } + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + l.restart <- struct{}{} + case <-l.updateTicker: + ticker.Stop() + ticker = time.NewTicker(l.GetPeriod()) + } + } +} diff --git a/internal/updater/nordvpn.go b/internal/updater/nordvpn.go index 6115a3df..cc6630c0 100644 --- a/internal/updater/nordvpn.go +++ b/internal/updater/nordvpn.go @@ -15,8 +15,10 @@ import ( func (u *updater) updateNordvpn() (err error) { servers, warnings, err := findNordvpnServers(u.httpGet) - for _, warning := range warnings { - u.println(warning) + if u.options.CLI { + for _, warning := range warnings { + u.logger.Warn("Nordvpn: %s", warning) + } } if err != nil { return fmt.Errorf("cannot update Nordvpn servers: %w", err) diff --git a/internal/updater/options.go b/internal/updater/options.go index b9c872a4..66badcb7 100644 --- a/internal/updater/options.go +++ b/internal/updater/options.go @@ -10,7 +10,24 @@ type Options struct { Surfshark bool Vyprvpn bool Windscribe bool - File bool // update JSON file (user side) - Stdout bool // update constants file (maintainer side) + Stdout bool // in order to update constants file (maintainer side) + CLI bool DNSAddress string } + +func NewOptions(dnsAddress string) Options { + return Options{ + Cyberghost: true, + Mullvad: true, + Nordvpn: true, + PIA: true, + PIAold: true, + Purevpn: true, + Surfshark: true, + Vyprvpn: true, + Windscribe: true, + Stdout: false, + CLI: false, + DNSAddress: dnsAddress, + } +} diff --git a/internal/updater/pia.go b/internal/updater/pia.go index 0f0ade7e..b9d86421 100644 --- a/internal/updater/pia.go +++ b/internal/updater/pia.go @@ -52,6 +52,9 @@ func (u *updater) updatePIAOld(ctx context.Context) (err error) { } servers := make([]models.PIAServer, 0, len(contents)) for fileName, content := range contents { + if err := ctx.Err(); err != nil { + return err + } remoteLines := extractRemoteLinesFromOpenvpn(content) if len(remoteLines) == 0 { return fmt.Errorf("cannot find any remote lines in %s", fileName) diff --git a/internal/updater/purevpn.go b/internal/updater/purevpn.go index f021911e..611ffc2d 100644 --- a/internal/updater/purevpn.go +++ b/internal/updater/purevpn.go @@ -14,8 +14,10 @@ import ( func (u *updater) updatePurevpn(ctx context.Context) (err error) { servers, warnings, err := findPurevpnServers(ctx, u.httpGet, u.lookupIP) - for _, warning := range warnings { - u.println(warning) + if u.options.CLI { + for _, warning := range warnings { + u.logger.Warn("PureVPN: %s", warning) + } } if err != nil { return fmt.Errorf("cannot update Purevpn servers: %w", err) @@ -76,6 +78,9 @@ func findPurevpnServers(ctx context.Context, httpGet httpGetFunc, lookupIP looku return data[i].Region < data[j].Region }) for _, jsonServer := range data { + if err := ctx.Err(); err != nil { + return nil, warnings, err + } if jsonServer.UDP == "" && jsonServer.TCP == "" { warnings = append(warnings, fmt.Sprintf("server %s %s %s does not support TCP and UDP for openvpn", jsonServer.Region, jsonServer.Country, jsonServer.City)) continue diff --git a/internal/updater/surfshark.go b/internal/updater/surfshark.go index b3ebc89f..7cdfe3d4 100644 --- a/internal/updater/surfshark.go +++ b/internal/updater/surfshark.go @@ -30,6 +30,9 @@ func findSurfsharkServers(ctx context.Context, lookupIP lookupIPFunc) (servers [ return nil, err } for fileName, content := range contents { + if err := ctx.Err(); err != nil { + return nil, err + } if strings.HasSuffix(fileName, "_tcp.ovpn") { continue // only parse UDP files } diff --git a/internal/updater/updater.go b/internal/updater/updater.go index d5994d90..14be6fa7 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -6,110 +6,138 @@ import ( "net/http" "time" - "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/storage" + "github.com/qdm12/golibs/logging" ) type Updater interface { - UpdateServers(ctx context.Context) error + UpdateServers(ctx context.Context) (allServers models.AllServers, err error) } type updater struct { // configuration options Options - storage storage.Storage // state servers models.AllServers // Functions for tests + logger logging.Logger timeNow func() time.Time println func(s string) httpGet httpGetFunc lookupIP lookupIPFunc } -func New(options Options, storage storage.Storage, httpClient *http.Client) Updater { +func New(options Options, httpClient *http.Client, currentServers models.AllServers, logger logging.Logger) Updater { if len(options.DNSAddress) == 0 { options.DNSAddress = "1.1.1.1" } resolver := newResolver(options.DNSAddress) return &updater{ - storage: storage, + logger: logger, timeNow: time.Now, println: func(s string) { fmt.Println(s) }, httpGet: httpClient.Get, lookupIP: newLookupIP(resolver), options: options, + servers: currentServers, } } // TODO parallelize DNS resolution -func (u *updater) UpdateServers(ctx context.Context) (err error) { - const writeSync = false - u.servers, err = u.storage.SyncServers(constants.GetAllServers(), writeSync) - if err != nil { - return fmt.Errorf("cannot update servers: %w", err) - } - +func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) { //nolint:gocognit if u.options.Cyberghost { - u.updateCyberghost(ctx) + u.logger.Info("updating Cyberghost servers...") + if err := u.updateCyberghost(ctx); err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return allServers, ctxErr + } + u.logger.Error(err) + } } if u.options.Mullvad { + u.logger.Info("updating Mullvad servers...") if err := u.updateMullvad(); err != nil { - return err + u.logger.Error(err) + } + if err := ctx.Err(); err != nil { + return allServers, err } } if u.options.Nordvpn { // TODO support servers offering only TCP or only UDP + u.logger.Info("updating NordVPN servers...") if err := u.updateNordvpn(); err != nil { - return err + u.logger.Error(err) + } + if err := ctx.Err(); err != nil { + return allServers, err } } if u.options.PIA { + u.logger.Info("updating Private Internet Access (v4) servers...") if err := u.updatePIA(); err != nil { - return err + u.logger.Error(err) + } + if ctx.Err() != nil { + return allServers, ctx.Err() } } if u.options.PIAold { + u.logger.Info("updating Private Internet Access old (v3) servers...") if err := u.updatePIAOld(ctx); err != nil { - return err + if ctxErr := ctx.Err(); ctxErr != nil { + return allServers, ctxErr + } + u.logger.Error(err) } } if u.options.Purevpn { + u.logger.Info("updating PureVPN servers...") // TODO support servers offering only TCP or only UDP if err := u.updatePurevpn(ctx); err != nil { - return err + if ctxErr := ctx.Err(); ctxErr != nil { + return allServers, ctxErr + } + u.logger.Error(err) } } if u.options.Surfshark { + u.logger.Info("updating Surfshark servers...") if err := u.updateSurfshark(ctx); err != nil { - return err + if ctxErr := ctx.Err(); ctxErr != nil { + return allServers, ctxErr + } + u.logger.Error(err) } } if u.options.Vyprvpn { + u.logger.Info("updating Vyprvpn servers...") if err := u.updateVyprvpn(ctx); err != nil { - return err + if ctxErr := ctx.Err(); ctxErr != nil { + return allServers, ctxErr + } + u.logger.Error(err) } } if u.options.Windscribe { - u.updateWindscribe(ctx) - } - - if u.options.File { - if err := u.storage.FlushToFile(u.servers); err != nil { - return fmt.Errorf("cannot update servers: %w", err) + u.logger.Info("updating Windscribe servers...") + if err := u.updateWindscribe(ctx); err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return allServers, ctxErr + } + u.logger.Error(err) } } - return nil + return u.servers, nil } diff --git a/internal/updater/vyprvpn.go b/internal/updater/vyprvpn.go index 87a8c82a..f58a2e80 100644 --- a/internal/updater/vyprvpn.go +++ b/internal/updater/vyprvpn.go @@ -30,6 +30,9 @@ func findVyprvpnServers(ctx context.Context, lookupIP lookupIPFunc) (servers []m return nil, err } for fileName, content := range contents { + if err := ctx.Err(); err != nil { + return nil, err + } remoteLines := extractRemoteLinesFromOpenvpn(content) if len(remoteLines) == 0 { return nil, fmt.Errorf("cannot find any remote lines in %s", fileName) diff --git a/internal/updater/windscribe.go b/internal/updater/windscribe.go index fb97409c..00198645 100644 --- a/internal/updater/windscribe.go +++ b/internal/updater/windscribe.go @@ -2,26 +2,34 @@ package updater import ( "context" + "fmt" "sort" "github.com/qdm12/gluetun/internal/models" ) -func (u *updater) updateWindscribe(ctx context.Context) { - servers := findWindscribeServers(ctx, u.lookupIP) +func (u *updater) updateWindscribe(ctx context.Context) (err error) { + servers, err := findWindscribeServers(ctx, u.lookupIP) + if err != nil { + return fmt.Errorf("cannot update Windscribe servers: %w", err) + } if u.options.Stdout { u.println(stringifyWindscribeServers(servers)) } u.servers.Windscribe.Timestamp = u.timeNow().Unix() u.servers.Windscribe.Servers = servers + return nil } -func findWindscribeServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.WindscribeServer) { +func findWindscribeServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.WindscribeServer, err error) { allCountryCodes := getCountryCodes() windscribeCountryCodes := getWindscribeSubdomainToRegion() possibleCountryCodes := mergeCountryCodes(windscribeCountryCodes, allCountryCodes) const domain = "windscribe.com" for countryCode, region := range possibleCountryCodes { + if err := ctx.Err(); err != nil { + return nil, err + } host := countryCode + "." + domain ips, err := resolveRepeat(ctx, lookupIP, host, 2) if err != nil || len(ips) == 0 { @@ -35,7 +43,7 @@ func findWindscribeServers(ctx context.Context, lookupIP lookupIPFunc) (servers sort.Slice(servers, func(i, j int) bool { return servers[i].Region < servers[j].Region }) - return servers + return servers, nil } func mergeCountryCodes(base, extend map[string]string) (merged map[string]string) {