fix(portforward): rework run loop and fix deadlocks (#1874)

This commit is contained in:
Quentin McGaw
2023-09-23 12:57:12 +02:00
committed by GitHub
parent c435bbb32c
commit 71201411f4
30 changed files with 453 additions and 476 deletions

View File

@@ -377,9 +377,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
portForwardLogger := logger.New(log.SetComponent("port forwarding")) portForwardLogger := logger.New(log.SetComponent("port forwarding"))
portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding, portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding,
httpClient, firewallConf, portForwardLogger, puid, pgid) httpClient, firewallConf, portForwardLogger, puid, pgid)
portForwardHandler, portForwardCtx, portForwardDone := goshutdown.NewGoRoutineHandler( portForwardRunError, _ := portForwardLooper.Start(context.Background())
"port forwarding", goroutine.OptionTimeout(time.Second))
go portForwardLooper.Run(portForwardCtx, portForwardDone)
unboundLogger := logger.New(log.SetComponent("dns")) unboundLogger := logger.New(log.SetComponent("dns"))
unboundLooper := dns.NewLoop(dnsConf, allSettings.DNS, httpClient, unboundLooper := dns.NewLoop(dnsConf, allSettings.DNS, httpClient,
@@ -481,13 +479,21 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
order.OptionOnSuccess(defaultShutdownOnSuccess), order.OptionOnSuccess(defaultShutdownOnSuccess),
order.OptionOnFailure(defaultShutdownOnFailure)) order.OptionOnFailure(defaultShutdownOnFailure))
orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler, orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler,
vpnHandler, portForwardHandler, otherGroupHandler) vpnHandler, otherGroupHandler)
// Start VPN for the first time in a blocking call // Start VPN for the first time in a blocking call
// until the VPN is launched // until the VPN is launched
_, _ = vpnLooper.ApplyStatus(ctx, constants.Running) // TODO option to disable with variable _, _ = vpnLooper.ApplyStatus(ctx, constants.Running) // TODO option to disable with variable
<-ctx.Done() select {
case <-ctx.Done():
err = portForwardLooper.Stop()
if err != nil {
logger.Error("stopping port forward loop: " + err.Error())
}
case err := <-portForwardRunError:
logger.Errorf("port forwarding loop crashed: %s", err)
}
return orderHandler.Shutdown(context.Background()) return orderHandler.Shutdown(context.Background())
} }

View File

