fix(portforward): trigger after VPN restart
This commit is contained in:
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
type Loop struct {
|
||||
// State
|
||||
settings service.Settings
|
||||
settings Settings
|
||||
settingsMutex sync.RWMutex
|
||||
service Service
|
||||
// Fixed injected objets
|
||||
@@ -28,7 +28,7 @@ type Loop struct {
|
||||
runCtx context.Context //nolint:containedctx
|
||||
runCancel context.CancelFunc
|
||||
runDone <-chan struct{}
|
||||
updateTrigger chan<- service.Settings
|
||||
updateTrigger chan<- Settings
|
||||
updatedResult <-chan error
|
||||
}
|
||||
|
||||
@@ -36,8 +36,12 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
|
||||
client *http.Client, portAllower PortAllower,
|
||||
logger Logger, uid, gid int) *Loop {
|
||||
return &Loop{
|
||||
settings: service.Settings{
|
||||
UserSettings: settings,
|
||||
settings: Settings{
|
||||
VPNIsUp: ptrTo(false),
|
||||
Service: service.Settings{
|
||||
Enabled: settings.Enabled,
|
||||
Filepath: *settings.Filepath,
|
||||
},
|
||||
},
|
||||
routing: routing,
|
||||
client: client,
|
||||
@@ -57,24 +61,22 @@ func (l *Loop) Start(_ context.Context) (runError <-chan error, _ error) {
|
||||
runDone := make(chan struct{})
|
||||
l.runDone = runDone
|
||||
|
||||
updateTrigger := make(chan service.Settings)
|
||||
updateTrigger := make(chan Settings)
|
||||
l.updateTrigger = updateTrigger
|
||||
updateResult := make(chan error)
|
||||
l.updatedResult = updateResult
|
||||
runErrorCh := make(chan error)
|
||||
|
||||
go l.run(l.runCtx, runDone, runErrorCh,
|
||||
l.settings, updateTrigger, updateResult)
|
||||
go l.run(l.runCtx, runDone, runErrorCh, updateTrigger, updateResult)
|
||||
|
||||
return runErrorCh, nil
|
||||
}
|
||||
|
||||
func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
|
||||
runErrorCh chan<- error, initialSettings service.Settings,
|
||||
updateTrigger <-chan service.Settings, updateResult chan<- error) {
|
||||
runErrorCh chan<- error, updateTrigger <-chan Settings,
|
||||
updateResult chan<- error) {
|
||||
defer close(runDone)
|
||||
|
||||
settings := initialSettings
|
||||
var serviceRunError <-chan error
|
||||
for {
|
||||
updateReceived := false
|
||||
@@ -83,18 +85,20 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
|
||||
// Stop call takes care of stopping the service
|
||||
return
|
||||
case partialUpdate := <-updateTrigger:
|
||||
updatedSettings, err := settings.UpdateWith(partialUpdate)
|
||||
updatedSettings, err := l.settings.updateWith(partialUpdate)
|
||||
if err != nil {
|
||||
updateResult <- err
|
||||
continue
|
||||
}
|
||||
settings = updatedSettings
|
||||
updateReceived = true
|
||||
l.settingsMutex.Lock()
|
||||
l.settings = updatedSettings
|
||||
l.settingsMutex.Unlock()
|
||||
case err := <-serviceRunError:
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
firstRun := l.service == nil
|
||||
firstRun := serviceRunError == nil
|
||||
if !firstRun {
|
||||
err := l.service.Stop()
|
||||
if err != nil {
|
||||
@@ -103,7 +107,11 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
|
||||
}
|
||||
}
|
||||
|
||||
l.service = service.New(settings, l.routing, l.client,
|
||||
serviceSettings := l.settings.Service.Copy()
|
||||
// Only enable port forward if the VPN tunnel is up
|
||||
*serviceSettings.Enabled = *serviceSettings.Enabled && *l.settings.VPNIsUp
|
||||
|
||||
l.service = service.New(serviceSettings, l.routing, l.client,
|
||||
l.portAllower, l.logger, l.uid, l.gid)
|
||||
|
||||
var err error
|
||||
@@ -119,16 +127,10 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Service is created and started successfully, so update
|
||||
// the settings for external calls such as GetSettings.
|
||||
l.settingsMutex.Lock()
|
||||
l.settings = settings
|
||||
l.settingsMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) UpdateWith(partialUpdate service.Settings) (err error) {
|
||||
func (l *Loop) UpdateWith(partialUpdate Settings) (err error) {
|
||||
select {
|
||||
case l.updateTrigger <- partialUpdate:
|
||||
select {
|
||||
@@ -159,3 +161,7 @@ func (l *Loop) GetPortForwarded() (port uint16) {
|
||||
}
|
||||
return l.service.GetPortForwarded()
|
||||
}
|
||||
|
||||
func ptrTo[T any](value T) *T {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
func (s *Service) writePortForwardedFile(port uint16) (err error) {
|
||||
filepath := *s.settings.UserSettings.Filepath
|
||||
filepath := s.settings.Filepath
|
||||
s.logger.Info("writing port file " + filepath)
|
||||
const perms = os.FileMode(0644)
|
||||
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
|
||||
|
||||
@@ -4,69 +4,51 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gosettings"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
UserSettings settings.PortForwarding
|
||||
Enabled *bool
|
||||
PortForwarder PortForwarder
|
||||
Filepath string
|
||||
Interface string // needed for PIA and ProtonVPN, tun0 for example
|
||||
ServerName string // needed for PIA
|
||||
VPNProvider string // used to validate new settings
|
||||
}
|
||||
|
||||
// UpdateWith deep copies the receiving settings, overrides the copy with
|
||||
// fields set in the partialUpdate argument, validates the new settings
|
||||
// and returns them if they are valid, or returns an error otherwise.
|
||||
// In all cases, the receiving settings are unmodified.
|
||||
func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) {
|
||||
updatedSettings = s.copy()
|
||||
updatedSettings.overrideWith(partialUpdate)
|
||||
err = updatedSettings.validate()
|
||||
if err != nil {
|
||||
return updatedSettings, fmt.Errorf("validating new settings: %w", err)
|
||||
}
|
||||
return updatedSettings, nil
|
||||
}
|
||||
|
||||
func (s Settings) copy() (copied Settings) {
|
||||
copied.UserSettings = s.UserSettings.Copy()
|
||||
func (s Settings) Copy() (copied Settings) {
|
||||
copied.Enabled = gosettings.CopyPointer(s.Enabled)
|
||||
copied.PortForwarder = s.PortForwarder
|
||||
copied.Filepath = s.Filepath
|
||||
copied.Interface = s.Interface
|
||||
copied.ServerName = s.ServerName
|
||||
copied.VPNProvider = s.VPNProvider
|
||||
return copied
|
||||
}
|
||||
|
||||
func (s *Settings) overrideWith(update Settings) {
|
||||
s.UserSettings.OverrideWith(update.UserSettings)
|
||||
func (s *Settings) OverrideWith(update Settings) {
|
||||
s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled)
|
||||
s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder)
|
||||
s.Filepath = gosettings.OverrideWithString(s.Filepath, update.Filepath)
|
||||
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
|
||||
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
|
||||
s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVPNProviderNotSet = errors.New("VPN provider not set")
|
||||
ErrServerNameNotSet = errors.New("server name not set")
|
||||
ErrPortForwarderNotSet = errors.New("port forwarder not set")
|
||||
ErrGatewayNotSet = errors.New("gateway not set")
|
||||
ErrInterfaceNotSet = errors.New("interface not set")
|
||||
ErrServerNameNotSet = errors.New("server name not set")
|
||||
ErrFilepathNotSet = errors.New("file path not set")
|
||||
ErrInterfaceNotSet = errors.New("interface not set")
|
||||
)
|
||||
|
||||
func (s *Settings) validate() (err error) {
|
||||
func (s *Settings) Validate() (err error) {
|
||||
switch {
|
||||
case s.VPNProvider == "":
|
||||
return fmt.Errorf("%w", ErrVPNProviderNotSet)
|
||||
case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "":
|
||||
return fmt.Errorf("%w", ErrServerNameNotSet)
|
||||
case s.PortForwarder == nil:
|
||||
return fmt.Errorf("%w", ErrPortForwarderNotSet)
|
||||
// Port forwarder can be nil when the loop updates
|
||||
// to stop the service.
|
||||
case s.Filepath == "":
|
||||
return fmt.Errorf("%w", ErrFilepathNotSet)
|
||||
case s.Interface == "":
|
||||
return fmt.Errorf("%w", ErrInterfaceNotSet)
|
||||
case s.PortForwarder.Name() == providers.PrivateInternetAccess && s.ServerName == "":
|
||||
return fmt.Errorf("%w", ErrServerNameNotSet)
|
||||
}
|
||||
|
||||
return s.UserSettings.Validate(s.VPNProvider)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
s.startStopMutex.Lock()
|
||||
defer s.startStopMutex.Unlock()
|
||||
|
||||
if !*s.settings.UserSettings.Enabled {
|
||||
if !*s.settings.Enabled {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
@@ -64,6 +64,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
if !crashed { // stopped by Stop call
|
||||
return
|
||||
}
|
||||
s.startStopMutex.Lock()
|
||||
defer s.startStopMutex.Unlock()
|
||||
_ = s.cleanup()
|
||||
runError <- err
|
||||
}(keepPortCtx, s.settings.PortForwarder, obj, runErrorCh, keepPortDoneCh)
|
||||
|
||||
@@ -14,6 +14,7 @@ func (s *Service) Stop() (err error) {
|
||||
serviceNotRunning := s.port == 0
|
||||
s.portMutex.RUnlock()
|
||||
if serviceNotRunning {
|
||||
// TODO replace with goservices.ErrAlreadyStopped
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,7 +37,7 @@ func (s *Service) cleanup() (err error) {
|
||||
|
||||
s.port = 0
|
||||
|
||||
filepath := *s.settings.UserSettings.Filepath
|
||||
filepath := s.settings.Filepath
|
||||
s.logger.Info("removing port file " + filepath)
|
||||
err = os.Remove(filepath)
|
||||
if err != nil {
|
||||
|
||||
43
internal/portforward/settings.go
Normal file
43
internal/portforward/settings.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/portforward/service"
|
||||
"github.com/qdm12/gosettings"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
// VPNIsUp can be optionally set to signal the loop
|
||||
// the VPN is up (true) or down (false). If left to nil,
|
||||
// it is assumed the VPN is in the same previous state.
|
||||
VPNIsUp *bool
|
||||
Service service.Settings
|
||||
}
|
||||
|
||||
// updateWith deep copies the receiving settings, overrides the copy with
|
||||
// fields set in the partialUpdate argument, validates the new settings
|
||||
// and returns them if they are valid, or returns an error otherwise.
|
||||
// In all cases, the receiving settings are unmodified.
|
||||
func (s Settings) updateWith(partialUpdate Settings) (updated Settings, err error) {
|
||||
updated = s.copy()
|
||||
updated.overrideWith(partialUpdate)
|
||||
err = updated.validate()
|
||||
if err != nil {
|
||||
return updated, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (s Settings) copy() (copied Settings) {
|
||||
copied.VPNIsUp = gosettings.CopyPointer(s.VPNIsUp)
|
||||
copied.Service = s.Service.Copy()
|
||||
return copied
|
||||
}
|
||||
|
||||
func (s *Settings) overrideWith(update Settings) {
|
||||
s.VPNIsUp = gosettings.OverrideWithPointer(s.VPNIsUp, update.VPNIsUp)
|
||||
s.Service.OverrideWith(update.Service)
|
||||
}
|
||||
|
||||
func (s Settings) validate() (err error) {
|
||||
return s.Service.Validate()
|
||||
}
|
||||
Reference in New Issue
Block a user