diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 5b43ae04..0131c7ee 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -377,9 +377,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, portForwardLogger := logger.New(log.SetComponent("port forwarding")) portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding, httpClient, firewallConf, portForwardLogger, puid, pgid) - portForwardHandler, portForwardCtx, portForwardDone := goshutdown.NewGoRoutineHandler( - "port forwarding", goroutine.OptionTimeout(time.Second)) - go portForwardLooper.Run(portForwardCtx, portForwardDone) + portForwardRunError, _ := portForwardLooper.Start(context.Background()) unboundLogger := logger.New(log.SetComponent("dns")) unboundLooper := dns.NewLoop(dnsConf, allSettings.DNS, httpClient, @@ -481,13 +479,21 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, order.OptionOnSuccess(defaultShutdownOnSuccess), order.OptionOnFailure(defaultShutdownOnFailure)) orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler, - vpnHandler, portForwardHandler, otherGroupHandler) + vpnHandler, otherGroupHandler) // Start VPN for the first time in a blocking call // until the VPN is launched _, _ = vpnLooper.ApplyStatus(ctx, constants.Running) // TODO option to disable with variable - <-ctx.Done() + select { + case <-ctx.Done(): + err = portForwardLooper.Stop() + if err != nil { + logger.Error("stopping port forward loop: " + err.Error()) + } + case err := <-portForwardRunError: + logger.Errorf("port forwarding loop crashed: %s", err) + } return orderHandler.Shutdown(context.Background()) } diff --git a/internal/configuration/settings/portforward.go b/internal/configuration/settings/portforward.go index a6b60c9c..31b512c6 100644 --- a/internal/configuration/settings/portforward.go +++ b/internal/configuration/settings/portforward.go @@ -30,7 +30,7 @@ type PortForwarding struct { Filepath *string `json:"status_file_path"` } -func (p PortForwarding) validate(vpnProvider string) (err error) { +func (p PortForwarding) Validate(vpnProvider string) (err error) { if !*p.Enabled { return nil } @@ -59,7 +59,7 @@ func (p PortForwarding) validate(vpnProvider string) (err error) { return nil } -func (p *PortForwarding) copy() (copied PortForwarding) { +func (p *PortForwarding) Copy() (copied PortForwarding) { return PortForwarding{ Enabled: gosettings.CopyPointer(p.Enabled), Provider: gosettings.CopyPointer(p.Provider), @@ -73,7 +73,7 @@ func (p *PortForwarding) mergeWith(other PortForwarding) { p.Filepath = gosettings.MergeWithPointer(p.Filepath, other.Filepath) } -func (p *PortForwarding) overrideWith(other PortForwarding) { +func (p *PortForwarding) OverrideWith(other PortForwarding) { p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled) p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider) p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath) diff --git a/internal/configuration/settings/provider.go b/internal/configuration/settings/provider.go index 10b24d92..bbe4d38b 100644 --- a/internal/configuration/settings/provider.go +++ b/internal/configuration/settings/provider.go @@ -49,7 +49,7 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) { return fmt.Errorf("server selection: %w", err) } - err = p.PortForwarding.validate(*p.Name) + err = p.PortForwarding.Validate(*p.Name) if err != nil { return fmt.Errorf("port forwarding: %w", err) } @@ -61,7 +61,7 @@ func (p *Provider) copy() (copied Provider) { return Provider{ Name: gosettings.CopyPointer(p.Name), ServerSelection: p.ServerSelection.copy(), - PortForwarding: p.PortForwarding.copy(), + PortForwarding: p.PortForwarding.Copy(), } } @@ -74,7 +74,7 @@ func (p *Provider) mergeWith(other Provider) { func (p *Provider) overrideWith(other Provider) { p.Name = gosettings.OverrideWithPointer(p.Name, other.Name) p.ServerSelection.overrideWith(other.ServerSelection) - p.PortForwarding.overrideWith(other.PortForwarding) + p.PortForwarding.OverrideWith(other.PortForwarding) } func (p *Provider) setDefaults() { diff --git a/internal/portforward/firewall.go b/internal/portforward/firewall.go deleted file mode 100644 index 42e7b54e..00000000 --- a/internal/portforward/firewall.go +++ /dev/null @@ -1,32 +0,0 @@ -package portforward - -import "context" - -// firewallBlockPort obtains the state port thread safely and blocks -// it in the firewall if it is not the zero value (0). -func (l *Loop) firewallBlockPort(ctx context.Context) { - port := l.state.GetPortForwarded() - if port == 0 { - return - } - - err := l.portAllower.RemoveAllowedPort(ctx, port) - if err != nil { - l.logger.Error("cannot block previous port in firewall: " + err.Error()) - } -} - -// firewallAllowPort obtains the state port thread safely and allows -// it in the firewall if it is not the zero value (0). -func (l *Loop) firewallAllowPort(ctx context.Context) { - port := l.state.GetPortForwarded() - if port == 0 { - return - } - - startData := l.state.GetStartData() - err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface) - if err != nil { - l.logger.Error("cannot allow port: " + err.Error()) - } -} diff --git a/internal/portforward/fs.go b/internal/portforward/fs.go deleted file mode 100644 index 638a9127..00000000 --- a/internal/portforward/fs.go +++ /dev/null @@ -1,37 +0,0 @@ -package portforward - -import ( - "fmt" - "os" -) - -func (l *Loop) removePortForwardedFile() { - filepath := *l.state.GetSettings().Filepath - l.logger.Info("removing port file " + filepath) - if err := os.Remove(filepath); err != nil { - l.logger.Error(err.Error()) - } -} - -func (l *Loop) writePortForwardedFile(port uint16) { - filepath := *l.state.GetSettings().Filepath - l.logger.Info("writing port file " + filepath) - if err := writePortForwardedToFile(filepath, port, l.puid, l.pgid); err != nil { - l.logger.Error("writing port forwarded to file: " + err.Error()) - } -} - -func writePortForwardedToFile(filepath string, port uint16, uid, gid int) (err error) { - const perms = os.FileMode(0644) - err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms) - if err != nil { - return fmt.Errorf("writing file: %w", err) - } - - err = os.Chown(filepath, uid, gid) - if err != nil { - return fmt.Errorf("chowning file: %w", err) - } - - return nil -} diff --git a/internal/portforward/get.go b/internal/portforward/get.go deleted file mode 100644 index 4078ec56..00000000 --- a/internal/portforward/get.go +++ /dev/null @@ -1,5 +0,0 @@ -package portforward - -func (l *Loop) GetPortForwarded() (port uint16) { - return l.state.GetPortForwarded() -} diff --git a/internal/portforward/helpers.go b/internal/portforward/helpers.go deleted file mode 100644 index 5dd28654..00000000 --- a/internal/portforward/helpers.go +++ /dev/null @@ -1,22 +0,0 @@ -package portforward - -import ( - "context" - "time" -) - -func (l *Loop) logAndWait(ctx context.Context, err error) { - if err != nil { - l.logger.Error(err.Error()) - } - l.logger.Info("retrying in " + l.backoffTime.String()) - timer := time.NewTimer(l.backoffTime) - l.backoffTime *= 2 - select { - case <-timer.C: - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - } -} diff --git a/internal/portforward/interfaces.go b/internal/portforward/interfaces.go index a2e2ad5f..7811c733 100644 --- a/internal/portforward/interfaces.go +++ b/internal/portforward/interfaces.go @@ -1,10 +1,20 @@ package portforward -import ( - "context" -) +import "context" + +type Service interface { + Start(ctx context.Context) (runError <-chan error, err error) + Stop() (err error) + GetPortForwarded() (port uint16) +} type PortAllower interface { SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error) } + +type Logger interface { + Info(s string) + Warn(s string) + Error(s string) +} diff --git a/internal/portforward/logger.go b/internal/portforward/logger.go deleted file mode 100644 index f03ab200..00000000 --- a/internal/portforward/logger.go +++ /dev/null @@ -1,7 +0,0 @@ -package portforward - -type Logger interface { - Info(s string) - Warn(s string) - Error(s string) -} diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index 8fd3dc63..f8ff4cc2 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -1,64 +1,139 @@ package portforward import ( + "context" + "fmt" "net/http" "sync" - "time" "github.com/qdm12/gluetun/internal/configuration/settings" - "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/gluetun/internal/loopstate" - "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/portforward/state" + "github.com/qdm12/gluetun/internal/portforward/service" ) type Loop struct { - statusManager *loopstate.State - state *state.State - // Fixed parameters - puid int - pgid int - // Objects + // State + settings service.Settings + settingsMutex sync.RWMutex + service Service + // Fixed injected objets client *http.Client portAllower PortAllower logger Logger + // Fixed parameters + uid, gid int // Internal channels and locks - start chan struct{} - running chan models.LoopStatus - stop chan struct{} - stopped chan struct{} - startMu sync.Mutex - backoffTime time.Duration - userTrigger bool + // runCtx is used to detect when the loop has exited + // when performing an update + runCtx context.Context //nolint:containedctx + runCancel context.CancelFunc + updatedSignal chan<- struct{} + runDone <-chan struct{} } -const defaultBackoffTime = 5 * time.Second - func NewLoop(settings settings.PortForwarding, client *http.Client, portAllower PortAllower, - logger Logger, puid, pgid int) *Loop { - start := make(chan struct{}) - running := make(chan models.LoopStatus) - stop := make(chan struct{}) - stopped := make(chan struct{}) - - statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped) - state := state.New(statusManager, settings) - + logger Logger, uid, gid int) *Loop { return &Loop{ - statusManager: statusManager, - state: state, - puid: puid, - pgid: pgid, - // Objects + settings: service.Settings{ + UserSettings: settings, + }, client: client, portAllower: portAllower, logger: logger, - start: start, - running: running, - stop: stop, - stopped: stopped, - userTrigger: true, - backoffTime: defaultBackoffTime, + uid: uid, + gid: gid, } } + +func (l *Loop) Start(_ context.Context) (runError <-chan error, _ error) { + l.runCtx, l.runCancel = context.WithCancel(context.Background()) + runDone := make(chan struct{}) + l.runDone = runDone + + updatedSignal := make(chan struct{}) + l.updatedSignal = updatedSignal + runErrorCh := make(chan error) + + go l.run(l.runCtx, runDone, runErrorCh, updatedSignal) + + return runErrorCh, nil +} + +func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, + runErrorCh chan<- error, updatedSignal <-chan struct{}) { + defer close(runDone) + + var serviceRunError <-chan error + for { + select { + case <-runCtx.Done(): + // Stop call takes care of stopping the service + return + case <-updatedSignal: // first and subsequent start trigger + case err := <-serviceRunError: + runErrorCh <- err + return + } + + firstRun := l.service == nil + if !firstRun { + err := l.service.Stop() + if err != nil { + runErrorCh <- fmt.Errorf("stopping previous service: %w", err) + return + } + } + + l.settingsMutex.RLock() + l.service = service.New(l.settings, l.client, + l.portAllower, l.logger, l.uid, l.gid) + l.settingsMutex.RUnlock() + + var err error + serviceRunError, err = l.service.Start(runCtx) + if err != nil { + if runCtx.Err() == nil { // crashed but NOT stopped + runErrorCh <- fmt.Errorf("starting new service: %w", err) + } + return + } + } +} + +func (l *Loop) UpdateWith(partialUpdate service.Settings) (err error) { + l.settingsMutex.Lock() + l.settings, err = l.settings.UpdateWith(partialUpdate) + l.settingsMutex.Unlock() + if err != nil { + return err + } + + select { + case l.updatedSignal <- struct{}{}: + // Settings are validated and if the service fails to start + // or crashes at runtime, the loop will stop and signal its + // parent goroutine. Settings validation should be the only + // error feedback for the caller of `Update`. + return nil + case <-l.runCtx.Done(): + // loop has been stopped, no update can be done + return l.runCtx.Err() + } +} + +func (l *Loop) Stop() (err error) { + l.runCancel() + <-l.runDone + + if l.service != nil { + return l.service.Stop() + } + return nil +} + +func (l *Loop) GetPortForwarded() (port uint16) { + if l.service == nil { + return 0 + } + return l.service.GetPortForwarded() +} diff --git a/internal/portforward/run.go b/internal/portforward/run.go deleted file mode 100644 index bf0ab740..00000000 --- a/internal/portforward/run.go +++ /dev/null @@ -1,98 +0,0 @@ -package portforward - -import ( - "context" - "strconv" - - "github.com/qdm12/gluetun/internal/constants" -) - -func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { - defer close(done) - - select { - case <-l.start: // l.state.SetStartData called beforehand - case <-ctx.Done(): - return - } - - for ctx.Err() == nil { - pfCtx, pfCancel := context.WithCancel(ctx) - - portCh := make(chan uint16) - errorCh := make(chan error) - - startData := l.state.GetStartData() - - go func(ctx context.Context, startData StartData) { - port, err := startData.PortForwarder.PortForward(ctx, l.client, l.logger, - startData.Gateway, startData.ServerName) - if err != nil { - errorCh <- err - return - } - portCh <- port - - // Infinite loop - err = startData.PortForwarder.KeepPortForward(ctx, port, - startData.Gateway, startData.ServerName, l.logger) - errorCh <- err - }(pfCtx, startData) - - if l.userTrigger { - l.userTrigger = false - l.running <- constants.Running - } else { // crash - l.backoffTime = defaultBackoffTime - l.statusManager.SetStatus(constants.Running) - } - - stayHere := true - stopped := false - for stayHere { - select { - case <-ctx.Done(): - pfCancel() - if stopped { - return - } - <-errorCh - close(errorCh) - close(portCh) - l.removePortForwardedFile() - l.firewallBlockPort(ctx) - l.state.SetPortForwarded(0) - return - case <-l.start: - l.userTrigger = true - l.logger.Info("starting") - pfCancel() - stayHere = false - case <-l.stop: - l.userTrigger = true - l.logger.Info("stopping") - pfCancel() - <-errorCh - l.removePortForwardedFile() - l.firewallBlockPort(ctx) - l.state.SetPortForwarded(0) - l.stopped <- struct{}{} - stopped = true - case port := <-portCh: - l.logger.Info("port forwarded is " + strconv.Itoa(int(port))) - l.firewallBlockPort(ctx) - l.state.SetPortForwarded(port) - l.firewallAllowPort(ctx) - l.writePortForwardedFile(port) - case err := <-errorCh: - pfCancel() - close(errorCh) - close(portCh) - l.statusManager.SetStatus(constants.Crashed) - l.logAndWait(ctx, err) - stayHere = false - } - } - pfCancel() // for linting - } -} diff --git a/internal/portforward/service/fs.go b/internal/portforward/service/fs.go new file mode 100644 index 00000000..bc10098e --- /dev/null +++ b/internal/portforward/service/fs.go @@ -0,0 +1,23 @@ +package service + +import ( + "fmt" + "os" +) + +func (s *Service) writePortForwardedFile(port uint16) (err error) { + filepath := *s.settings.UserSettings.Filepath + s.logger.Info("writing port file " + filepath) + const perms = os.FileMode(0644) + err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms) + if err != nil { + return fmt.Errorf("writing file: %w", err) + } + + err = os.Chown(filepath, s.puid, s.pgid) + if err != nil { + return fmt.Errorf("chowning file: %w", err) + } + + return nil +} diff --git a/internal/portforward/service/interfaces.go b/internal/portforward/service/interfaces.go new file mode 100644 index 00000000..0265ea7f --- /dev/null +++ b/internal/portforward/service/interfaces.go @@ -0,0 +1,16 @@ +package service + +import ( + "context" +) + +type PortAllower interface { + SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) + RemoveAllowedPort(ctx context.Context, port uint16) (err error) +} + +type Logger interface { + Info(s string) + Warn(s string) + Error(s string) +} diff --git a/internal/portforward/service/service.go b/internal/portforward/service/service.go new file mode 100644 index 00000000..16cd07db --- /dev/null +++ b/internal/portforward/service/service.go @@ -0,0 +1,45 @@ +package service + +import ( + "context" + "net/http" + "sync" +) + +type Service struct { + // State + portMutex sync.RWMutex + port uint16 + // Fixed parameters + settings Settings + puid int + pgid int + // Fixed injected objets + client *http.Client + portAllower PortAllower + logger Logger + // Internal channels and locks + startStopMutex sync.Mutex + keepPortCancel context.CancelFunc + keepPortDoneCh <-chan struct{} +} + +func New(settings Settings, client *http.Client, + portAllower PortAllower, logger Logger, puid, pgid int) *Service { + return &Service{ + // Fixed parameters + settings: settings, + puid: puid, + pgid: pgid, + // Fixed injected objets + client: client, + portAllower: portAllower, + logger: logger, + } +} + +func (s *Service) GetPortForwarded() (port uint16) { + s.portMutex.RLock() + defer s.portMutex.RUnlock() + return s.port +} diff --git a/internal/portforward/service/settings.go b/internal/portforward/service/settings.go new file mode 100644 index 00000000..8f8ce108 --- /dev/null +++ b/internal/portforward/service/settings.go @@ -0,0 +1,79 @@ +package service + +import ( + "errors" + "fmt" + "net/netip" + + "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants/providers" + "github.com/qdm12/gluetun/internal/provider" + "github.com/qdm12/gosettings" +) + +type Settings struct { + UserSettings settings.PortForwarding + PortForwarder provider.PortForwarder + Gateway netip.Addr // needed for PIA and ProtonVPN + ServerName string // needed for PIA + Interface string // needed for PIA and ProtonVPN, tun0 for example + VPNProvider string // used to validate new settings +} + +// UpdateWith deep copies the receiving settings, overrides the copy with +// fields set in the partialUpdate argument, validates the new settings +// and returns them if they are valid, or returns an error otherwise. +// In all cases, the receiving settings are unmodified. +func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) { + updatedSettings = s.copy() + updatedSettings.overrideWith(partialUpdate) + err = updatedSettings.validate() + if err != nil { + return updatedSettings, fmt.Errorf("validating new settings: %w", err) + } + return updatedSettings, nil +} + +func (s Settings) copy() (copied Settings) { + copied.UserSettings = s.UserSettings.Copy() + copied.PortForwarder = s.PortForwarder + copied.Gateway = s.Gateway + copied.ServerName = s.ServerName + copied.Interface = s.Interface + copied.VPNProvider = s.VPNProvider + return copied +} + +func (s *Settings) overrideWith(update Settings) { + s.UserSettings.OverrideWith(update.UserSettings) + s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder) + s.Gateway = gosettings.OverrideWithValidator(s.Gateway, update.Gateway) + s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName) + s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface) + s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider) +} + +var ( + ErrVPNProviderNotSet = errors.New("VPN provider not set") + ErrServerNameNotSet = errors.New("server name not set") + ErrPortForwarderNotSet = errors.New("port forwarder not set") + ErrGatewayNotSet = errors.New("gateway not set") + ErrInterfaceNotSet = errors.New("interface not set") +) + +func (s *Settings) validate() (err error) { + switch { + case s.VPNProvider == "": + return fmt.Errorf("%w", ErrVPNProviderNotSet) + case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "": + return fmt.Errorf("%w", ErrServerNameNotSet) + case s.PortForwarder == nil: + return fmt.Errorf("%w", ErrPortForwarderNotSet) + case !s.Gateway.IsValid(): + return fmt.Errorf("%w", ErrGatewayNotSet) + case s.Interface == "": + return fmt.Errorf("%w", ErrInterfaceNotSet) + } + + return s.UserSettings.Validate(s.VPNProvider) +} diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go new file mode 100644 index 00000000..c50f5c53 --- /dev/null +++ b/internal/portforward/service/start.go @@ -0,0 +1,60 @@ +package service + +import ( + "context" + "fmt" +) + +func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) { + s.startStopMutex.Lock() + defer s.startStopMutex.Unlock() + + if !*s.settings.UserSettings.Enabled { + return nil, nil //nolint:nilnil + } + + s.logger.Info("starting") + port, err := s.settings.PortForwarder.PortForward(ctx, s.client, s.logger, + s.settings.Gateway, s.settings.ServerName) + if err != nil { + return nil, fmt.Errorf("port forwarding for the first time: %w", err) + } + + s.logger.Info("port forwarded is " + fmt.Sprint(int(port))) + + err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface) + if err != nil { + return nil, fmt.Errorf("allowing port in firewall: %w", err) + } + + err = s.writePortForwardedFile(port) + if err != nil { + _ = s.cleanup() + return nil, fmt.Errorf("writing port file: %w", err) + } + + s.portMutex.Lock() + s.port = port + s.portMutex.Unlock() + + keepPortCtx, keepPortCancel := context.WithCancel(context.Background()) + s.keepPortCancel = keepPortCancel + runErrorCh := make(chan error) + keepPortDoneCh := make(chan struct{}) + s.keepPortDoneCh = keepPortDoneCh + + go func(ctx context.Context, settings Settings, port uint16, + runError chan<- error, doneCh chan<- struct{}) { + defer close(doneCh) + err = settings.PortForwarder.KeepPortForward(ctx, port, + settings.Gateway, settings.ServerName, s.logger) + crashed := ctx.Err() == nil + if !crashed { // stopped by Stop call + return + } + _ = s.cleanup() + runError <- err + }(keepPortCtx, s.settings, port, runErrorCh, keepPortDoneCh) + + return runErrorCh, nil +} diff --git a/internal/portforward/service/stop.go b/internal/portforward/service/stop.go new file mode 100644 index 00000000..bd6e1871 --- /dev/null +++ b/internal/portforward/service/stop.go @@ -0,0 +1,47 @@ +package service + +import ( + "context" + "fmt" + "os" +) + +func (s *Service) Stop() (err error) { + s.startStopMutex.Lock() + defer s.startStopMutex.Unlock() + + s.portMutex.RLock() + serviceNotRunning := s.port == 0 + s.portMutex.RUnlock() + if serviceNotRunning { + return nil + } + + s.logger.Info("stopping") + + s.keepPortCancel() + <-s.keepPortDoneCh + + return s.cleanup() +} + +func (s *Service) cleanup() (err error) { + s.portMutex.Lock() + defer s.portMutex.Unlock() + + err = s.portAllower.RemoveAllowedPort(context.Background(), s.port) + if err != nil { + return fmt.Errorf("blocking previous port in firewall: %w", err) + } + + s.port = 0 + + filepath := *s.settings.UserSettings.Filepath + s.logger.Info("removing port file " + filepath) + err = os.Remove(filepath) + if err != nil { + return fmt.Errorf("removing port file: %w", err) + } + + return nil +} diff --git a/internal/portforward/settings.go b/internal/portforward/settings.go deleted file mode 100644 index a72d8a17..00000000 --- a/internal/portforward/settings.go +++ /dev/null @@ -1,16 +0,0 @@ -package portforward - -import ( - "context" - - "github.com/qdm12/gluetun/internal/configuration/settings" -) - -func (l *Loop) GetSettings() (settings settings.PortForwarding) { - return l.state.GetSettings() -} - -func (l *Loop) SetSettings(ctx context.Context, settings settings.PortForwarding) ( - outcome string) { - return l.state.SetSettings(ctx, settings) -} diff --git a/internal/portforward/state/portforwarded.go b/internal/portforward/state/portforwarded.go deleted file mode 100644 index f11b0aae..00000000 --- a/internal/portforward/state/portforwarded.go +++ /dev/null @@ -1,17 +0,0 @@ -package state - -// GetPortForwarded is used by the control HTTP server -// to obtain the port currently forwarded. -func (s *State) GetPortForwarded() (port uint16) { - s.portForwardedMu.RLock() - defer s.portForwardedMu.RUnlock() - return s.portForwarded -} - -// SetPortForwarded is only used from within the OpenVPN loop -// to set the port forwarded. -func (s *State) SetPortForwarded(port uint16) { - s.portForwardedMu.Lock() - defer s.portForwardedMu.Unlock() - s.portForwarded = port -} diff --git a/internal/portforward/state/settings.go b/internal/portforward/state/settings.go deleted file mode 100644 index fe6d629e..00000000 --- a/internal/portforward/state/settings.go +++ /dev/null @@ -1,49 +0,0 @@ -package state - -import ( - "context" - "os" - "reflect" - - "github.com/qdm12/gluetun/internal/configuration/settings" - "github.com/qdm12/gluetun/internal/constants" -) - -func (s *State) GetSettings() (settings settings.PortForwarding) { - s.settingsMu.RLock() - defer s.settingsMu.RUnlock() - return s.settings -} - -func (s *State) SetSettings(ctx context.Context, settings settings.PortForwarding) ( - outcome string) { - s.settingsMu.Lock() - - settingsUnchanged := reflect.DeepEqual(s.settings, settings) - if settingsUnchanged { - s.settingsMu.Unlock() - return "settings left unchanged" - } - - if s.settings.Filepath != settings.Filepath { - _ = os.Rename(*s.settings.Filepath, *settings.Filepath) - } - - newEnabled := *settings.Enabled - previousEnabled := *s.settings.Enabled - - s.settings = settings - s.settingsMu.Unlock() - - switch { - case !newEnabled && !previousEnabled: - case newEnabled && previousEnabled: - // no need to restart for now since we os.Rename the file here. - case newEnabled && !previousEnabled: - _, _ = s.statusApplier.ApplyStatus(ctx, constants.Running) - case !newEnabled && previousEnabled: - _, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped) - } - - return "settings updated" -} diff --git a/internal/portforward/state/startdata.go b/internal/portforward/state/startdata.go deleted file mode 100644 index dc3c779e..00000000 --- a/internal/portforward/state/startdata.go +++ /dev/null @@ -1,26 +0,0 @@ -package state - -import ( - "net/netip" - - "github.com/qdm12/gluetun/internal/provider" -) - -type StartData struct { - PortForwarder provider.PortForwarder - Gateway netip.Addr // needed for PIA - ServerName string // needed for PIA - Interface string // tun0 for example -} - -func (s *State) GetStartData() (startData StartData) { - s.startDataMu.RLock() - defer s.startDataMu.RUnlock() - return s.startData -} - -func (s *State) SetStartData(startData StartData) { - s.startDataMu.Lock() - defer s.startDataMu.Unlock() - s.startData = startData -} diff --git a/internal/portforward/state/state.go b/internal/portforward/state/state.go deleted file mode 100644 index 7ec3bd30..00000000 --- a/internal/portforward/state/state.go +++ /dev/null @@ -1,35 +0,0 @@ -package state - -import ( - "context" - "sync" - - "github.com/qdm12/gluetun/internal/configuration/settings" - "github.com/qdm12/gluetun/internal/models" -) - -func New(statusApplier StatusApplier, - settings settings.PortForwarding) *State { - return &State{ - statusApplier: statusApplier, - settings: settings, - } -} - -type State struct { - statusApplier StatusApplier - - settings settings.PortForwarding - settingsMu sync.RWMutex - - portForwarded uint16 - portForwardedMu sync.RWMutex - - startData StartData - startDataMu sync.RWMutex -} - -type StatusApplier interface { - ApplyStatus(ctx context.Context, status models.LoopStatus) ( - outcome string, err error) -} diff --git a/internal/portforward/status.go b/internal/portforward/status.go deleted file mode 100644 index 76c5136d..00000000 --- a/internal/portforward/status.go +++ /dev/null @@ -1,27 +0,0 @@ -package portforward - -import ( - "context" - - "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/portforward/state" -) - -func (l *Loop) GetStatus() (status models.LoopStatus) { - return l.statusManager.GetStatus() -} - -type StartData = state.StartData - -func (l *Loop) Start(ctx context.Context, data StartData) ( - outcome string, err error) { - l.startMu.Lock() - defer l.startMu.Unlock() - l.state.SetStartData(data) - return l.statusManager.ApplyStatus(ctx, constants.Running) -} - -func (l *Loop) Stop(ctx context.Context) (outcome string, err error) { - return l.statusManager.ApplyStatus(ctx, constants.Stopped) -} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 94584087..5b3e4efb 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -21,6 +21,7 @@ type Provider interface { } type PortForwarder interface { + Name() string PortForward(ctx context.Context, client *http.Client, logger utils.Logger, gateway netip.Addr, serverName string) ( port uint16, err error) diff --git a/internal/vpn/cleanup.go b/internal/vpn/cleanup.go index 3b5ac484..c0ea8959 100644 --- a/internal/vpn/cleanup.go +++ b/internal/vpn/cleanup.go @@ -2,14 +2,14 @@ package vpn import ( "context" - "time" + "errors" "github.com/qdm12/gluetun/internal/models" ) -func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) { +func (l *Loop) cleanup(vpnProvider string) { for _, vpnPort := range l.vpnInputPorts { - err := l.fw.RemoveAllowedPort(ctx, vpnPort) + err := l.fw.RemoveAllowedPort(context.Background(), vpnPort) if err != nil { l.logger.Error("cannot remove allowed input port from firewall: " + err.Error()) } @@ -17,11 +17,11 @@ func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) { l.publicip.SetData(models.PublicIP{}) // clear public IP address data - if pfEnabled { - const pfTimeout = 100 * time.Millisecond - err := l.stopPortForwarding(ctx, pfTimeout) - if err != nil { - l.logger.Error("cannot stop port forwarding: " + err.Error()) + err := l.stopPortForwarding(vpnProvider) + if err != nil { + portForwardingAlreadyStopped := errors.Is(err, context.Canceled) + if !portForwardingAlreadyStopped { + l.logger.Error("stopping port forwarding: " + err.Error()) } } } diff --git a/internal/vpn/helpers.go b/internal/vpn/helpers.go index 851fb230..a6b292d6 100644 --- a/internal/vpn/helpers.go +++ b/internal/vpn/helpers.go @@ -8,6 +8,8 @@ import ( "github.com/qdm12/gluetun/internal/models" ) +func ptrTo[T any](value T) *T { return &value } + // waitForError waits 100ms for an error in the waitError channel. func (l *Loop) waitForError(ctx context.Context, waitError chan error) (err error) { diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 3d05aac2..792404d3 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -7,7 +7,7 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/netlink" - "github.com/qdm12/gluetun/internal/portforward" + portforward "github.com/qdm12/gluetun/internal/portforward/service" "github.com/qdm12/gluetun/internal/provider" ) @@ -22,8 +22,7 @@ type Routing interface { } type PortForward interface { - Start(ctx context.Context, data portforward.StartData) (outcome string, err error) - Stop(ctx context.Context) (outcome string, err error) + UpdateWith(settings portforward.Settings) (err error) } type OpenVPN interface { diff --git a/internal/vpn/portforward.go b/internal/vpn/portforward.go index e2294545..d05d1f8e 100644 --- a/internal/vpn/portforward.go +++ b/internal/vpn/portforward.go @@ -1,47 +1,35 @@ package vpn import ( - "context" "fmt" - "time" - "github.com/qdm12/gluetun/internal/portforward" + "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/portforward/service" ) -func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) { - if !data.portForwarding { - return nil - } - - // only used for PIA for now +func (l *Loop) startPortForwarding(data tunnelUpData) (err error) { gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf) if err != nil { return fmt.Errorf("obtaining VPN local gateway IP for interface %s: %w", data.vpnIntf, err) } l.logger.Info("VPN gateway IP address: " + gateway.String()) - pfData := portforward.StartData{ + partialUpdate := service.Settings{ PortForwarder: data.portForwarder, Gateway: gateway, - ServerName: data.serverName, Interface: data.vpnIntf, + ServerName: data.serverName, + VPNProvider: data.portForwarder.Name(), } - _, err = l.portForward.Start(ctx, pfData) - if err != nil { - return fmt.Errorf("starting port forwarding: %w", err) - } - - return nil + return l.portForward.UpdateWith(partialUpdate) } -func (l *Loop) stopPortForwarding(ctx context.Context, - timeout time.Duration) (err error) { - if timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() +func (l *Loop) stopPortForwarding(vpnProvider string) (err error) { + partialUpdate := service.Settings{ + VPNProvider: vpnProvider, + UserSettings: settings.PortForwarding{ + Enabled: ptrTo(false), + }, } - - _, err = l.portForward.Stop(ctx) - return err + return l.portForward.UpdateWith(partialUpdate) } diff --git a/internal/vpn/run.go b/internal/vpn/run.go index 082cf9ba..ca279c4b 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -22,10 +22,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { providerConf := l.providers.Get(*settings.Provider.Name) - portForwarding := *settings.Provider.PortForwarding.Enabled customPortForwardingProvider := *settings.Provider.PortForwarding.Provider portForwader := providerConf - if portForwarding && customPortForwardingProvider != "" { + if customPortForwardingProvider != "" { portForwader = l.providers.Get(customPortForwardingProvider) } @@ -49,10 +48,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { continue } tunnelUpData := tunnelUpData{ - portForwarding: portForwarding, - serverName: serverName, - portForwarder: portForwader, - vpnIntf: vpnInterface, + serverName: serverName, + portForwarder: portForwader, + vpnIntf: vpnInterface, } openvpnCtx, openvpnCancel := context.WithCancel(context.Background()) @@ -76,7 +74,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { case <-tunnelReady: go l.onTunnelUp(openvpnCtx, tunnelUpData) case <-ctx.Done(): - l.cleanup(context.Background(), portForwarding) + l.cleanup(portForwader.Name()) openvpnCancel() <-waitError close(waitError) @@ -84,7 +82,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { case <-l.stop: l.userTrigger = true l.logger.Info("stopping") - l.cleanup(context.Background(), portForwarding) + l.cleanup(portForwader.Name()) openvpnCancel() <-waitError // do not close waitError or the waitError @@ -97,7 +95,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { case err := <-waitError: // unexpected error l.statusManager.Lock() // prevent SetStatus from running in parallel - l.cleanup(context.Background(), portForwarding) + l.cleanup(portForwader.Name()) openvpnCancel() l.statusManager.SetStatus(constants.Crashed) l.logAndWait(ctx, err) diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 366d5a28..e58c7c80 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -10,10 +10,9 @@ import ( type tunnelUpData struct { // Port forwarding - portForwarding bool - vpnIntf string - serverName string - portForwarder provider.PortForwarder + vpnIntf string + serverName string + portForwarder provider.PortForwarder } func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { @@ -42,7 +41,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { } } - err := l.startPortForwarding(ctx, data) + err := l.startPortForwarding(data) if err != nil { l.logger.Error(err.Error()) }