@@ -30,7 +30,7 @@ type PortForwarding struct {
Filepath *string `json:"status_file_path"` Filepath *string `json:"status_file_path"`
} }
func (p PortForwarding) validate(vpnProvider string) (err error) { func (p PortForwarding) Validate(vpnProvider string) (err error) {
if !*p.Enabled { if !*p.Enabled {
return nil return nil
} }
@@ -59,7 +59,7 @@ func (p PortForwarding) validate(vpnProvider string) (err error) {
return nil return nil
} }
func (p *PortForwarding) copy() (copied PortForwarding) { func (p *PortForwarding) Copy() (copied PortForwarding) {
return PortForwarding{ return PortForwarding{
Enabled: gosettings.CopyPointer(p.Enabled), Enabled: gosettings.CopyPointer(p.Enabled),
Provider: gosettings.CopyPointer(p.Provider), Provider: gosettings.CopyPointer(p.Provider),
@@ -73,7 +73,7 @@ func (p *PortForwarding) mergeWith(other PortForwarding) {
p.Filepath = gosettings.MergeWithPointer(p.Filepath, other.Filepath) p.Filepath = gosettings.MergeWithPointer(p.Filepath, other.Filepath)
} }
func (p *PortForwarding) overrideWith(other PortForwarding) { func (p *PortForwarding) OverrideWith(other PortForwarding) {
p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled) p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled)
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider) p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath) p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)

View File

@@ -49,7 +49,7 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) {
return fmt.Errorf("server selection: %w", err) return fmt.Errorf("server selection: %w", err)
} }
err = p.PortForwarding.validate(*p.Name) err = p.PortForwarding.Validate(*p.Name)
if err != nil { if err != nil {
return fmt.Errorf("port forwarding: %w", err) return fmt.Errorf("port forwarding: %w", err)
} }
@@ -61,7 +61,7 @@ func (p *Provider) copy() (copied Provider) {
return Provider{ return Provider{
Name: gosettings.CopyPointer(p.Name), Name: gosettings.CopyPointer(p.Name),
ServerSelection: p.ServerSelection.copy(), ServerSelection: p.ServerSelection.copy(),
PortForwarding: p.PortForwarding.copy(), PortForwarding: p.PortForwarding.Copy(),
} }
} }
@@ -74,7 +74,7 @@ func (p *Provider) mergeWith(other Provider) {
func (p *Provider) overrideWith(other Provider) { func (p *Provider) overrideWith(other Provider) {
p.Name = gosettings.OverrideWithPointer(p.Name, other.Name) p.Name = gosettings.OverrideWithPointer(p.Name, other.Name)
p.ServerSelection.overrideWith(other.ServerSelection) p.ServerSelection.overrideWith(other.ServerSelection)
p.PortForwarding.overrideWith(other.PortForwarding) p.PortForwarding.OverrideWith(other.PortForwarding)
} }
func (p *Provider) setDefaults() { func (p *Provider) setDefaults() {

View File

@@ -1,32 +0,0 @@
package portforward
import "context"
// firewallBlockPort obtains the state port thread safely and blocks
// it in the firewall if it is not the zero value (0).
func (l *Loop) firewallBlockPort(ctx context.Context) {
port := l.state.GetPortForwarded()
if port == 0 {
return
}
err := l.portAllower.RemoveAllowedPort(ctx, port)
if err != nil {
l.logger.Error("cannot block previous port in firewall: " + err.Error())
}
}
// firewallAllowPort obtains the state port thread safely and allows
// it in the firewall if it is not the zero value (0).
func (l *Loop) firewallAllowPort(ctx context.Context) {
port := l.state.GetPortForwarded()
if port == 0 {
return
}
startData := l.state.GetStartData()
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
if err != nil {
l.logger.Error("cannot allow port: " + err.Error())
}
}

View File

@@ -1,37 +0,0 @@
package portforward
import (
"fmt"
"os"
)
func (l *Loop) removePortForwardedFile() {
filepath := *l.state.GetSettings().Filepath
l.logger.Info("removing port file " + filepath)
if err := os.Remove(filepath); err != nil {
l.logger.Error(err.Error())
}
}
func (l *Loop) writePortForwardedFile(port uint16) {
filepath := *l.state.GetSettings().Filepath
l.logger.Info("writing port file " + filepath)
if err := writePortForwardedToFile(filepath, port, l.puid, l.pgid); err != nil {
l.logger.Error("writing port forwarded to file: " + err.Error())
}
}
func writePortForwardedToFile(filepath string, port uint16, uid, gid int) (err error) {
const perms = os.FileMode(0644)
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
if err != nil {
return fmt.Errorf("writing file: %w", err)
}
err = os.Chown(filepath, uid, gid)
if err != nil {
return fmt.Errorf("chowning file: %w", err)
}
return nil
}

View File

@@ -1,5 +0,0 @@
package portforward
func (l *Loop) GetPortForwarded() (port uint16) {
return l.state.GetPortForwarded()
}

View File

@@ -1,22 +0,0 @@
package portforward
import (
"context"
"time"
)
func (l *Loop) logAndWait(ctx context.Context, err error) {
if err != nil {
l.logger.Error(err.Error())
}
l.logger.Info("retrying in " + l.backoffTime.String())
timer := time.NewTimer(l.backoffTime)
l.backoffTime *= 2
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
}
}

View File

@@ -1,10 +1,20 @@
package portforward package portforward
import ( import "context"
"context"
) type Service interface {
Start(ctx context.Context) (runError <-chan error, err error)
Stop() (err error)
GetPortForwarded() (port uint16)
}
type PortAllower interface { type PortAllower interface {
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
RemoveAllowedPort(ctx context.Context, port uint16) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
} }
type Logger interface {
Info(s string)
Warn(s string)
Error(s string)
}

View File

@@ -1,7 +0,0 @@
package portforward
type Logger interface {
Info(s string)
Warn(s string)
Error(s string)
}

View File

@@ -1,64 +1,139 @@
package portforward package portforward
import ( import (
"context"
"fmt"
"net/http" "net/http"
"sync" "sync"
"time"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/portforward/service"
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/portforward/state"
) )
type Loop struct { type Loop struct {
statusManager *loopstate.State // State
state *state.State settings service.Settings
// Fixed parameters settingsMutex sync.RWMutex
puid int service Service
pgid int // Fixed injected objets
// Objects
client *http.Client client *http.Client
portAllower PortAllower portAllower PortAllower
logger Logger logger Logger
// Fixed parameters
uid, gid int
// Internal channels and locks // Internal channels and locks
start chan struct{} // runCtx is used to detect when the loop has exited
running chan models.LoopStatus // when performing an update
stop chan struct{} runCtx context.Context //nolint:containedctx
stopped chan struct{} runCancel context.CancelFunc
startMu sync.Mutex updatedSignal chan<- struct{}
backoffTime time.Duration runDone <-chan struct{}
userTrigger bool
} }
const defaultBackoffTime = 5 * time.Second
func NewLoop(settings settings.PortForwarding, func NewLoop(settings settings.PortForwarding,
client *http.Client, portAllower PortAllower, client *http.Client, portAllower PortAllower,
logger Logger, puid, pgid int) *Loop { logger Logger, uid, gid int) *Loop {
start := make(chan struct{})
running := make(chan models.LoopStatus)
stop := make(chan struct{})
stopped := make(chan struct{})
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
state := state.New(statusManager, settings)
return &Loop{ return &Loop{
statusManager: statusManager, settings: service.Settings{
state: state, UserSettings: settings,
puid: puid, },
pgid: pgid,
// Objects
client: client, client: client,
portAllower: portAllower, portAllower: portAllower,
logger: logger, logger: logger,
start: start, uid: uid,
running: running, gid: gid,
stop: stop,
stopped: stopped,
userTrigger: true,
backoffTime: defaultBackoffTime,
} }
} }
func (l *Loop) Start(_ context.Context) (runError <-chan error, _ error) {
l.runCtx, l.runCancel = context.WithCancel(context.Background())
runDone := make(chan struct{})
l.runDone = runDone
updatedSignal := make(chan struct{})
l.updatedSignal = updatedSignal
runErrorCh := make(chan error)
go l.run(l.runCtx, runDone, runErrorCh, updatedSignal)
return runErrorCh, nil
}
func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
runErrorCh chan<- error, updatedSignal <-chan struct{}) {
defer close(runDone)
var serviceRunError <-chan error
for {
select {
case <-runCtx.Done():
// Stop call takes care of stopping the service
return
case <-updatedSignal: // first and subsequent start trigger
case err := <-serviceRunError:
runErrorCh <- err
return
}
firstRun := l.service == nil
if !firstRun {
err := l.service.Stop()
if err != nil {
runErrorCh <- fmt.Errorf("stopping previous service: %w", err)
return
}
}
l.settingsMutex.RLock()
l.service = service.New(l.settings, l.client,
l.portAllower, l.logger, l.uid, l.gid)
l.settingsMutex.RUnlock()
var err error
serviceRunError, err = l.service.Start(runCtx)
if err != nil {
if runCtx.Err() == nil { // crashed but NOT stopped
runErrorCh <- fmt.Errorf("starting new service: %w", err)
}
return
}
}
}
func (l *Loop) UpdateWith(partialUpdate service.Settings) (err error) {
l.settingsMutex.Lock()
l.settings, err = l.settings.UpdateWith(partialUpdate)
l.settingsMutex.Unlock()
if err != nil {
return err
}
select {
case l.updatedSignal <- struct{}{}:
// Settings are validated and if the service fails to start
// or crashes at runtime, the loop will stop and signal its
// parent goroutine. Settings validation should be the only
// error feedback for the caller of `Update`.
return nil
case <-l.runCtx.Done():
// loop has been stopped, no update can be done
return l.runCtx.Err()
}
}
func (l *Loop) Stop() (err error) {
l.runCancel()
<-l.runDone
if l.service != nil {
return l.service.Stop()
}
return nil
}
func (l *Loop) GetPortForwarded() (port uint16) {
if l.service == nil {
return 0
}
return l.service.GetPortForwarded()
}

View File

@@ -1,98 +0,0 @@
package portforward
import (
"context"
"strconv"
"github.com/qdm12/gluetun/internal/constants"
)
func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
defer close(done)
select {
case <-l.start: // l.state.SetStartData called beforehand
case <-ctx.Done():
return
}
for ctx.Err() == nil {
pfCtx, pfCancel := context.WithCancel(ctx)
portCh := make(chan uint16)
errorCh := make(chan error)
startData := l.state.GetStartData()
go func(ctx context.Context, startData StartData) {
port, err := startData.PortForwarder.PortForward(ctx, l.client, l.logger,
startData.Gateway, startData.ServerName)
if err != nil {
errorCh <- err
return
}
portCh <- port
// Infinite loop
err = startData.PortForwarder.KeepPortForward(ctx, port,
startData.Gateway, startData.ServerName, l.logger)
errorCh <- err
}(pfCtx, startData)
if l.userTrigger {
l.userTrigger = false
l.running <- constants.Running
} else { // crash
l.backoffTime = defaultBackoffTime
l.statusManager.SetStatus(constants.Running)
}
stayHere := true
stopped := false
for stayHere {
select {
case <-ctx.Done():
pfCancel()
if stopped {
return
}
<-errorCh
close(errorCh)
close(portCh)
l.removePortForwardedFile()
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(0)
return
case <-l.start:
l.userTrigger = true
l.logger.Info("starting")
pfCancel()
stayHere = false
case <-l.stop:
l.userTrigger = true
l.logger.Info("stopping")
pfCancel()
<-errorCh
l.removePortForwardedFile()
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(0)
l.stopped <- struct{}{}
stopped = true
case port := <-portCh:
l.logger.Info("port forwarded is " + strconv.Itoa(int(port)))
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(port)
l.firewallAllowPort(ctx)
l.writePortForwardedFile(port)
case err := <-errorCh:
pfCancel()
close(errorCh)
close(portCh)
l.statusManager.SetStatus(constants.Crashed)
l.logAndWait(ctx, err)
stayHere = false
}
}
pfCancel() // for linting
}
}

