fix(portforward): trigger after VPN restart

This commit is contained in:
Quentin McGaw
2023-09-28 14:00:58 +00:00
parent a194906bdd
commit d4df87286e
10 changed files with 112 additions and 79 deletions

View File

@@ -12,7 +12,7 @@ import (
type Loop struct { type Loop struct {
// State // State
settings service.Settings settings Settings
settingsMutex sync.RWMutex settingsMutex sync.RWMutex
service Service service Service
// Fixed injected objets // Fixed injected objets
@@ -28,7 +28,7 @@ type Loop struct {
runCtx context.Context //nolint:containedctx runCtx context.Context //nolint:containedctx
runCancel context.CancelFunc runCancel context.CancelFunc
runDone <-chan struct{} runDone <-chan struct{}
updateTrigger chan<- service.Settings updateTrigger chan<- Settings
updatedResult <-chan error updatedResult <-chan error
} }
@@ -36,8 +36,12 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
client *http.Client, portAllower PortAllower, client *http.Client, portAllower PortAllower,
logger Logger, uid, gid int) *Loop { logger Logger, uid, gid int) *Loop {
return &Loop{ return &Loop{
settings: service.Settings{ settings: Settings{
UserSettings: settings, VPNIsUp: ptrTo(false),
Service: service.Settings{
Enabled: settings.Enabled,
Filepath: *settings.Filepath,
},
}, },
routing: routing, routing: routing,
client: client, client: client,
@@ -57,24 +61,22 @@ func (l *Loop) Start(_ context.Context) (runError <-chan error, _ error) {
runDone := make(chan struct{}) runDone := make(chan struct{})
l.runDone = runDone l.runDone = runDone
updateTrigger := make(chan service.Settings) updateTrigger := make(chan Settings)
l.updateTrigger = updateTrigger l.updateTrigger = updateTrigger
updateResult := make(chan error) updateResult := make(chan error)
l.updatedResult = updateResult l.updatedResult = updateResult
runErrorCh := make(chan error) runErrorCh := make(chan error)
go l.run(l.runCtx, runDone, runErrorCh, go l.run(l.runCtx, runDone, runErrorCh, updateTrigger, updateResult)
l.settings, updateTrigger, updateResult)
return runErrorCh, nil return runErrorCh, nil
} }
func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
runErrorCh chan<- error, initialSettings service.Settings, runErrorCh chan<- error, updateTrigger <-chan Settings,
updateTrigger <-chan service.Settings, updateResult chan<- error) { updateResult chan<- error) {
defer close(runDone) defer close(runDone)
settings := initialSettings
var serviceRunError <-chan error var serviceRunError <-chan error
for { for {
updateReceived := false updateReceived := false
@@ -83,18 +85,20 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
// Stop call takes care of stopping the service // Stop call takes care of stopping the service
return return
case partialUpdate := <-updateTrigger: case partialUpdate := <-updateTrigger:
updatedSettings, err := settings.UpdateWith(partialUpdate) updatedSettings, err := l.settings.updateWith(partialUpdate)
if err != nil { if err != nil {
updateResult <- err updateResult <- err
continue continue
} }
settings = updatedSettings
updateReceived = true updateReceived = true
l.settingsMutex.Lock()
l.settings = updatedSettings
l.settingsMutex.Unlock()
case err := <-serviceRunError: case err := <-serviceRunError:
l.logger.Error(err.Error()) l.logger.Error(err.Error())
} }
firstRun := l.service == nil firstRun := serviceRunError == nil
if !firstRun { if !firstRun {
err := l.service.Stop() err := l.service.Stop()
if err != nil { if err != nil {
@@ -103,7 +107,11 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
} }
} }
l.service = service.New(settings, l.routing, l.client, serviceSettings := l.settings.Service.Copy()
// Only enable port forward if the VPN tunnel is up
*serviceSettings.Enabled = *serviceSettings.Enabled && *l.settings.VPNIsUp
l.service = service.New(serviceSettings, l.routing, l.client,
l.portAllower, l.logger, l.uid, l.gid) l.portAllower, l.logger, l.uid, l.gid)
var err error var err error
@@ -119,16 +127,10 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
} }
return return
} }
// Service is created and started successfully, so update
// the settings for external calls such as GetSettings.
l.settingsMutex.Lock()
l.settings = settings
l.settingsMutex.Unlock()
} }
} }
func (l *Loop) UpdateWith(partialUpdate service.Settings) (err error) { func (l *Loop) UpdateWith(partialUpdate Settings) (err error) {
select { select {
case l.updateTrigger <- partialUpdate: case l.updateTrigger <- partialUpdate:
select { select {
@@ -159,3 +161,7 @@ func (l *Loop) GetPortForwarded() (port uint16) {
} }
return l.service.GetPortForwarded() return l.service.GetPortForwarded()
} }
func ptrTo[T any](value T) *T {
return &value
}

View File

@@ -6,7 +6,7 @@ import (
) )
func (s *Service) writePortForwardedFile(port uint16) (err error) { func (s *Service) writePortForwardedFile(port uint16) (err error) {
filepath := *s.settings.UserSettings.Filepath filepath := s.settings.Filepath
s.logger.Info("writing port file " + filepath) s.logger.Info("writing port file " + filepath)
const perms = os.FileMode(0644) const perms = os.FileMode(0644)
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms) err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)

View File

@@ -4,69 +4,51 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gosettings" "github.com/qdm12/gosettings"
) )
type Settings struct { type Settings struct {
UserSettings settings.PortForwarding Enabled *bool
PortForwarder PortForwarder PortForwarder PortForwarder
Filepath string
Interface string // needed for PIA and ProtonVPN, tun0 for example Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA ServerName string // needed for PIA
VPNProvider string // used to validate new settings
} }
// UpdateWith deep copies the receiving settings, overrides the copy with func (s Settings) Copy() (copied Settings) {
// fields set in the partialUpdate argument, validates the new settings copied.Enabled = gosettings.CopyPointer(s.Enabled)
// and returns them if they are valid, or returns an error otherwise.
// In all cases, the receiving settings are unmodified.
func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) {
updatedSettings = s.copy()
updatedSettings.overrideWith(partialUpdate)
err = updatedSettings.validate()
if err != nil {
return updatedSettings, fmt.Errorf("validating new settings: %w", err)
}
return updatedSettings, nil
}
func (s Settings) copy() (copied Settings) {
copied.UserSettings = s.UserSettings.Copy()
copied.PortForwarder = s.PortForwarder copied.PortForwarder = s.PortForwarder
copied.Filepath = s.Filepath
copied.Interface = s.Interface copied.Interface = s.Interface
copied.ServerName = s.ServerName copied.ServerName = s.ServerName
copied.VPNProvider = s.VPNProvider
return copied return copied
} }
func (s *Settings) overrideWith(update Settings) { func (s *Settings) OverrideWith(update Settings) {
s.UserSettings.OverrideWith(update.UserSettings) s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled)
s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder) s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder)
s.Filepath = gosettings.OverrideWithString(s.Filepath, update.Filepath)
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface) s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName) s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider)
} }
var ( var (
ErrVPNProviderNotSet = errors.New("VPN provider not set")
ErrServerNameNotSet = errors.New("server name not set") ErrServerNameNotSet = errors.New("server name not set")
ErrPortForwarderNotSet = errors.New("port forwarder not set") ErrFilepathNotSet = errors.New("file path not set")
ErrGatewayNotSet = errors.New("gateway not set")
ErrInterfaceNotSet = errors.New("interface not set") ErrInterfaceNotSet = errors.New("interface not set")
) )
func (s *Settings) validate() (err error) { func (s *Settings) Validate() (err error) {
switch { switch {
case s.VPNProvider == "": // Port forwarder can be nil when the loop updates
return fmt.Errorf("%w", ErrVPNProviderNotSet) // to stop the service.
case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "": case s.Filepath == "":
return fmt.Errorf("%w", ErrServerNameNotSet) return fmt.Errorf("%w", ErrFilepathNotSet)
case s.PortForwarder == nil:
return fmt.Errorf("%w", ErrPortForwarderNotSet)
case s.Interface == "": case s.Interface == "":
return fmt.Errorf("%w", ErrInterfaceNotSet) return fmt.Errorf("%w", ErrInterfaceNotSet)
case s.PortForwarder.Name() == providers.PrivateInternetAccess && s.ServerName == "":
return fmt.Errorf("%w", ErrServerNameNotSet)
} }
return nil
return s.UserSettings.Validate(s.VPNProvider)
} }

