diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index 266478c1..c819317a 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -12,7 +12,7 @@ import ( type Loop struct { // State - settings service.Settings + settings Settings settingsMutex sync.RWMutex service Service // Fixed injected objets @@ -28,7 +28,7 @@ type Loop struct { runCtx context.Context //nolint:containedctx runCancel context.CancelFunc runDone <-chan struct{} - updateTrigger chan<- service.Settings + updateTrigger chan<- Settings updatedResult <-chan error } @@ -36,8 +36,12 @@ func NewLoop(settings settings.PortForwarding, routing Routing, client *http.Client, portAllower PortAllower, logger Logger, uid, gid int) *Loop { return &Loop{ - settings: service.Settings{ - UserSettings: settings, + settings: Settings{ + VPNIsUp: ptrTo(false), + Service: service.Settings{ + Enabled: settings.Enabled, + Filepath: *settings.Filepath, + }, }, routing: routing, client: client, @@ -57,24 +61,22 @@ func (l *Loop) Start(_ context.Context) (runError <-chan error, _ error) { runDone := make(chan struct{}) l.runDone = runDone - updateTrigger := make(chan service.Settings) + updateTrigger := make(chan Settings) l.updateTrigger = updateTrigger updateResult := make(chan error) l.updatedResult = updateResult runErrorCh := make(chan error) - go l.run(l.runCtx, runDone, runErrorCh, - l.settings, updateTrigger, updateResult) + go l.run(l.runCtx, runDone, runErrorCh, updateTrigger, updateResult) return runErrorCh, nil } func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, - runErrorCh chan<- error, initialSettings service.Settings, - updateTrigger <-chan service.Settings, updateResult chan<- error) { + runErrorCh chan<- error, updateTrigger <-chan Settings, + updateResult chan<- error) { defer close(runDone) - settings := initialSettings var serviceRunError <-chan error for { updateReceived := false @@ -83,18 +85,20 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, // Stop call takes care of stopping the service return case partialUpdate := <-updateTrigger: - updatedSettings, err := settings.UpdateWith(partialUpdate) + updatedSettings, err := l.settings.updateWith(partialUpdate) if err != nil { updateResult <- err continue } - settings = updatedSettings updateReceived = true + l.settingsMutex.Lock() + l.settings = updatedSettings + l.settingsMutex.Unlock() case err := <-serviceRunError: l.logger.Error(err.Error()) } - firstRun := l.service == nil + firstRun := serviceRunError == nil if !firstRun { err := l.service.Stop() if err != nil { @@ -103,7 +107,11 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, } } - l.service = service.New(settings, l.routing, l.client, + serviceSettings := l.settings.Service.Copy() + // Only enable port forward if the VPN tunnel is up + *serviceSettings.Enabled = *serviceSettings.Enabled && *l.settings.VPNIsUp + + l.service = service.New(serviceSettings, l.routing, l.client, l.portAllower, l.logger, l.uid, l.gid) var err error @@ -119,16 +127,10 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, } return } - - // Service is created and started successfully, so update - // the settings for external calls such as GetSettings. - l.settingsMutex.Lock() - l.settings = settings - l.settingsMutex.Unlock() } } -func (l *Loop) UpdateWith(partialUpdate service.Settings) (err error) { +func (l *Loop) UpdateWith(partialUpdate Settings) (err error) { select { case l.updateTrigger <- partialUpdate: select { @@ -159,3 +161,7 @@ func (l *Loop) GetPortForwarded() (port uint16) { } return l.service.GetPortForwarded() } + +func ptrTo[T any](value T) *T { + return &value +} diff --git a/internal/portforward/service/fs.go b/internal/portforward/service/fs.go index bc10098e..ed24fdf1 100644 --- a/internal/portforward/service/fs.go +++ b/internal/portforward/service/fs.go @@ -6,7 +6,7 @@ import ( ) func (s *Service) writePortForwardedFile(port uint16) (err error) { - filepath := *s.settings.UserSettings.Filepath + filepath := s.settings.Filepath s.logger.Info("writing port file " + filepath) const perms = os.FileMode(0644) err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms) diff --git a/internal/portforward/service/settings.go b/internal/portforward/service/settings.go index 43ebb118..78f9b9a3 100644 --- a/internal/portforward/service/settings.go +++ b/internal/portforward/service/settings.go @@ -4,69 +4,51 @@ import ( "errors" "fmt" - "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gosettings" ) type Settings struct { - UserSettings settings.PortForwarding + Enabled *bool PortForwarder PortForwarder + Filepath string Interface string // needed for PIA and ProtonVPN, tun0 for example ServerName string // needed for PIA - VPNProvider string // used to validate new settings } -// UpdateWith deep copies the receiving settings, overrides the copy with -// fields set in the partialUpdate argument, validates the new settings -// and returns them if they are valid, or returns an error otherwise. -// In all cases, the receiving settings are unmodified. -func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) { - updatedSettings = s.copy() - updatedSettings.overrideWith(partialUpdate) - err = updatedSettings.validate() - if err != nil { - return updatedSettings, fmt.Errorf("validating new settings: %w", err) - } - return updatedSettings, nil -} - -func (s Settings) copy() (copied Settings) { - copied.UserSettings = s.UserSettings.Copy() +func (s Settings) Copy() (copied Settings) { + copied.Enabled = gosettings.CopyPointer(s.Enabled) copied.PortForwarder = s.PortForwarder + copied.Filepath = s.Filepath copied.Interface = s.Interface copied.ServerName = s.ServerName - copied.VPNProvider = s.VPNProvider return copied } -func (s *Settings) overrideWith(update Settings) { - s.UserSettings.OverrideWith(update.UserSettings) +func (s *Settings) OverrideWith(update Settings) { + s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled) s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder) + s.Filepath = gosettings.OverrideWithString(s.Filepath, update.Filepath) s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface) s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName) - s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider) } var ( - ErrVPNProviderNotSet = errors.New("VPN provider not set") - ErrServerNameNotSet = errors.New("server name not set") - ErrPortForwarderNotSet = errors.New("port forwarder not set") - ErrGatewayNotSet = errors.New("gateway not set") - ErrInterfaceNotSet = errors.New("interface not set") + ErrServerNameNotSet = errors.New("server name not set") + ErrFilepathNotSet = errors.New("file path not set") + ErrInterfaceNotSet = errors.New("interface not set") ) -func (s *Settings) validate() (err error) { +func (s *Settings) Validate() (err error) { switch { - case s.VPNProvider == "": - return fmt.Errorf("%w", ErrVPNProviderNotSet) - case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "": - return fmt.Errorf("%w", ErrServerNameNotSet) - case s.PortForwarder == nil: - return fmt.Errorf("%w", ErrPortForwarderNotSet) + // Port forwarder can be nil when the loop updates + // to stop the service. + case s.Filepath == "": + return fmt.Errorf("%w", ErrFilepathNotSet) case s.Interface == "": return fmt.Errorf("%w", ErrInterfaceNotSet) + case s.PortForwarder.Name() == providers.PrivateInternetAccess && s.ServerName == "": + return fmt.Errorf("%w", ErrServerNameNotSet) } - - return s.UserSettings.Validate(s.VPNProvider) + return nil } diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index 3afa0311..a938495d 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -11,7 +11,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) s.startStopMutex.Lock() defer s.startStopMutex.Unlock() - if !*s.settings.UserSettings.Enabled { + if !*s.settings.Enabled { return nil, nil //nolint:nilnil } @@ -64,6 +64,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) if !crashed { // stopped by Stop call return } + s.startStopMutex.Lock() + defer s.startStopMutex.Unlock() _ = s.cleanup() runError <- err }(keepPortCtx, s.settings.PortForwarder, obj, runErrorCh, keepPortDoneCh) diff --git a/internal/portforward/service/stop.go b/internal/portforward/service/stop.go index bd6e1871..cb7f8efa 100644 --- a/internal/portforward/service/stop.go +++ b/internal/portforward/service/stop.go @@ -14,6 +14,7 @@ func (s *Service) Stop() (err error) { serviceNotRunning := s.port == 0 s.portMutex.RUnlock() if serviceNotRunning { + // TODO replace with goservices.ErrAlreadyStopped return nil } @@ -36,7 +37,7 @@ func (s *Service) cleanup() (err error) { s.port = 0 - filepath := *s.settings.UserSettings.Filepath + filepath := s.settings.Filepath s.logger.Info("removing port file " + filepath) err = os.Remove(filepath) if err != nil { diff --git a/internal/portforward/settings.go b/internal/portforward/settings.go new file mode 100644 index 00000000..6ac7d582 --- /dev/null +++ b/internal/portforward/settings.go @@ -0,0 +1,43 @@ +package portforward + +import ( + "github.com/qdm12/gluetun/internal/portforward/service" + "github.com/qdm12/gosettings" +) + +type Settings struct { + // VPNIsUp can be optionally set to signal the loop + // the VPN is up (true) or down (false). If left to nil, + // it is assumed the VPN is in the same previous state. + VPNIsUp *bool + Service service.Settings +} + +// updateWith deep copies the receiving settings, overrides the copy with +// fields set in the partialUpdate argument, validates the new settings +// and returns them if they are valid, or returns an error otherwise. +// In all cases, the receiving settings are unmodified. +func (s Settings) updateWith(partialUpdate Settings) (updated Settings, err error) { + updated = s.copy() + updated.overrideWith(partialUpdate) + err = updated.validate() + if err != nil { + return updated, err + } + return updated, nil +} + +func (s Settings) copy() (copied Settings) { + copied.VPNIsUp = gosettings.CopyPointer(s.VPNIsUp) + copied.Service = s.Service.Copy() + return copied +} + +func (s *Settings) overrideWith(update Settings) { + s.VPNIsUp = gosettings.OverrideWithPointer(s.VPNIsUp, update.VPNIsUp) + s.Service.OverrideWith(update.Service) +} + +func (s Settings) validate() (err error) { + return s.Service.Validate() +} diff --git a/internal/vpn/cleanup.go b/internal/vpn/cleanup.go index 4319577e..2b5c0919 100644 --- a/internal/vpn/cleanup.go +++ b/internal/vpn/cleanup.go @@ -5,7 +5,7 @@ import ( "errors" ) -func (l *Loop) cleanup(vpnProvider string) { +func (l *Loop) cleanup() { for _, vpnPort := range l.vpnInputPorts { err := l.fw.RemoveAllowedPort(context.Background(), vpnPort) if err != nil { @@ -18,7 +18,7 @@ func (l *Loop) cleanup(vpnProvider string) { l.logger.Error("clearing public IP data: " + err.Error()) } - err = l.stopPortForwarding(vpnProvider) + err = l.stopPortForwarding() if err != nil { portForwardingAlreadyStopped := errors.Is(err, context.Canceled) if !portForwardingAlreadyStopped { diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 4b84b7b6..5ca8eca9 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -7,7 +7,7 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/netlink" - portforward "github.com/qdm12/gluetun/internal/portforward/service" + portforward "github.com/qdm12/gluetun/internal/portforward" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider/utils" ) diff --git a/internal/vpn/portforward.go b/internal/vpn/portforward.go index 6be92029..f6f6b117 100644 --- a/internal/vpn/portforward.go +++ b/internal/vpn/portforward.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/portforward" "github.com/qdm12/gluetun/internal/portforward/service" pfutils "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -23,21 +23,20 @@ func getPortForwarder(provider Provider, providers Providers, //nolint:ireturn } func (l *Loop) startPortForwarding(data tunnelUpData) (err error) { - partialUpdate := service.Settings{ - PortForwarder: data.portForwarder, - Interface: data.vpnIntf, - ServerName: data.serverName, - VPNProvider: data.portForwarder.Name(), + partialUpdate := portforward.Settings{ + VPNIsUp: ptrTo(true), + Service: service.Settings{ + PortForwarder: data.portForwarder, + Interface: data.vpnIntf, + ServerName: data.serverName, + }, } return l.portForward.UpdateWith(partialUpdate) } -func (l *Loop) stopPortForwarding(vpnProvider string) (err error) { - partialUpdate := service.Settings{ - VPNProvider: vpnProvider, - UserSettings: settings.PortForwarding{ - Enabled: ptrTo(false), - }, +func (l *Loop) stopPortForwarding() (err error) { + partialUpdate := portforward.Settings{ + VPNIsUp: ptrTo(false), } return l.portForward.UpdateWith(partialUpdate) } diff --git a/internal/vpn/run.go b/internal/vpn/run.go index c8e7ff21..1666d74a 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -71,7 +71,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { case <-tunnelReady: go l.onTunnelUp(openvpnCtx, tunnelUpData) case <-ctx.Done(): - l.cleanup(portForwarder.Name()) + l.cleanup() openvpnCancel() <-waitError close(waitError) @@ -79,7 +79,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { case <-l.stop: l.userTrigger = true l.logger.Info("stopping") - l.cleanup(portForwarder.Name()) + l.cleanup() openvpnCancel() <-waitError // do not close waitError or the waitError @@ -92,7 +92,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { case err := <-waitError: // unexpected error l.statusManager.Lock() // prevent SetStatus from running in parallel - l.cleanup(portForwarder.Name()) + l.cleanup() openvpnCancel() l.statusManager.SetStatus(constants.Crashed) l.logAndWait(ctx, err)