View File

@@ -0,0 +1,23 @@
package service
import (
"fmt"
"os"
)
func (s *Service) writePortForwardedFile(port uint16) (err error) {
filepath := *s.settings.UserSettings.Filepath
s.logger.Info("writing port file " + filepath)
const perms = os.FileMode(0644)
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
if err != nil {
return fmt.Errorf("writing file: %w", err)
}
err = os.Chown(filepath, s.puid, s.pgid)
if err != nil {
return fmt.Errorf("chowning file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,16 @@
package service
import (
"context"
)
type PortAllower interface {
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
}
type Logger interface {
Info(s string)
Warn(s string)
Error(s string)
}

View File

@@ -0,0 +1,45 @@
package service
import (
"context"
"net/http"
"sync"
)
type Service struct {
// State
portMutex sync.RWMutex
port uint16
// Fixed parameters
settings Settings
puid int
pgid int
// Fixed injected objets
client *http.Client
portAllower PortAllower
logger Logger
// Internal channels and locks
startStopMutex sync.Mutex
keepPortCancel context.CancelFunc
keepPortDoneCh <-chan struct{}
}
func New(settings Settings, client *http.Client,
portAllower PortAllower, logger Logger, puid, pgid int) *Service {
return &Service{
// Fixed parameters
settings: settings,
puid: puid,
pgid: pgid,
// Fixed injected objets
client: client,
portAllower: portAllower,
logger: logger,
}
}
func (s *Service) GetPortForwarded() (port uint16) {
s.portMutex.RLock()
defer s.portMutex.RUnlock()
return s.port
}

View File

@@ -0,0 +1,79 @@
package service
import (
"errors"
"fmt"
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gosettings"
)
type Settings struct {
UserSettings settings.PortForwarding
PortForwarder provider.PortForwarder
Gateway netip.Addr // needed for PIA and ProtonVPN
ServerName string // needed for PIA
Interface string // needed for PIA and ProtonVPN, tun0 for example
VPNProvider string // used to validate new settings
}
// UpdateWith deep copies the receiving settings, overrides the copy with
// fields set in the partialUpdate argument, validates the new settings
// and returns them if they are valid, or returns an error otherwise.
// In all cases, the receiving settings are unmodified.
func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) {
updatedSettings = s.copy()
updatedSettings.overrideWith(partialUpdate)
err = updatedSettings.validate()
if err != nil {
return updatedSettings, fmt.Errorf("validating new settings: %w", err)
}
return updatedSettings, nil
}
func (s Settings) copy() (copied Settings) {
copied.UserSettings = s.UserSettings.Copy()
copied.PortForwarder = s.PortForwarder
copied.Gateway = s.Gateway
copied.ServerName = s.ServerName
copied.Interface = s.Interface
copied.VPNProvider = s.VPNProvider
return copied
}
func (s *Settings) overrideWith(update Settings) {
s.UserSettings.OverrideWith(update.UserSettings)
s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder)
s.Gateway = gosettings.OverrideWithValidator(s.Gateway, update.Gateway)
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider)
}
var (
ErrVPNProviderNotSet = errors.New("VPN provider not set")
ErrServerNameNotSet = errors.New("server name not set")
ErrPortForwarderNotSet = errors.New("port forwarder not set")
ErrGatewayNotSet = errors.New("gateway not set")
ErrInterfaceNotSet = errors.New("interface not set")
)
func (s *Settings) validate() (err error) {
switch {
case s.VPNProvider == "":
return fmt.Errorf("%w", ErrVPNProviderNotSet)
case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "":
return fmt.Errorf("%w", ErrServerNameNotSet)
case s.PortForwarder == nil:
return fmt.Errorf("%w", ErrPortForwarderNotSet)
case !s.Gateway.IsValid():
return fmt.Errorf("%w", ErrGatewayNotSet)
case s.Interface == "":
return fmt.Errorf("%w", ErrInterfaceNotSet)
}
return s.UserSettings.Validate(s.VPNProvider)
}

