From 4257581f55a3b6b07a3401fe1d2e74683c86d807 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 19 Dec 2020 20:10:34 -0500 Subject: [PATCH] Loops and HTTP control server rework (#308) - CRUD REST HTTP server - `/v1` HTTP server prefix - Retrocompatible with older routes (redirects to v1 or handles the requests directly) - DNS, Updater and Openvpn refactored to have a REST-like state with new methods to change their states synchronously - Openvpn, Unbound and Updater status, see #287 --- cmd/gluetun/main.go | 14 +- internal/cli/cli.go | 2 +- internal/constants/status.go | 14 ++ internal/dns/loop.go | 244 +++++++++++++----------------- internal/dns/state.go | 96 ++++++++++++ internal/models/alias.go | 6 + internal/models/build.go | 2 +- internal/models/openvpn.go | 4 +- internal/models/selection.go | 22 +-- internal/openvpn/loop.go | 144 +++++++++--------- internal/openvpn/state.go | 121 +++++++++++++++ internal/params/dns.go | 4 +- internal/server/dns.go | 76 ++++++++++ internal/server/handler.go | 69 +++------ internal/server/handlerv0.go | 69 +++++++++ internal/server/handlerv1.go | 58 +++++++ internal/server/log.go | 75 +++++++++ internal/server/openvpn.go | 114 +++++++++++--- internal/server/server.go | 3 +- internal/server/updater.go | 78 ++++++++++ internal/server/version.go | 19 --- internal/server/wrappers.go | 32 ++++ internal/settings/openvpn.go | 4 +- internal/settings/openvpn_test.go | 2 +- internal/settings/settings.go | 11 +- internal/settings/updater.go | 59 ++++++++ internal/updater/loop.go | 153 ++++++++++--------- internal/updater/options.go | 32 ---- internal/updater/state.go | 88 +++++++++++ internal/updater/updater.go | 14 +- 30 files changed, 1191 insertions(+), 438 deletions(-) create mode 100644 internal/constants/status.go create mode 100644 internal/dns/state.go create mode 100644 internal/openvpn/state.go create mode 100644 internal/server/dns.go create mode 100644 internal/server/handlerv0.go create mode 100644 internal/server/handlerv1.go create mode 100644 internal/server/log.go create mode 100644 internal/server/updater.go delete mode 100644 internal/server/version.go create mode 100644 internal/server/wrappers.go create mode 100644 internal/settings/updater.go delete mode 100644 internal/updater/options.go create mode 100644 internal/updater/state.go diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index d0b18240..d38eef49 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -217,15 +217,14 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) - openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers, + openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, uid, gid, allServers, ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel) wg.Add(1) // wait for restartOpenvpn go openvpnLooper.Run(ctx, wg) - updaterOptions := updater.NewOptions("127.0.0.1") - updaterLooper := updater.NewLooper(updaterOptions, allSettings.UpdaterPeriod, - allServers, storage, openvpnLooper.SetAllServers, httpClient, logger) + updaterLooper := updater.NewLooper(allSettings.Updater, + allServers, storage, openvpnLooper.SetServers, httpClient, logger) wg.Add(1) // wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker go updaterLooper.Run(ctx, wg) @@ -276,8 +275,9 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go wg.Add(1) go healthcheckServer.Run(ctx, wg) - // Start openvpn for the first time - openvpnLooper.Restart() + // Start openvpn for the first time in a blocking call + // until openvpn is launched + _, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, @@ -401,7 +401,7 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn tickerWg.Wait() return case <-tunnelReadyCh: // blocks until openvpn is connected - unboundLooper.Restart() + _, _ = unboundLooper.SetStatus(constants.Running) restartTickerCancel() // stop previous restart tickers tickerWg.Wait() restartTickerContext, restartTickerCancel = context.WithCancel(ctx) diff --git a/internal/cli/cli.go b/internal/cli/cli.go index d00a2777..4cf31868 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -83,7 +83,7 @@ func OpenvpnConfig() error { } func Update(args []string) error { - options := updater.Options{CLI: true} + options := settings.Updater{CLI: true} var flushToFile bool flagSet := flag.NewFlagSet("update", flag.ExitOnError) flagSet.BoolVar(&flushToFile, "file", false, "Write results to /gluetun/servers.json (for end users)") diff --git a/internal/constants/status.go b/internal/constants/status.go new file mode 100644 index 00000000..f045c43f --- /dev/null +++ b/internal/constants/status.go @@ -0,0 +1,14 @@ +package constants + +import ( + "github.com/qdm12/gluetun/internal/models" +) + +const ( + Starting models.LoopStatus = "starting" + Running models.LoopStatus = "running" + Stopping models.LoopStatus = "stopping" + Stopped models.LoopStatus = "stopped" + Crashed models.LoopStatus = "crashed" + Completed models.LoopStatus = "completed" +) diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 3d7c3955..fc29ef30 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -7,6 +7,7 @@ import ( "time" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/logging" @@ -15,80 +16,51 @@ import ( type Looper interface { Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) - Restart() - Start() - Stop() + GetStatus() (status models.LoopStatus) + SetStatus(status models.LoopStatus) (outcome string, err error) GetSettings() (settings settings.DNS) - SetSettings(settings settings.DNS) + SetSettings(settings settings.DNS) (outcome string) } type looper struct { - conf Configurator - settings settings.DNS - settingsMutex sync.RWMutex - logger logging.Logger - streamMerger command.StreamMerger - uid int - gid int - restart chan struct{} - start chan struct{} - stop chan struct{} - updateTicker chan struct{} - timeNow func() time.Time - timeSince func(time.Time) time.Duration + state state + conf Configurator + logger logging.Logger + streamMerger command.StreamMerger + uid int + gid int + loopLock sync.Mutex + start chan struct{} + running chan models.LoopStatus + stop chan struct{} + stopped chan struct{} + updateTicker chan struct{} + timeNow func() time.Time + timeSince func(time.Time) time.Duration } func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, streamMerger command.StreamMerger, uid, gid int) Looper { return &looper{ + state: state{ + status: constants.Stopped, + settings: settings, + }, conf: conf, - settings: settings, logger: logger.WithPrefix("dns over tls: "), uid: uid, gid: gid, streamMerger: streamMerger, - restart: make(chan struct{}), start: make(chan struct{}), + running: make(chan models.LoopStatus), stop: make(chan struct{}), + stopped: make(chan struct{}), updateTicker: make(chan struct{}), timeNow: time.Now, timeSince: time.Since, } } -func (l *looper) Restart() { l.restart <- struct{}{} } -func (l *looper) Start() { l.start <- struct{}{} } -func (l *looper) Stop() { l.stop <- struct{}{} } - -func (l *looper) GetSettings() (settings settings.DNS) { - l.settingsMutex.RLock() - defer l.settingsMutex.RUnlock() - return l.settings -} - -func (l *looper) SetSettings(settings settings.DNS) { - l.settingsMutex.Lock() - defer l.settingsMutex.Unlock() - updatePeriodDiffers := l.settings.UpdatePeriod != settings.UpdatePeriod - l.settings = settings - l.settingsMutex.Unlock() - if updatePeriodDiffers { - l.updateTicker <- struct{}{} - } -} - -func (l *looper) isEnabled() bool { - l.settingsMutex.RLock() - defer l.settingsMutex.RUnlock() - return l.settings.Enabled -} - -func (l *looper) setEnabled(enabled bool) { - l.settingsMutex.Lock() - defer l.settingsMutex.Unlock() - l.settings.Enabled = enabled -} - func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Warn(err) l.logger.Info("attempting restart in 10 seconds") @@ -103,96 +75,42 @@ func (l *looper) logAndWait(ctx context.Context, err error) { } } -func (l *looper) waitForFirstStart(ctx context.Context, signalDNSReady func()) { - for { - select { - case <-l.stop: - l.setEnabled(false) - l.logger.Info("not started yet") - case <-l.restart: - if l.isEnabled() { - return - } - signalDNSReady() - l.logger.Info("not restarting because disabled") - case <-l.start: - l.setEnabled(true) - return - case <-ctx.Done(): - return - } - } -} - -func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel context.CancelFunc) { - if l.isEnabled() { - return - } - for { - // wait for a signal to re-enable - select { - case <-l.stop: - l.logger.Info("already disabled") - case <-l.restart: - if !l.isEnabled() { - l.logger.Info("not restarting because disabled") - } else { - return - } - case <-l.start: - l.setEnabled(true) - return - case <-ctx.Done(): - unboundCancel() - return - } - } -} - func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) { defer wg.Done() + const fallback = false - l.useUnencryptedDNS(fallback) - l.waitForFirstStart(ctx, signalDNSReady) - if ctx.Err() != nil { + l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution? + + select { + case <-l.start: + case <-ctx.Done(): return } + defer l.logger.Warn("loop exited") - var unboundCtx context.Context - var unboundCancel context.CancelFunc = func() {} - var waitError chan error - triggeredRestart := false - l.setEnabled(true) for ctx.Err() == nil { - l.waitForSubsequentStart(ctx, unboundCancel) + err := l.updateFiles(ctx) + if err == nil { + break + } + l.state.setStatusWithLock(constants.Crashed) + l.logAndWait(ctx, err) + } + crashed := false + + for ctx.Err() == nil { settings := l.GetSettings() - // Setup - if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil { - l.logAndWait(ctx, err) - continue - } - if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil { - l.logAndWait(ctx, err) - continue - } - if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil { - l.logAndWait(ctx, err) - continue - } - - if triggeredRestart { - triggeredRestart = false - unboundCancel() - <-waitError - close(waitError) - } - unboundCtx, unboundCancel = context.WithCancel(context.Background()) + unboundCtx, unboundCancel := context.WithCancel(context.Background()) stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel) if err != nil { unboundCancel() + if !crashed { + l.running <- constants.Crashed + } + crashed = true const fallback = true l.useUnencryptedDNS(fallback) l.logAndWait(ctx, err) @@ -201,23 +119,37 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun // Started successfully go l.streamMerger.Merge(unboundCtx, stream, command.MergeName("unbound")) + l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound l.logger.Error(err) } + if err := l.conf.WaitForUnbound(); err != nil { + if !crashed { + l.running <- constants.Crashed + crashed = true + } unboundCancel() const fallback = true l.useUnencryptedDNS(fallback) l.logAndWait(ctx, err) continue } - waitError = make(chan error) + + waitError := make(chan error) go func() { err := waitFn() // blocking waitError <- err }() + l.logger.Info("DNS over TLS is ready") + if !crashed { + l.running <- constants.Running + crashed = false + } else { + l.state.setStatusWithLock(constants.Running) + } signalDNSReady() stayHere := true @@ -229,31 +161,28 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun <-waitError close(waitError) return - case <-l.restart: // triggered restart - l.logger.Info("restarting") - // unboundCancel occurs next loop run when the setup is complete - triggeredRestart = true - stayHere = false - case <-l.start: - l.logger.Info("already started") case <-l.stop: l.logger.Info("stopping") + const fallback = false + l.useUnencryptedDNS(fallback) unboundCancel() <-waitError - close(waitError) - l.setEnabled(false) + l.stopped <- struct{}{} + case <-l.start: + l.logger.Info("starting") stayHere = false case err := <-waitError: // unexpected error - close(waitError) unboundCancel() + l.state.setStatusWithLock(constants.Crashed) const fallback = true l.useUnencryptedDNS(fallback) l.logAndWait(ctx, err) stayHere = false } } + close(waitError) + unboundCancel() } - unboundCancel() } func (l *looper) useUnencryptedDNS(fallback bool) { @@ -279,7 +208,11 @@ func (l *looper) useUnencryptedDNS(fallback bool) { data := constants.DNSProviderMapping()[provider] for _, targetIP = range data.IPs { if targetIP.To4() != nil { - l.logger.Info("falling back on plaintext DNS at address %s", targetIP) + if fallback { + l.logger.Info("falling back on plaintext DNS at address %s", targetIP) + } else { + l.logger.Info("using plaintext DNS at address %s", targetIP) + } l.conf.UseDNSInternally(targetIP) if err := l.conf.UseDNSSystemWide(targetIP, settings.KeepNameserver); err != nil { l.logger.Error(err) @@ -314,7 +247,20 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { return case <-timer.C: lastTick = l.timeNow() - l.restart <- struct{}{} + + status := l.GetStatus() + if status == constants.Running { + if err := l.updateFiles(ctx); err != nil { + l.state.setStatusWithLock(constants.Crashed) + l.logger.Error(err) + l.logger.Warn("skipping Unbound restart due to failed files update") + continue + } + } + + _, _ = l.SetStatus(constants.Stopped) + _, _ = l.SetStatus(constants.Running) + settings := l.GetSettings() timer.Reset(settings.UpdatePeriod) case <-l.updateTicker: @@ -337,3 +283,17 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { } } } + +func (l *looper) updateFiles(ctx context.Context) (err error) { + if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil { + return err + } + if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil { + return err + } + settings := l.GetSettings() + if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil { + return err + } + return nil +} diff --git a/internal/dns/state.go b/internal/dns/state.go new file mode 100644 index 00000000..ce9a017c --- /dev/null +++ b/internal/dns/state.go @@ -0,0 +1,96 @@ +package dns + +import ( + "fmt" + "reflect" + "sync" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/settings" +) + +type state struct { + status models.LoopStatus + settings settings.DNS + statusMu sync.RWMutex + settingsMu sync.RWMutex +} + +func (s *state) setStatusWithLock(status models.LoopStatus) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.status = status +} + +func (l *looper) GetStatus() (status models.LoopStatus) { + l.state.statusMu.RLock() + defer l.state.statusMu.RUnlock() + return l.state.status +} + +func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { + l.state.statusMu.Lock() + defer l.state.statusMu.Unlock() + existingStatus := l.state.status + + switch status { + case constants.Running: + switch existingStatus { + case constants.Starting, constants.Running, constants.Stopping, constants.Crashed: + return fmt.Sprintf("already %s", existingStatus), nil + } + l.loopLock.Lock() + defer l.loopLock.Unlock() + l.state.status = constants.Starting + l.state.statusMu.Unlock() + l.start <- struct{}{} + newStatus := <-l.running + l.state.statusMu.Lock() + l.state.status = newStatus + return newStatus.String(), nil + case constants.Stopped: + switch existingStatus { + case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed: + return fmt.Sprintf("already %s", existingStatus), nil + } + l.loopLock.Lock() + defer l.loopLock.Unlock() + l.state.status = constants.Stopping + l.state.statusMu.Unlock() + l.stop <- struct{}{} + <-l.stopped + l.state.statusMu.Lock() + l.state.status = constants.Stopped + return status.String(), nil + default: + return "", fmt.Errorf("status %q can only be %q or %q", + status, constants.Running, constants.Stopped) + } +} + +func (l *looper) GetSettings() (settings settings.DNS) { + l.state.settingsMu.RLock() + defer l.state.settingsMu.RUnlock() + return l.state.settings +} + +func (l *looper) SetSettings(settings settings.DNS) (outcome string) { + l.state.settingsMu.Lock() + settingsUnchanged := reflect.DeepEqual(l.state.settings, settings) + if settingsUnchanged { + l.state.settingsMu.Unlock() + return "settings left unchanged" + } + tempSettings := l.state.settings + tempSettings.UpdatePeriod = settings.UpdatePeriod + onlyUpdatePeriodChanged := reflect.DeepEqual(tempSettings, settings) + l.state.settings = settings + if onlyUpdatePeriodChanged { + l.updateTicker <- struct{}{} + return "update period changed" + } + _, _ = l.SetStatus(constants.Stopped) + outcome, _ = l.SetStatus(constants.Running) + return outcome +} diff --git a/internal/models/alias.go b/internal/models/alias.go index b789ab7c..170d3753 100644 --- a/internal/models/alias.go +++ b/internal/models/alias.go @@ -20,8 +20,14 @@ type ( VPNProvider string // NetworkProtocol contains the network protocol to be used to communicate with the VPN servers. NetworkProtocol string + // Loop status such as stopped or running. + LoopStatus string ) +func (ls LoopStatus) String() string { + return string(ls) +} + func marshalJSONString(s string) (data []byte, err error) { return []byte(fmt.Sprintf("%q", s)), nil } diff --git a/internal/models/build.go b/internal/models/build.go index 2034e584..8f877632 100644 --- a/internal/models/build.go +++ b/internal/models/build.go @@ -3,5 +3,5 @@ package models type BuildInformation struct { Version string `json:"version"` Commit string `json:"commit"` - BuildDate string `json:"buildDate"` + BuildDate string `json:"build_date"` } diff --git a/internal/models/openvpn.go b/internal/models/openvpn.go index b09077ba..511cb81f 100644 --- a/internal/models/openvpn.go +++ b/internal/models/openvpn.go @@ -1,6 +1,8 @@ package models -import "net" +import ( + "net" +) type OpenVPNConnection struct { IP net.IP diff --git a/internal/models/selection.go b/internal/models/selection.go index 51c03fb5..c7964af8 100644 --- a/internal/models/selection.go +++ b/internal/models/selection.go @@ -9,15 +9,15 @@ import ( // ProviderSettings contains settings specific to a VPN provider. type ProviderSettings struct { Name VPNProvider `json:"name"` - ServerSelection ServerSelection `json:"serverSelection"` - ExtraConfigOptions ExtraConfigOptions `json:"extraConfig"` - PortForwarding PortForwarding `json:"portForwarding"` + ServerSelection ServerSelection `json:"server_selection"` + ExtraConfigOptions ExtraConfigOptions `json:"extra_config"` + PortForwarding PortForwarding `json:"port_forwarding"` } type ServerSelection struct { // Common - Protocol NetworkProtocol `json:"networkProtocol"` - TargetIP net.IP `json:"targetIP,omitempty"` + Protocol NetworkProtocol `json:"network_protocol"` + TargetIP net.IP `json:"target_ip,omitempty"` // Cyberghost, PIA, Surfshark, Windscribe, Vyprvpn, NordVPN Regions []string `json:"regions"` @@ -34,20 +34,20 @@ type ServerSelection struct { Owned bool `json:"owned"` // Mullvad, Windscribe - CustomPort uint16 `json:"customPort"` + CustomPort uint16 `json:"custom_port"` // NordVPN Numbers []uint16 `json:"numbers"` // PIA - EncryptionPreset string `json:"encryptionPreset"` + EncryptionPreset string `json:"encryption_preset"` } type ExtraConfigOptions struct { - ClientCertificate string `json:"-"` // Cyberghost - ClientKey string `json:"-"` // Cyberghost - EncryptionPreset string `json:"encryptionPreset"` // PIA - OpenVPNIPv6 bool `json:"openvpnIPv6"` // Mullvad + ClientCertificate string `json:"-"` // Cyberghost + ClientKey string `json:"-"` // Cyberghost + EncryptionPreset string `json:"encryption_preset"` // PIA + OpenVPNIPv6 bool `json:"openvpn_ipv6"` // Mullvad } // PortForwarding contains settings for port forwarding. diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index f47649b5..76fd2230 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -20,23 +20,18 @@ import ( type Looper interface { Run(ctx context.Context, wg *sync.WaitGroup) - Restart() - PortForward(vpnGatewayIP net.IP) + GetStatus() (status models.LoopStatus) + SetStatus(status models.LoopStatus) (outcome string, err error) GetSettings() (settings settings.OpenVPN) - SetSettings(settings settings.OpenVPN) - GetPortForwarded() (portForwarded uint16) - SetAllServers(allServers models.AllServers) + SetSettings(settings settings.OpenVPN) (outcome string) + GetServers() (servers models.AllServers) + SetServers(servers models.AllServers) + GetPortForwarded() (port uint16) + PortForward(vpnGatewayIP net.IP) } type looper struct { - // Variable parameters - provider models.VPNProvider - settings settings.OpenVPN - settingsMutex sync.RWMutex - portForwarded uint16 - portForwardedMutex sync.RWMutex - allServers models.AllServers - allServersMutex sync.RWMutex + state state // Fixed parameters uid int gid int @@ -50,22 +45,27 @@ type looper struct { fileManager files.FileManager streamMerger command.StreamMerger cancel context.CancelFunc - // Internal channels - restart chan struct{} + // Internal channels and locks + loopLock sync.Mutex + running chan models.LoopStatus + stop, stopped chan struct{} + start chan struct{} portForwardSignals chan net.IP } -func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, +func NewLooper(settings settings.OpenVPN, uid, gid int, allServers models.AllServers, conf Configurator, fw firewall.Configurator, routing routing.Routing, logger logging.Logger, client *http.Client, fileManager files.FileManager, streamMerger command.StreamMerger, cancel context.CancelFunc) Looper { return &looper{ - provider: provider, - settings: settings, + state: state{ + status: constants.Stopped, + settings: settings, + allServers: allServers, + }, uid: uid, gid: gid, - allServers: allServers, conf: conf, fw: fw, routing: routing, @@ -75,46 +75,29 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, fileManager: fileManager, streamMerger: streamMerger, cancel: cancel, - restart: make(chan struct{}), + start: make(chan struct{}), + running: make(chan models.LoopStatus), + stop: make(chan struct{}), + stopped: make(chan struct{}), portForwardSignals: make(chan net.IP), } } -func (l *looper) Restart() { l.restart <- struct{}{} } func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway } -func (l *looper) GetSettings() (settings settings.OpenVPN) { - l.settingsMutex.RLock() - defer l.settingsMutex.RUnlock() - return l.settings -} - -func (l *looper) SetSettings(settings settings.OpenVPN) { - l.settingsMutex.Lock() - defer l.settingsMutex.Unlock() - l.settings = settings -} - -func (l *looper) SetAllServers(allServers models.AllServers) { - l.allServersMutex.Lock() - defer l.allServersMutex.Unlock() - l.allServers = allServers -} - func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() + crashed := false select { - case <-l.restart: + case <-l.start: case <-ctx.Done(): return } defer l.logger.Warn("loop exited") for ctx.Err() == nil { - settings := l.GetSettings() - l.allServersMutex.RLock() - providerConf := provider.New(l.provider, l.allServers, time.Now) - l.allServersMutex.RUnlock() + settings, allServers := l.state.getSettingsAndServers() + providerConf := provider.New(settings.Provider.Name, allServers, time.Now) connection, err := providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection) if err != nil { l.logger.Error(err) @@ -155,6 +138,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { stream, waitFn, err := l.conf.Start(openvpnCtx) if err != nil { openvpnCancel() + if !crashed { + l.running <- constants.Crashed + crashed = true + } l.logAndWait(ctx, err) continue } @@ -179,23 +166,41 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { err := waitFn() // blocking waitError <- err }() - select { - case <-ctx.Done(): - l.logger.Warn("context canceled: exiting loop") - openvpnCancel() - <-waitError - close(waitError) - return - case <-l.restart: // triggered restart - l.logger.Info("restarting") - openvpnCancel() - <-waitError - close(waitError) - case err := <-waitError: // unexpected error - openvpnCancel() - close(waitError) - l.logAndWait(ctx, err) + + if !crashed { + l.running <- constants.Running + crashed = false + } else { + l.state.setStatusWithLock(constants.Running) } + + stayHere := true + for stayHere { + select { + case <-ctx.Done(): + l.logger.Warn("context canceled: exiting loop") + openvpnCancel() + <-waitError + close(waitError) + return + case <-l.stop: + l.logger.Info("stopping") + openvpnCancel() + <-waitError + l.stopped <- struct{}{} + case <-l.start: + l.logger.Info("starting") + stayHere = false + case err := <-waitError: // unexpected error + openvpnCancel() + l.state.setStatusWithLock(constants.Crashed) + l.logAndWait(ctx, err) + crashed = true + stayHere = false + } + } + close(waitError) + openvpnCancel() // just for the linter } } @@ -218,24 +223,21 @@ func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup, providerConf provider.Provider, client *http.Client, gateway net.IP) { defer wg.Done() - settings := l.GetSettings() + l.state.portForwardedMu.RLock() + settings := l.state.settings + l.state.portForwardedMu.RUnlock() if !settings.Provider.PortForwarding.Enabled { return } syncState := func(port uint16) (pfFilepath models.Filepath) { - l.portForwardedMutex.Lock() - l.portForwarded = port - l.portForwardedMutex.Unlock() - settings := l.GetSettings() + l.state.portForwardedMu.Lock() + defer l.state.portForwardedMu.Unlock() + l.state.portForwarded = port + l.state.settingsMu.RLock() + defer l.state.settingsMu.RUnlock() return settings.Provider.PortForwarding.Filepath } providerConf.PortForward(ctx, client, l.fileManager, l.pfLogger, gateway, l.fw, syncState) } - -func (l *looper) GetPortForwarded() (portForwarded uint16) { - l.portForwardedMutex.RLock() - defer l.portForwardedMutex.RUnlock() - return l.portForwarded -} diff --git a/internal/openvpn/state.go b/internal/openvpn/state.go new file mode 100644 index 00000000..56f6d0d4 --- /dev/null +++ b/internal/openvpn/state.go @@ -0,0 +1,121 @@ +package openvpn + +import ( + "fmt" + "reflect" + "sync" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/settings" +) + +type state struct { + status models.LoopStatus + settings settings.OpenVPN + allServers models.AllServers + portForwarded uint16 + statusMu sync.RWMutex + settingsMu sync.RWMutex + allServersMu sync.RWMutex + portForwardedMu sync.RWMutex +} + +func (s *state) setStatusWithLock(status models.LoopStatus) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.status = status +} + +func (s *state) getSettingsAndServers() (settings settings.OpenVPN, allServers models.AllServers) { + s.settingsMu.RLock() + s.allServersMu.RLock() + settings = s.settings + allServers = s.allServers + s.settingsMu.RLock() + s.allServersMu.RLock() + return settings, allServers +} + +func (l *looper) GetStatus() (status models.LoopStatus) { + l.state.statusMu.RLock() + defer l.state.statusMu.RUnlock() + return l.state.status +} + +func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { + l.state.statusMu.Lock() + defer l.state.statusMu.Unlock() + existingStatus := l.state.status + + switch status { + case constants.Running: + switch existingStatus { + case constants.Starting, constants.Running, constants.Stopping, constants.Crashed: + return fmt.Sprintf("already %s", existingStatus), nil + } + l.loopLock.Lock() + defer l.loopLock.Unlock() + l.state.status = constants.Starting + l.state.statusMu.Unlock() + l.start <- struct{}{} + newStatus := <-l.running + l.state.statusMu.Lock() + l.state.status = newStatus + return newStatus.String(), nil + case constants.Stopped: + switch existingStatus { + case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed: + return fmt.Sprintf("already %s", existingStatus), nil + } + l.loopLock.Lock() + defer l.loopLock.Unlock() + l.state.status = constants.Stopping + l.state.statusMu.Unlock() + l.stop <- struct{}{} + <-l.stopped + l.state.statusMu.Lock() + l.state.status = constants.Stopped + return status.String(), nil + default: + return "", fmt.Errorf("status %q can only be %q or %q", + status, constants.Running, constants.Stopped) + } +} + +func (l *looper) GetSettings() (settings settings.OpenVPN) { + l.state.settingsMu.RLock() + defer l.state.settingsMu.RUnlock() + return l.state.settings +} + +func (l *looper) SetSettings(settings settings.OpenVPN) (outcome string) { + l.state.settingsMu.Lock() + settingsUnchanged := reflect.DeepEqual(l.state.settings, settings) + if settingsUnchanged { + l.state.settingsMu.Unlock() + return "settings left unchanged" + } + l.state.settings = settings + _, _ = l.SetStatus(constants.Stopped) + outcome, _ = l.SetStatus(constants.Running) + return outcome +} + +func (l *looper) GetServers() (servers models.AllServers) { + l.state.allServersMu.RLock() + defer l.state.allServersMu.RUnlock() + return l.state.allServers +} + +func (l *looper) SetServers(servers models.AllServers) { + l.state.allServersMu.Lock() + defer l.state.allServersMu.Unlock() + l.state.allServers = servers +} + +func (l *looper) GetPortForwarded() (port uint16) { + l.state.portForwardedMu.RLock() + defer l.state.portForwardedMu.RUnlock() + return port +} diff --git a/internal/params/dns.go b/internal/params/dns.go index 250d3692..42a22149 100644 --- a/internal/params/dns.go +++ b/internal/params/dns.go @@ -130,8 +130,8 @@ func (r *reader) GetDNSOverTLSPrivateAddresses() (privateAddresses []string, err return privateAddresses, nil } -// GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using ipv6 DNS over TLS -// servers from the environment variable DOT_IPV6. +// GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using +// ipv6 DNS over TLS from the environment variable DOT_IPV6. func (r *reader) GetDNSOverTLSIPv6() (ipv6 bool, err error) { return r.envParams.GetOnOff("DOT_IPV6", libparams.Default("off")) } diff --git a/internal/server/dns.go b/internal/server/dns.go new file mode 100644 index 00000000..af15f1fb --- /dev/null +++ b/internal/server/dns.go @@ -0,0 +1,76 @@ +//nolint:dupl +package server + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/qdm12/gluetun/internal/dns" + "github.com/qdm12/golibs/logging" +) + +func newDNSHandler(looper dns.Looper, logger logging.Logger) http.Handler { + return &dnsHandler{ + looper: looper, + logger: logger, + } +} + +type dnsHandler struct { + looper dns.Looper + logger logging.Logger +} + +func (h *dnsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r.RequestURI = strings.TrimPrefix(r.RequestURI, "/dns") + switch r.RequestURI { + case "/status": //nolint:goconst + switch r.Method { + case http.MethodGet: + h.getStatus(w) + case http.MethodPut: + h.setStatus(w, r) + default: + http.Error(w, "", http.StatusNotFound) + } + default: + http.Error(w, "", http.StatusNotFound) + } +} + +func (h *dnsHandler) getStatus(w http.ResponseWriter) { + status := h.looper.GetStatus() + encoder := json.NewEncoder(w) + data := statusWrapper{Status: string(status)} + if err := encoder.Encode(data); err != nil { + h.logger.Warn(err) + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (h *dnsHandler) setStatus(w http.ResponseWriter, r *http.Request) { + decoder := json.NewDecoder(r.Body) + var data statusWrapper + if err := decoder.Decode(&data); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + status, err := data.getStatus() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + outcome, err := h.looper.SetStatus(status) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + encoder := json.NewEncoder(w) + if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil { + h.logger.Warn(err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } +} diff --git a/internal/server/handler.go b/internal/server/handler.go index 7b6ec32f..3f77ee5e 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -1,8 +1,8 @@ package server import ( - "fmt" "net/http" + "strings" "github.com/qdm12/gluetun/internal/dns" "github.com/qdm12/gluetun/internal/models" @@ -17,54 +17,33 @@ func newHandler(logger logging.Logger, logging bool, unboundLooper dns.Looper, updaterLooper updater.Looper, ) http.Handler { - return &handler{ - logger: logger, - logging: logging, - buildInfo: buildInfo, - openvpnLooper: openvpnLooper, - unboundLooper: unboundLooper, - updaterLooper: updaterLooper, - } + handler := &handler{} + + openvpn := newOpenvpnHandler(openvpnLooper, logger) + dns := newDNSHandler(unboundLooper, logger) + updater := newUpdaterHandler(updaterLooper, logger) + + handler.v0 = newHandlerV0(logger, openvpnLooper, unboundLooper, updaterLooper) + handler.v1 = newHandlerV1(logger, buildInfo, openvpn, dns, updater) + + handlerWithLog := withLogMiddleware(handler, logger, logging) + handler.setLogEnabled = handlerWithLog.setEnabled + + return handlerWithLog } type handler struct { - logger logging.Logger - logging bool - buildInfo models.BuildInformation - openvpnLooper openvpn.Looper - unboundLooper dns.Looper - updaterLooper updater.Looper + v0 http.Handler + v1 http.Handler + setLogEnabled func(enabled bool) } -func (h *handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { - if h.logging { - h.logger.Info("HTTP %s %s", request.Method, request.RequestURI) - } - switch request.Method { - case http.MethodGet: - switch request.RequestURI { - case "/version": - h.getVersion(responseWriter) - responseWriter.WriteHeader(http.StatusOK) - case "/openvpn/actions/restart": - h.openvpnLooper.Restart() - responseWriter.WriteHeader(http.StatusOK) - case "/unbound/actions/restart": - h.unboundLooper.Restart() - responseWriter.WriteHeader(http.StatusOK) - case "/openvpn/portforwarded": - h.getPortForwarded(responseWriter) - case "/openvpn/settings": - h.getOpenvpnSettings(responseWriter) - case "/updater/restart": - h.updaterLooper.Restart() - responseWriter.WriteHeader(http.StatusOK) - default: - errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI) - http.Error(responseWriter, errString, http.StatusBadRequest) - } - default: - errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI) - http.Error(responseWriter, errString, http.StatusBadRequest) +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r.RequestURI = strings.TrimSuffix(r.RequestURI, "/") + if !strings.HasPrefix(r.RequestURI, "/v1/") && r.RequestURI != "/v1" { + h.v0.ServeHTTP(w, r) + return } + r.RequestURI = strings.TrimPrefix(r.RequestURI, "/v1") + h.v1.ServeHTTP(w, r) } diff --git a/internal/server/handlerv0.go b/internal/server/handlerv0.go new file mode 100644 index 00000000..8236cd43 --- /dev/null +++ b/internal/server/handlerv0.go @@ -0,0 +1,69 @@ +package server + +import ( + "net/http" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/dns" + "github.com/qdm12/gluetun/internal/openvpn" + "github.com/qdm12/gluetun/internal/updater" + "github.com/qdm12/golibs/logging" +) + +func newHandlerV0(logger logging.Logger, + openvpn openvpn.Looper, dns dns.Looper, updater updater.Looper) http.Handler { + return &handlerV0{ + logger: logger, + openvpn: openvpn, + dns: dns, + updater: updater, + } +} + +type handlerV0 struct { + logger logging.Logger + openvpn openvpn.Looper + dns dns.Looper + updater updater.Looper +} + +func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "unversioned API: only supports GET method", http.StatusNotFound) + return + } + switch r.RequestURI { + case "/version": + http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect) + case "/openvpn/actions/restart": + outcome, _ := h.openvpn.SetStatus(constants.Stopped) + h.logger.Info("openvpn: %s", outcome) + outcome, _ = h.openvpn.SetStatus(constants.Running) + h.logger.Info("openvpn: %s", outcome) + if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil { + h.logger.Warn(err) + } + case "/unbound/actions/restart": + outcome, _ := h.dns.SetStatus(constants.Stopped) + h.logger.Info("dns: %s", outcome) + outcome, _ = h.dns.SetStatus(constants.Running) + h.logger.Info("dns: %s", outcome) + if _, err := w.Write([]byte("dns restarted, please consider using the /v1/ API in the future.")); err != nil { + h.logger.Warn(err) + } + case "/openvpn/portforwarded": + http.Redirect(w, r, "/v1/openvpn/portforwarded", http.StatusPermanentRedirect) + case "/openvpn/settings": + http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect) + case "/updater/restart": + outcome, _ := h.updater.SetStatus(constants.Stopped) + h.logger.Info("updater: %s", outcome) + outcome, _ = h.updater.SetStatus(constants.Running) + h.logger.Info("updater: %s", outcome) + if _, err := w.Write([]byte("updater restarted, please consider using the /v1/ API in the future.")); err != nil { + h.logger.Warn(err) + } + default: + http.Error(w, "unversioned API: requested URI not found", http.StatusNotFound) + } +} diff --git a/internal/server/handlerv1.go b/internal/server/handlerv1.go new file mode 100644 index 00000000..948677dd --- /dev/null +++ b/internal/server/handlerv1.go @@ -0,0 +1,58 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/golibs/logging" +) + +func newHandlerV1(logger logging.Logger, buildInfo models.BuildInformation, + openvpn, dns, updater http.Handler) http.Handler { + return &handlerV1{ + logger: logger, + buildInfo: buildInfo, + openvpn: openvpn, + dns: dns, + updater: updater, + } +} + +type handlerV1 struct { + logger logging.Logger + buildInfo models.BuildInformation + openvpn http.Handler + dns http.Handler + updater http.Handler +} + +func (h *handlerV1) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch { + case r.RequestURI == "/version" && r.Method == http.MethodGet: + h.getVersion(w) + case strings.HasPrefix(r.RequestURI, "/openvpn"): + h.openvpn.ServeHTTP(w, r) + case strings.HasPrefix(r.RequestURI, "/dns"): + h.dns.ServeHTTP(w, r) + case strings.HasPrefix(r.RequestURI, "/updater"): + h.updater.ServeHTTP(w, r) + default: + errString := fmt.Sprintf("%s %s not found", r.Method, r.RequestURI) + http.Error(w, errString, http.StatusNotFound) + } +} + +func (h *handlerV1) getVersion(w http.ResponseWriter) { + data, err := json.Marshal(h.buildInfo) + if err != nil { + h.logger.Warn(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if _, err := w.Write(data); err != nil { + h.logger.Warn(err) + } +} diff --git a/internal/server/log.go b/internal/server/log.go new file mode 100644 index 00000000..48f2efe8 --- /dev/null +++ b/internal/server/log.go @@ -0,0 +1,75 @@ +package server + +import ( + "net/http" + "sync" + "time" + + "github.com/qdm12/golibs/logging" +) + +func withLogMiddleware(childHandler http.Handler, logger logging.Logger, enabled bool) *logMiddleware { + return &logMiddleware{ + childHandler: childHandler, + logger: logger, + timeNow: time.Now, + enabled: enabled, + } +} + +type logMiddleware struct { + childHandler http.Handler + logger logging.Logger + timeNow func() time.Time + enabled bool + enabledMu sync.RWMutex +} + +func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !m.isEnabled() { + m.childHandler.ServeHTTP(w, r) + return + } + tStart := m.timeNow() + statefulWriter := &statefulResponseWriter{httpWriter: w} + m.childHandler.ServeHTTP(statefulWriter, r) + duration := m.timeNow().Sub(tStart) + m.logger.Info("%d %s %s wrote %dB to %s in %s", + statefulWriter.statusCode, r.Method, r.RequestURI, statefulWriter.length, r.RemoteAddr, duration) +} + +func (m *logMiddleware) setEnabled(enabled bool) { + m.enabledMu.Lock() + defer m.enabledMu.Unlock() + m.enabled = enabled +} + +func (m *logMiddleware) isEnabled() (enabled bool) { + m.enabledMu.RLock() + defer m.enabledMu.RUnlock() + return m.enabled +} + +type statefulResponseWriter struct { + httpWriter http.ResponseWriter + statusCode int + length int +} + +func (w *statefulResponseWriter) Write(b []byte) (n int, err error) { + n, err = w.httpWriter.Write(b) + if w.statusCode == 0 { + w.statusCode = http.StatusOK + } + w.length += n + return n, err +} + +func (w *statefulResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.httpWriter.WriteHeader(statusCode) +} + +func (w *statefulResponseWriter) Header() http.Header { + return w.httpWriter.Header() +} diff --git a/internal/server/openvpn.go b/internal/server/openvpn.go index 9b4df9b7..2a3d7361 100644 --- a/internal/server/openvpn.go +++ b/internal/server/openvpn.go @@ -3,34 +3,110 @@ package server import ( "encoding/json" "net/http" + "strings" + + "github.com/qdm12/gluetun/internal/openvpn" + "github.com/qdm12/golibs/logging" ) -func (h *handler) getPortForwarded(w http.ResponseWriter) { - port := h.openvpnLooper.GetPortForwarded() - data, err := json.Marshal(struct { - Port uint16 `json:"port"` - }{port}) - if err != nil { - h.logger.Warn(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - if _, err := w.Write(data); err != nil { - h.logger.Warn(err) - w.WriteHeader(http.StatusInternalServerError) +func newOpenvpnHandler(looper openvpn.Looper, logger logging.Logger) http.Handler { + return &openvpnHandler{ + looper: looper, + logger: logger, } } -func (h *handler) getOpenvpnSettings(w http.ResponseWriter) { - settings := h.openvpnLooper.GetSettings() - data, err := json.Marshal(settings) - if err != nil { +type openvpnHandler struct { + looper openvpn.Looper + logger logging.Logger +} + +func (h *openvpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r.RequestURI = strings.TrimPrefix(r.RequestURI, "/openvpn") + switch r.RequestURI { + case "/status": + switch r.Method { + case http.MethodGet: + h.getStatus(w) + case http.MethodPut: + h.setStatus(w, r) + default: + http.Error(w, "", http.StatusNotFound) + } + case "/settings": + switch r.Method { + case http.MethodGet: + h.getSettings(w) + default: + http.Error(w, "", http.StatusNotFound) + } + case "/portforwarded": + switch r.Method { + case http.MethodGet: + h.getPortForwarded(w) + default: + http.Error(w, "", http.StatusNotFound) + } + default: + http.Error(w, "", http.StatusNotFound) + } +} + +func (h *openvpnHandler) getStatus(w http.ResponseWriter) { + status := h.looper.GetStatus() + encoder := json.NewEncoder(w) + data := statusWrapper{Status: string(status)} + if err := encoder.Encode(data); err != nil { h.logger.Warn(err) w.WriteHeader(http.StatusInternalServerError) return } - if _, err := w.Write(data); err != nil { +} + +func (h *openvpnHandler) setStatus(w http.ResponseWriter, r *http.Request) { //nolint:dupl + decoder := json.NewDecoder(r.Body) + var data statusWrapper + if err := decoder.Decode(&data); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + status, err := data.getStatus() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + outcome, err := h.looper.SetStatus(status) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + encoder := json.NewEncoder(w) + if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil { h.logger.Warn(err) - w.WriteHeader(http.StatusInternalServerError) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } +} + +func (h *openvpnHandler) getSettings(w http.ResponseWriter) { + settings := h.looper.GetSettings() + settings.User = "redacted" + settings.Password = "redacted" + encoder := json.NewEncoder(w) + if err := encoder.Encode(settings); err != nil { + h.logger.Warn(err) + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) { + port := h.looper.GetPortForwarded() + encoder := json.NewEncoder(w) + data := portWrapper{Port: port} + if err := encoder.Encode(data); err != nil { + h.logger.Warn(err) + w.WriteHeader(http.StatusInternalServerError) + return } } diff --git a/internal/server/server.go b/internal/server/server.go index a06536c2..f339d91d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -23,7 +23,8 @@ type server struct { handler http.Handler } -func New(address string, logging bool, logger logging.Logger, buildInfo models.BuildInformation, +func New(address string, logging bool, logger logging.Logger, + buildInfo models.BuildInformation, openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper) Server { serverLogger := logger.WithPrefix("http server: ") handler := newHandler(serverLogger, logging, buildInfo, openvpnLooper, unboundLooper, updaterLooper) diff --git a/internal/server/updater.go b/internal/server/updater.go new file mode 100644 index 00000000..746dbcbe --- /dev/null +++ b/internal/server/updater.go @@ -0,0 +1,78 @@ +//nolint:dupl +package server + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/qdm12/gluetun/internal/updater" + "github.com/qdm12/golibs/logging" +) + +func newUpdaterHandler( + looper updater.Looper, + logger logging.Logger) http.Handler { + return &updaterHandler{ + looper: looper, + logger: logger, + } +} + +type updaterHandler struct { + looper updater.Looper + logger logging.Logger +} + +func (h *updaterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r.RequestURI = strings.TrimPrefix(r.RequestURI, "/updater") + switch r.RequestURI { + case "/status": + switch r.Method { + case http.MethodGet: + h.getStatus(w) + case http.MethodPut: + h.setStatus(w, r) + default: + http.Error(w, "", http.StatusNotFound) + } + default: + http.Error(w, "", http.StatusNotFound) + } +} + +func (h *updaterHandler) getStatus(w http.ResponseWriter) { + status := h.looper.GetStatus() + encoder := json.NewEncoder(w) + data := statusWrapper{Status: string(status)} + if err := encoder.Encode(data); err != nil { + h.logger.Warn(err) + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (h *updaterHandler) setStatus(w http.ResponseWriter, r *http.Request) { + decoder := json.NewDecoder(r.Body) + var data statusWrapper + if err := decoder.Decode(&data); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + status, err := data.getStatus() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + outcome, err := h.looper.SetStatus(status) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + encoder := json.NewEncoder(w) + if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil { + h.logger.Warn(err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } +} diff --git a/internal/server/version.go b/internal/server/version.go deleted file mode 100644 index 9f11ef26..00000000 --- a/internal/server/version.go +++ /dev/null @@ -1,19 +0,0 @@ -package server - -import ( - "encoding/json" - "net/http" -) - -func (h *handler) getVersion(w http.ResponseWriter) { - data, err := json.Marshal(h.buildInfo) - if err != nil { - h.logger.Warn(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - if _, err := w.Write(data); err != nil { - h.logger.Warn(err) - w.WriteHeader(http.StatusInternalServerError) - } -} diff --git a/internal/server/wrappers.go b/internal/server/wrappers.go new file mode 100644 index 00000000..2b1f1eaa --- /dev/null +++ b/internal/server/wrappers.go @@ -0,0 +1,32 @@ +package server + +import ( + "fmt" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" +) + +type statusWrapper struct { + Status string `json:"status"` +} + +func (sw *statusWrapper) getStatus() (status models.LoopStatus, err error) { + status = models.LoopStatus(sw.Status) + switch status { + case constants.Stopped, constants.Running: + return status, nil + default: + return "", fmt.Errorf( + "invalid status %q: possible values are: %s, %s", + sw.Status, constants.Stopped, constants.Running) + } +} + +type portWrapper struct { + Port uint16 `json:"port"` +} + +type outcomeWrapper struct { + Outcome string `json:"outcome"` +} diff --git a/internal/settings/openvpn.go b/internal/settings/openvpn.go index 5f210aa0..569158e0 100644 --- a/internal/settings/openvpn.go +++ b/internal/settings/openvpn.go @@ -12,9 +12,9 @@ import ( // OpenVPN contains settings to configure the OpenVPN client. type OpenVPN struct { User string `json:"user"` - Password string `json:"-"` + Password string `json:"password"` Verbosity int `json:"verbosity"` - Root bool `json:"runAsRoot"` + Root bool `json:"run_as_root"` Cipher string `json:"cipher"` Auth string `json:"auth"` Provider models.ProviderSettings `json:"provider"` diff --git a/internal/settings/openvpn_test.go b/internal/settings/openvpn_test.go index 7489b1cc..9e1e7444 100644 --- a/internal/settings/openvpn_test.go +++ b/internal/settings/openvpn_test.go @@ -20,7 +20,7 @@ func Test_OpenVPN_JSON(t *testing.T) { data, err := json.Marshal(in) require.NoError(t, err) //nolint:lll - assert.Equal(t, `{"user":"","verbosity":0,"runAsRoot":true,"cipher":"","auth":"","provider":{"name":"name","serverSelection":{"networkProtocol":"","regions":null,"group":"","countries":null,"cities":null,"hostnames":null,"isps":null,"owned":false,"customPort":0,"numbers":null,"encryptionPreset":""},"extraConfig":{"encryptionPreset":"","openvpnIPv6":false},"portForwarding":{"enabled":false,"filepath":""}}}`, string(data)) + assert.Equal(t, `{"user":"","password":"","verbosity":0,"run_as_root":true,"cipher":"","auth":"","provider":{"name":"name","server_selection":{"network_protocol":"","regions":null,"group":"","countries":null,"cities":null,"hostnames":null,"isps":null,"owned":false,"custom_port":0,"numbers":null,"encryption_preset":""},"extra_config":{"encryption_preset":"","openvpn_ipv6":false},"port_forwarding":{"enabled":false,"filepath":""}}}`, string(data)) var out OpenVPN err = json.Unmarshal(data, &out) require.NoError(t, err) diff --git a/internal/settings/settings.go b/internal/settings/settings.go index 9042062c..97122428 100644 --- a/internal/settings/settings.go +++ b/internal/settings/settings.go @@ -1,7 +1,6 @@ package settings import ( - "fmt" "strings" "time" @@ -24,7 +23,7 @@ type Settings struct { HTTPProxy HTTPProxy ShadowSocks ShadowSocks PublicIPPeriod time.Duration - UpdaterPeriod time.Duration + Updater Updater VersionInformation bool ControlServer ControlServer } @@ -34,10 +33,6 @@ func (s *Settings) String() string { if s.VersionInformation { versionInformation = enabled } - updaterLine := "Updater: disabled" - if s.UpdaterPeriod > 0 { - updaterLine = fmt.Sprintf("Updater period: %s", s.UpdaterPeriod) - } return strings.Join([]string{ "Settings summary below:", s.OpenVPN.String(), @@ -47,9 +42,9 @@ func (s *Settings) String() string { s.HTTPProxy.String(), s.ShadowSocks.String(), s.ControlServer.String(), + s.Updater.String(), "Public IP check period: " + s.PublicIPPeriod.String(), // TODO print disabled if 0 "Version information: " + versionInformation, - updaterLine, "", // new line at the end }, "\n") } @@ -93,7 +88,7 @@ func GetAllSettings(paramsReader params.Reader) (settings Settings, err error) { if err != nil { return settings, err } - settings.UpdaterPeriod, err = paramsReader.GetUpdaterPeriod() + settings.Updater, err = GetUpdaterSettings(paramsReader) if err != nil { return settings, err } diff --git a/internal/settings/updater.go b/internal/settings/updater.go new file mode 100644 index 00000000..19a967c8 --- /dev/null +++ b/internal/settings/updater.go @@ -0,0 +1,59 @@ +package settings + +import ( + "fmt" + "strings" + "time" + + "github.com/qdm12/gluetun/internal/params" +) + +type Updater struct { + Period time.Duration `json:"period"` + DNSAddress string `json:"dns_address"` + Cyberghost bool `json:"cyberghost"` + Mullvad bool `json:"mullvad"` + Nordvpn bool `json:"nordvpn"` + PIA bool `json:"pia"` + Privado bool `json:"privado"` + Purevpn bool `json:"purevpn"` + Surfshark bool `json:"surfshark"` + Vyprvpn bool `json:"vyprvpn"` + Windscribe bool `json:"windscribe"` + // The two below should be used in CLI mode only + Stdout bool `json:"-"` // in order to update constants file (maintainer side) + CLI bool `json:"-"` +} + +// GetUpdaterSettings obtains the server updater settings using the params functions. +func GetUpdaterSettings(paramsReader params.Reader) (settings Updater, err error) { + settings = Updater{ + Cyberghost: true, + Mullvad: true, + Nordvpn: true, + PIA: true, + Purevpn: true, + Surfshark: true, + Vyprvpn: true, + Windscribe: true, + Stdout: false, + CLI: false, + DNSAddress: "127.0.0.1", + } + settings.Period, err = paramsReader.GetUpdaterPeriod() + if err != nil { + return settings, err + } + return settings, nil +} + +func (s *Updater) String() string { + if s.Period == 0 { + return "Server updater settings: disabled" + } + settingsList := []string{ + "Server updater settings:", + fmt.Sprintf("Period: %s", s.Period), + } + return strings.Join(settingsList, "\n|--") +} diff --git a/internal/updater/loop.go b/internal/updater/loop.go index 91ec74e0..28184132 100644 --- a/internal/updater/loop.go +++ b/internal/updater/loop.go @@ -6,7 +6,9 @@ import ( "sync" "time" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/golibs/logging" ) @@ -14,60 +16,54 @@ import ( type Looper interface { Run(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) - Restart() - Stop() - GetPeriod() (period time.Duration) - SetPeriod(period time.Duration) + GetStatus() (status models.LoopStatus) + SetStatus(status models.LoopStatus) (outcome string, err error) + GetSettings() (settings settings.Updater) + SetSettings(settings settings.Updater) (outcome string) } type looper struct { - period time.Duration - periodMutex sync.RWMutex + state state + // Objects updater Updater storage storage.Storage setAllServers func(allServers models.AllServers) logger logging.Logger - restart chan struct{} - stop chan struct{} - updateTicker chan struct{} - timeNow func() time.Time - timeSince func(time.Time) time.Duration + // Internal channels and locks + loopLock sync.Mutex + start chan struct{} + running chan models.LoopStatus + stop chan struct{} + stopped chan struct{} + updateTicker chan struct{} + // Mock functions + timeNow func() time.Time + timeSince func(time.Time) time.Duration } -func NewLooper(options Options, period time.Duration, currentServers models.AllServers, +func NewLooper(settings settings.Updater, currentServers models.AllServers, storage storage.Storage, setAllServers func(allServers models.AllServers), client *http.Client, logger logging.Logger) Looper { loggerWithPrefix := logger.WithPrefix("updater: ") return &looper{ - period: period, - updater: New(options, client, currentServers, loggerWithPrefix), + state: state{ + status: constants.Stopped, + settings: settings, + }, + updater: New(settings, client, currentServers, loggerWithPrefix), storage: storage, setAllServers: setAllServers, logger: loggerWithPrefix, - restart: make(chan struct{}), + start: make(chan struct{}), + running: make(chan models.LoopStatus), stop: make(chan struct{}), + stopped: make(chan struct{}), updateTicker: make(chan struct{}), timeNow: time.Now, timeSince: time.Since, } } -func (l *looper) Restart() { l.restart <- struct{}{} } -func (l *looper) Stop() { l.stop <- struct{}{} } - -func (l *looper) GetPeriod() (period time.Duration) { - l.periodMutex.RLock() - defer l.periodMutex.RUnlock() - return l.period -} - -func (l *looper) SetPeriod(period time.Duration) { - l.periodMutex.Lock() - l.period = period - l.periodMutex.Unlock() - l.updateTicker <- struct{}{} -} - func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Error(err) const waitTime = 5 * time.Minute @@ -84,52 +80,71 @@ func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() + crashed := false select { - case <-l.restart: - l.logger.Info("starting...") + case <-l.start: case <-ctx.Done(): return } defer l.logger.Warn("loop exited") - enabled := true - for ctx.Err() == nil { - for !enabled { - // wait for a signal to re-enable + updateCtx, updateCancel := context.WithCancel(ctx) + defer updateCancel() + serversCh := make(chan models.AllServers) + errorCh := make(chan error) + go func() { + servers, err := l.updater.UpdateServers(updateCtx) + if err != nil { + errorCh <- err + return + } + serversCh <- servers + }() + + if !crashed { + l.running <- constants.Running + crashed = false + } else { + l.state.setStatusWithLock(constants.Running) + } + + stayHere := true + for stayHere { select { - case <-l.stop: - l.logger.Info("already disabled") - case <-l.restart: - enabled = true case <-ctx.Done(): + l.logger.Warn("context canceled: exiting loop") + updateCancel() + close(errorCh) return + case <-l.start: + l.logger.Info("starting") + updateCancel() + stayHere = false + case <-l.stop: + l.logger.Info("stopping") + updateCancel() + <-errorCh + l.stopped <- struct{}{} + case servers := <-serversCh: + updateCancel() + close(serversCh) + l.setAllServers(servers) + if err := l.storage.FlushToFile(servers); err != nil { + l.logger.Error(err) + } + l.state.setStatusWithLock(constants.Completed) + l.logger.Info("Updated servers information") + case err := <-errorCh: + updateCancel() + close(serversCh) + l.state.setStatusWithLock(constants.Crashed) + l.logAndWait(ctx, err) + crashed = true + stayHere = false } } - - // Enabled and has a period set - - servers, err := l.updater.UpdateServers(ctx) - if err != nil { - if ctx.Err() != nil { - return - } - l.logAndWait(ctx, err) - continue - } - l.setAllServers(servers) - if err := l.storage.FlushToFile(servers); err != nil { - l.logger.Error(err) - } - l.logger.Info("Updated servers information") - - select { - case <-l.restart: // triggered restart - case <-l.stop: - enabled = false - case <-ctx.Done(): - return - } + close(errorCh) } } @@ -138,7 +153,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { timer := time.NewTimer(time.Hour) timer.Stop() timerIsStopped := true - if period := l.GetPeriod(); period > 0 { + if period := l.GetSettings().Period; period > 0 { timerIsStopped = false timer.Reset(period) } @@ -152,14 +167,14 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { return case <-timer.C: lastTick = l.timeNow() - l.restart <- struct{}{} - timer.Reset(l.GetPeriod()) + l.start <- struct{}{} + timer.Reset(l.GetSettings().Period) case <-l.updateTicker: if !timerIsStopped && !timer.Stop() { <-timer.C } timerIsStopped = true - period := l.GetPeriod() + period := l.GetSettings().Period if period == 0 { continue } diff --git a/internal/updater/options.go b/internal/updater/options.go deleted file mode 100644 index b26fdc88..00000000 --- a/internal/updater/options.go +++ /dev/null @@ -1,32 +0,0 @@ -package updater - -type Options struct { - Cyberghost bool - Mullvad bool - Nordvpn bool - PIA bool - Privado bool - Purevpn bool - Surfshark bool - Vyprvpn bool - Windscribe bool - Stdout bool // in order to update constants file (maintainer side) - CLI bool - DNSAddress string -} - -func NewOptions(dnsAddress string) Options { - return Options{ - Cyberghost: true, - Mullvad: true, - Nordvpn: true, - PIA: true, - Purevpn: true, - Surfshark: true, - Vyprvpn: true, - Windscribe: true, - Stdout: false, - CLI: false, - DNSAddress: dnsAddress, - } -} diff --git a/internal/updater/state.go b/internal/updater/state.go new file mode 100644 index 00000000..e55d6f44 --- /dev/null +++ b/internal/updater/state.go @@ -0,0 +1,88 @@ +package updater + +import ( + "fmt" + "reflect" + "sync" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/settings" +) + +type state struct { + status models.LoopStatus + settings settings.Updater + statusMu sync.RWMutex + periodMu sync.RWMutex +} + +func (s *state) setStatusWithLock(status models.LoopStatus) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.status = status +} + +func (l *looper) GetStatus() (status models.LoopStatus) { + l.state.statusMu.RLock() + defer l.state.statusMu.RUnlock() + return l.state.status +} + +func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { + l.state.statusMu.Lock() + defer l.state.statusMu.Unlock() + existingStatus := l.state.status + + switch status { + case constants.Running: + switch existingStatus { + case constants.Starting, constants.Running, constants.Stopping, constants.Crashed: + return fmt.Sprintf("already %s", existingStatus), nil + } + l.loopLock.Lock() + defer l.loopLock.Unlock() + l.state.status = constants.Starting + l.state.statusMu.Unlock() + l.start <- struct{}{} + newStatus := <-l.running + l.state.statusMu.Lock() + l.state.status = newStatus + return newStatus.String(), nil + case constants.Stopped: + switch existingStatus { + case constants.Stopped, constants.Stopping, constants.Starting, constants.Crashed: + return fmt.Sprintf("already %s", existingStatus), nil + } + l.loopLock.Lock() + defer l.loopLock.Unlock() + l.state.status = constants.Stopping + l.state.statusMu.Unlock() + l.stop <- struct{}{} + <-l.stopped + l.state.statusMu.Lock() + l.state.status = status + return status.String(), nil + default: + return "", fmt.Errorf("status %q can only be %q or %q", + status, constants.Running, constants.Stopped) + } +} + +func (l *looper) GetSettings() (settings settings.Updater) { + l.state.periodMu.RLock() + defer l.state.periodMu.RUnlock() + return l.state.settings +} + +func (l *looper) SetSettings(settings settings.Updater) (outcome string) { + l.state.periodMu.Lock() + defer l.state.periodMu.Unlock() + settingsUnchanged := reflect.DeepEqual(settings, l.state.settings) + if settingsUnchanged { + return "settings left unchanged" + } + l.state.settings = settings + l.updateTicker <- struct{}{} + return "settings updated" +} diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 9f83f4ae..05336944 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -7,6 +7,7 @@ import ( "time" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/network" ) @@ -17,7 +18,7 @@ type Updater interface { type updater struct { // configuration - options Options + options settings.Updater // state servers models.AllServers @@ -30,11 +31,12 @@ type updater struct { client network.Client } -func New(options Options, httpClient *http.Client, currentServers models.AllServers, logger logging.Logger) Updater { - if len(options.DNSAddress) == 0 { - options.DNSAddress = "1.1.1.1" +func New(settings settings.Updater, httpClient *http.Client, + currentServers models.AllServers, logger logging.Logger) Updater { + if len(settings.DNSAddress) == 0 { + settings.DNSAddress = "1.1.1.1" } - resolver := newResolver(options.DNSAddress) + resolver := newResolver(settings.DNSAddress) const clientTimeout = 10 * time.Second return &updater{ logger: logger, @@ -42,7 +44,7 @@ func New(options Options, httpClient *http.Client, currentServers models.AllServ println: func(s string) { fmt.Println(s) }, lookupIP: newLookupIP(resolver), client: network.NewClient(clientTimeout), - options: options, + options: settings, servers: currentServers, } }