View File

@@ -11,7 +11,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
s.startStopMutex.Lock() s.startStopMutex.Lock()
defer s.startStopMutex.Unlock() defer s.startStopMutex.Unlock()
if !*s.settings.UserSettings.Enabled { if !*s.settings.Enabled {
return nil, nil //nolint:nilnil return nil, nil //nolint:nilnil
} }
@@ -64,6 +64,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
if !crashed { // stopped by Stop call if !crashed { // stopped by Stop call
return return
} }
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
_ = s.cleanup() _ = s.cleanup()
runError <- err runError <- err
}(keepPortCtx, s.settings.PortForwarder, obj, runErrorCh, keepPortDoneCh) }(keepPortCtx, s.settings.PortForwarder, obj, runErrorCh, keepPortDoneCh)

View File

@@ -14,6 +14,7 @@ func (s *Service) Stop() (err error) {
serviceNotRunning := s.port == 0 serviceNotRunning := s.port == 0
s.portMutex.RUnlock() s.portMutex.RUnlock()
if serviceNotRunning { if serviceNotRunning {
// TODO replace with goservices.ErrAlreadyStopped
return nil return nil
} }
@@ -36,7 +37,7 @@ func (s *Service) cleanup() (err error) {
s.port = 0 s.port = 0
filepath := *s.settings.UserSettings.Filepath filepath := s.settings.Filepath
s.logger.Info("removing port file " + filepath) s.logger.Info("removing port file " + filepath)
err = os.Remove(filepath) err = os.Remove(filepath)
if err != nil { if err != nil {

View File

@@ -0,0 +1,43 @@
package portforward
import (
"github.com/qdm12/gluetun/internal/portforward/service"
"github.com/qdm12/gosettings"
)
type Settings struct {
// VPNIsUp can be optionally set to signal the loop
// the VPN is up (true) or down (false). If left to nil,
// it is assumed the VPN is in the same previous state.
VPNIsUp *bool
Service service.Settings
}
// updateWith deep copies the receiving settings, overrides the copy with
// fields set in the partialUpdate argument, validates the new settings
// and returns them if they are valid, or returns an error otherwise.
// In all cases, the receiving settings are unmodified.
func (s Settings) updateWith(partialUpdate Settings) (updated Settings, err error) {
updated = s.copy()
updated.overrideWith(partialUpdate)
err = updated.validate()
if err != nil {
return updated, err
}
return updated, nil
}
func (s Settings) copy() (copied Settings) {
copied.VPNIsUp = gosettings.CopyPointer(s.VPNIsUp)
copied.Service = s.Service.Copy()
return copied
}
func (s *Settings) overrideWith(update Settings) {
s.VPNIsUp = gosettings.OverrideWithPointer(s.VPNIsUp, update.VPNIsUp)
s.Service.OverrideWith(update.Service)
}
func (s Settings) validate() (err error) {
return s.Service.Validate()
}

View File

@@ -5,7 +5,7 @@ import (
"errors" "errors"
) )
func (l *Loop) cleanup(vpnProvider string) { func (l *Loop) cleanup() {
for _, vpnPort := range l.vpnInputPorts { for _, vpnPort := range l.vpnInputPorts {
err := l.fw.RemoveAllowedPort(context.Background(), vpnPort) err := l.fw.RemoveAllowedPort(context.Background(), vpnPort)
if err != nil { if err != nil {
@@ -18,7 +18,7 @@ func (l *Loop) cleanup(vpnProvider string) {
l.logger.Error("clearing public IP data: " + err.Error()) l.logger.Error("clearing public IP data: " + err.Error())
} }
err = l.stopPortForwarding(vpnProvider) err = l.stopPortForwarding()
if err != nil { if err != nil {
portForwardingAlreadyStopped := errors.Is(err, context.Canceled) portForwardingAlreadyStopped := errors.Is(err, context.Canceled)
if !portForwardingAlreadyStopped { if !portForwardingAlreadyStopped {

View File

@@ -7,7 +7,7 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
portforward "github.com/qdm12/gluetun/internal/portforward/service" portforward "github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )

View File

@@ -5,7 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/portforward/service" "github.com/qdm12/gluetun/internal/portforward/service"
pfutils "github.com/qdm12/gluetun/internal/provider/utils" pfutils "github.com/qdm12/gluetun/internal/provider/utils"
) )
@@ -23,21 +23,20 @@ func getPortForwarder(provider Provider, providers Providers, //nolint:ireturn
} }
func (l *Loop) startPortForwarding(data tunnelUpData) (err error) { func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
partialUpdate := service.Settings{ partialUpdate := portforward.Settings{
VPNIsUp: ptrTo(true),
Service: service.Settings{
PortForwarder: data.portForwarder, PortForwarder: data.portForwarder,
Interface: data.vpnIntf, Interface: data.vpnIntf,
ServerName: data.serverName, ServerName: data.serverName,
VPNProvider: data.portForwarder.Name(), },
} }
return l.portForward.UpdateWith(partialUpdate) return l.portForward.UpdateWith(partialUpdate)
} }
func (l *Loop) stopPortForwarding(vpnProvider string) (err error) { func (l *Loop) stopPortForwarding() (err error) {
partialUpdate := service.Settings{ partialUpdate := portforward.Settings{
VPNProvider: vpnProvider, VPNIsUp: ptrTo(false),
UserSettings: settings.PortForwarding{
Enabled: ptrTo(false),
},
} }
return l.portForward.UpdateWith(partialUpdate) return l.portForward.UpdateWith(partialUpdate)
} }

View File

@@ -71,7 +71,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-tunnelReady: case <-tunnelReady:
go l.onTunnelUp(openvpnCtx, tunnelUpData) go l.onTunnelUp(openvpnCtx, tunnelUpData)
case <-ctx.Done(): case <-ctx.Done():
l.cleanup(portForwarder.Name()) l.cleanup()
openvpnCancel() openvpnCancel()
<-waitError <-waitError
close(waitError) close(waitError)
@@ -79,7 +79,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-l.stop: case <-l.stop:
l.userTrigger = true l.userTrigger = true
l.logger.Info("stopping") l.logger.Info("stopping")
l.cleanup(portForwarder.Name()) l.cleanup()
openvpnCancel() openvpnCancel()
<-waitError <-waitError
// do not close waitError or the waitError // do not close waitError or the waitError
@@ -92,7 +92,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case err := <-waitError: // unexpected error case err := <-waitError: // unexpected error
l.statusManager.Lock() // prevent SetStatus from running in parallel l.statusManager.Lock() // prevent SetStatus from running in parallel
l.cleanup(portForwarder.Name()) l.cleanup()
openvpnCancel() openvpnCancel()
l.statusManager.SetStatus(constants.Crashed) l.statusManager.SetStatus(constants.Crashed)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)