View File

@@ -0,0 +1,60 @@
package service
import (
"context"
"fmt"
)
func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) {
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
if !*s.settings.UserSettings.Enabled {
return nil, nil //nolint:nilnil
}
s.logger.Info("starting")
port, err := s.settings.PortForwarder.PortForward(ctx, s.client, s.logger,
s.settings.Gateway, s.settings.ServerName)
if err != nil {
return nil, fmt.Errorf("port forwarding for the first time: %w", err)
}
s.logger.Info("port forwarded is " + fmt.Sprint(int(port)))
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
if err != nil {
return nil, fmt.Errorf("allowing port in firewall: %w", err)
}
err = s.writePortForwardedFile(port)
if err != nil {
_ = s.cleanup()
return nil, fmt.Errorf("writing port file: %w", err)
}
s.portMutex.Lock()
s.port = port
s.portMutex.Unlock()
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
s.keepPortCancel = keepPortCancel
runErrorCh := make(chan error)
keepPortDoneCh := make(chan struct{})
s.keepPortDoneCh = keepPortDoneCh
go func(ctx context.Context, settings Settings, port uint16,
runError chan<- error, doneCh chan<- struct{}) {
defer close(doneCh)
err = settings.PortForwarder.KeepPortForward(ctx, port,
settings.Gateway, settings.ServerName, s.logger)
crashed := ctx.Err() == nil
if !crashed { // stopped by Stop call
return
}
_ = s.cleanup()
runError <- err
}(keepPortCtx, s.settings, port, runErrorCh, keepPortDoneCh)
return runErrorCh, nil
}

View File

@@ -0,0 +1,47 @@
package service
import (
"context"
"fmt"
"os"
)
func (s *Service) Stop() (err error) {
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
s.portMutex.RLock()
serviceNotRunning := s.port == 0
s.portMutex.RUnlock()
if serviceNotRunning {
return nil
}
s.logger.Info("stopping")
s.keepPortCancel()
<-s.keepPortDoneCh
return s.cleanup()
}
func (s *Service) cleanup() (err error) {
s.portMutex.Lock()
defer s.portMutex.Unlock()
err = s.portAllower.RemoveAllowedPort(context.Background(), s.port)
if err != nil {
return fmt.Errorf("blocking previous port in firewall: %w", err)
}
s.port = 0
filepath := *s.settings.UserSettings.Filepath
s.logger.Info("removing port file " + filepath)
err = os.Remove(filepath)
if err != nil {
return fmt.Errorf("removing port file: %w", err)
}
return nil
}

View File

@@ -1,16 +0,0 @@
package portforward
import (
"context"
"github.com/qdm12/gluetun/internal/configuration/settings"
)
func (l *Loop) GetSettings() (settings settings.PortForwarding) {
return l.state.GetSettings()
}
func (l *Loop) SetSettings(ctx context.Context, settings settings.PortForwarding) (
outcome string) {
return l.state.SetSettings(ctx, settings)
}

View File

@@ -1,17 +0,0 @@
package state
// GetPortForwarded is used by the control HTTP server
// to obtain the port currently forwarded.
func (s *State) GetPortForwarded() (port uint16) {
s.portForwardedMu.RLock()
defer s.portForwardedMu.RUnlock()
return s.portForwarded
}
// SetPortForwarded is only used from within the OpenVPN loop
// to set the port forwarded.
func (s *State) SetPortForwarded(port uint16) {
s.portForwardedMu.Lock()
defer s.portForwardedMu.Unlock()
s.portForwarded = port
}

View File

@@ -1,49 +0,0 @@
package state
import (
"context"
"os"
"reflect"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
)
func (s *State) GetSettings() (settings settings.PortForwarding) {
s.settingsMu.RLock()
defer s.settingsMu.RUnlock()
return s.settings
}
func (s *State) SetSettings(ctx context.Context, settings settings.PortForwarding) (
outcome string) {
s.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(s.settings, settings)
if settingsUnchanged {
s.settingsMu.Unlock()
return "settings left unchanged"
}
if s.settings.Filepath != settings.Filepath {
_ = os.Rename(*s.settings.Filepath, *settings.Filepath)
}
newEnabled := *settings.Enabled
previousEnabled := *s.settings.Enabled
s.settings = settings
s.settingsMu.Unlock()
switch {
case !newEnabled && !previousEnabled:
case newEnabled && previousEnabled:
// no need to restart for now since we os.Rename the file here.
case newEnabled && !previousEnabled:
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
case !newEnabled && previousEnabled:
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
}
return "settings updated"
}

View File

@@ -1,26 +0,0 @@
package state
import (
"net/netip"
"github.com/qdm12/gluetun/internal/provider"
)
type StartData struct {
PortForwarder provider.PortForwarder
Gateway netip.Addr // needed for PIA
ServerName string // needed for PIA
Interface string // tun0 for example
}
func (s *State) GetStartData() (startData StartData) {
s.startDataMu.RLock()
defer s.startDataMu.RUnlock()
return s.startData
}
func (s *State) SetStartData(startData StartData) {
s.startDataMu.Lock()
defer s.startDataMu.Unlock()
s.startData = startData
}

View File

@@ -1,35 +0,0 @@
package state
import (
"context"
"sync"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
)
func New(statusApplier StatusApplier,
settings settings.PortForwarding) *State {
return &State{
statusApplier: statusApplier,
settings: settings,
}
}
type State struct {
statusApplier StatusApplier
settings settings.PortForwarding
settingsMu sync.RWMutex
portForwarded uint16
portForwardedMu sync.RWMutex
startData StartData
startDataMu sync.RWMutex
}
type StatusApplier interface {
ApplyStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error)
}

View File

@@ -1,27 +0,0 @@
package portforward
import (
"context"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/portforward/state"
)
func (l *Loop) GetStatus() (status models.LoopStatus) {
return l.statusManager.GetStatus()
}
type StartData = state.StartData
func (l *Loop) Start(ctx context.Context, data StartData) (
outcome string, err error) {
l.startMu.Lock()
defer l.startMu.Unlock()
l.state.SetStartData(data)
return l.statusManager.ApplyStatus(ctx, constants.Running)
}
func (l *Loop) Stop(ctx context.Context) (outcome string, err error) {
return l.statusManager.ApplyStatus(ctx, constants.Stopped)
}

View File

@@ -21,6 +21,7 @@ type Provider interface {
} }
type PortForwarder interface { type PortForwarder interface {
Name() string
PortForward(ctx context.Context, client *http.Client, PortForward(ctx context.Context, client *http.Client,
logger utils.Logger, gateway netip.Addr, serverName string) ( logger utils.Logger, gateway netip.Addr, serverName string) (
port uint16, err error) port uint16, err error)

View File

@@ -2,14 +2,14 @@ package vpn
import ( import (
"context" "context"
"time" "errors"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) { func (l *Loop) cleanup(vpnProvider string) {
for _, vpnPort := range l.vpnInputPorts { for _, vpnPort := range l.vpnInputPorts {
err := l.fw.RemoveAllowedPort(ctx, vpnPort) err := l.fw.RemoveAllowedPort(context.Background(), vpnPort)
if err != nil { if err != nil {
l.logger.Error("cannot remove allowed input port from firewall: " + err.Error()) l.logger.Error("cannot remove allowed input port from firewall: " + err.Error())
} }
@@ -17,11 +17,11 @@ func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) {
l.publicip.SetData(models.PublicIP{}) // clear public IP address data l.publicip.SetData(models.PublicIP{}) // clear public IP address data
if pfEnabled { err := l.stopPortForwarding(vpnProvider)
const pfTimeout = 100 * time.Millisecond if err != nil {
err := l.stopPortForwarding(ctx, pfTimeout) portForwardingAlreadyStopped := errors.Is(err, context.Canceled)
if err != nil { if !portForwardingAlreadyStopped {
l.logger.Error("cannot stop port forwarding: " + err.Error()) l.logger.Error("stopping port forwarding: " + err.Error())
} }
} }
} }

View File

@@ -8,6 +8,8 @@ import (
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
func ptrTo[T any](value T) *T { return &value }
// waitForError waits 100ms for an error in the waitError channel. // waitForError waits 100ms for an error in the waitError channel.
func (l *Loop) waitForError(ctx context.Context, func (l *Loop) waitForError(ctx context.Context,
waitError chan error) (err error) { waitError chan error) (err error) {

View File

@@ -7,7 +7,7 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/portforward" portforward "github.com/qdm12/gluetun/internal/portforward/service"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
) )
@@ -22,8 +22,7 @@ type Routing interface {
} }
type PortForward interface { type PortForward interface {
Start(ctx context.Context, data portforward.StartData) (outcome string, err error) UpdateWith(settings portforward.Settings) (err error)
Stop(ctx context.Context) (outcome string, err error)
} }
type OpenVPN interface { type OpenVPN interface {

View File

@@ -1,47 +1,35 @@
package vpn package vpn
import ( import (
"context"
"fmt" "fmt"
"time"
"github.com/qdm12/gluetun/internal/portforward" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/portforward/service"
) )
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) { func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
if !data.portForwarding {
return nil
}
// only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf) gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
if err != nil { if err != nil {
return fmt.Errorf("obtaining VPN local gateway IP for interface %s: %w", data.vpnIntf, err) return fmt.Errorf("obtaining VPN local gateway IP for interface %s: %w", data.vpnIntf, err)
} }
l.logger.Info("VPN gateway IP address: " + gateway.String()) l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{ partialUpdate := service.Settings{
PortForwarder: data.portForwarder, PortForwarder: data.portForwarder,
Gateway: gateway, Gateway: gateway,
ServerName: data.serverName,
Interface: data.vpnIntf, Interface: data.vpnIntf,
ServerName: data.serverName,
VPNProvider: data.portForwarder.Name(),
} }
_, err = l.portForward.Start(ctx, pfData) return l.portForward.UpdateWith(partialUpdate)
if err != nil {
return fmt.Errorf("starting port forwarding: %w", err)
}
return nil
} }
func (l *Loop) stopPortForwarding(ctx context.Context, func (l *Loop) stopPortForwarding(vpnProvider string) (err error) {
timeout time.Duration) (err error) { partialUpdate := service.Settings{
if timeout > 0 { VPNProvider: vpnProvider,
var cancel context.CancelFunc UserSettings: settings.PortForwarding{
ctx, cancel = context.WithTimeout(ctx, timeout) Enabled: ptrTo(false),
defer cancel() },
} }
return l.portForward.UpdateWith(partialUpdate)
_, err = l.portForward.Stop(ctx)
return err
} }

View File

@@ -22,10 +22,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
providerConf := l.providers.Get(*settings.Provider.Name) providerConf := l.providers.Get(*settings.Provider.Name)
portForwarding := *settings.Provider.PortForwarding.Enabled
customPortForwardingProvider := *settings.Provider.PortForwarding.Provider customPortForwardingProvider := *settings.Provider.PortForwarding.Provider
portForwader := providerConf portForwader := providerConf
if portForwarding && customPortForwardingProvider != "" { if customPortForwardingProvider != "" {
portForwader = l.providers.Get(customPortForwardingProvider) portForwader = l.providers.Get(customPortForwardingProvider)
} }
@@ -49,10 +48,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue continue
} }
tunnelUpData := tunnelUpData{ tunnelUpData := tunnelUpData{
portForwarding: portForwarding, serverName: serverName,
serverName: serverName, portForwarder: portForwader,
portForwarder: portForwader, vpnIntf: vpnInterface,
vpnIntf: vpnInterface,
} }
openvpnCtx, openvpnCancel := context.WithCancel(context.Background()) openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
@@ -76,7 +74,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-tunnelReady: case <-tunnelReady:
go l.onTunnelUp(openvpnCtx, tunnelUpData) go l.onTunnelUp(openvpnCtx, tunnelUpData)
case <-ctx.Done(): case <-ctx.Done():
l.cleanup(context.Background(), portForwarding) l.cleanup(portForwader.Name())
openvpnCancel() openvpnCancel()
<-waitError <-waitError
close(waitError) close(waitError)
@@ -84,7 +82,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-l.stop: case <-l.stop:
l.userTrigger = true l.userTrigger = true
l.logger.Info("stopping") l.logger.Info("stopping")
l.cleanup(context.Background(), portForwarding) l.cleanup(portForwader.Name())
openvpnCancel() openvpnCancel()
<-waitError <-waitError
// do not close waitError or the waitError // do not close waitError or the waitError
@@ -97,7 +95,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case err := <-waitError: // unexpected error case err := <-waitError: // unexpected error
l.statusManager.Lock() // prevent SetStatus from running in parallel l.statusManager.Lock() // prevent SetStatus from running in parallel
l.cleanup(context.Background(), portForwarding) l.cleanup(portForwader.Name())
openvpnCancel() openvpnCancel()
l.statusManager.SetStatus(constants.Crashed) l.statusManager.SetStatus(constants.Crashed)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)

View File

@@ -10,10 +10,9 @@ import (
type tunnelUpData struct { type tunnelUpData struct {
// Port forwarding // Port forwarding
portForwarding bool vpnIntf string
vpnIntf string serverName string
serverName string portForwarder provider.PortForwarder
portForwarder provider.PortForwarder
} }
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
@@ -42,7 +41,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
} }
} }
err := l.startPortForwarding(ctx, data) err := l.startPortForwarding(data)
if err != nil { if err != nil {
l.logger.Error(err.Error()) l.logger.Error(err.Error())
} }