fix(portforward): rework run loop and fix deadlocks (#1874)

This commit is contained in:
Quentin McGaw
2023-09-23 12:57:12 +02:00
committed by GitHub
parent c435bbb32c
commit 71201411f4
30 changed files with 453 additions and 476 deletions

View File

@@ -0,0 +1,23 @@
package service
import (
"fmt"
"os"
)
func (s *Service) writePortForwardedFile(port uint16) (err error) {
filepath := *s.settings.UserSettings.Filepath
s.logger.Info("writing port file " + filepath)
const perms = os.FileMode(0644)
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
if err != nil {
return fmt.Errorf("writing file: %w", err)
}
err = os.Chown(filepath, s.puid, s.pgid)
if err != nil {
return fmt.Errorf("chowning file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,16 @@
package service
import (
"context"
)
type PortAllower interface {
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
}
type Logger interface {
Info(s string)
Warn(s string)
Error(s string)
}

View File

@@ -0,0 +1,45 @@
package service
import (
"context"
"net/http"
"sync"
)
type Service struct {
// State
portMutex sync.RWMutex
port uint16
// Fixed parameters
settings Settings
puid int
pgid int
// Fixed injected objets
client *http.Client
portAllower PortAllower
logger Logger
// Internal channels and locks
startStopMutex sync.Mutex
keepPortCancel context.CancelFunc
keepPortDoneCh <-chan struct{}
}
func New(settings Settings, client *http.Client,
portAllower PortAllower, logger Logger, puid, pgid int) *Service {
return &Service{
// Fixed parameters
settings: settings,
puid: puid,
pgid: pgid,
// Fixed injected objets
client: client,
portAllower: portAllower,
logger: logger,
}
}
func (s *Service) GetPortForwarded() (port uint16) {
s.portMutex.RLock()
defer s.portMutex.RUnlock()
return s.port
}

View File

@@ -0,0 +1,79 @@
package service
import (
"errors"
"fmt"
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gosettings"
)
type Settings struct {
UserSettings settings.PortForwarding
PortForwarder provider.PortForwarder
Gateway netip.Addr // needed for PIA and ProtonVPN
ServerName string // needed for PIA
Interface string // needed for PIA and ProtonVPN, tun0 for example
VPNProvider string // used to validate new settings
}
// UpdateWith deep copies the receiving settings, overrides the copy with
// fields set in the partialUpdate argument, validates the new settings
// and returns them if they are valid, or returns an error otherwise.
// In all cases, the receiving settings are unmodified.
func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) {
updatedSettings = s.copy()
updatedSettings.overrideWith(partialUpdate)
err = updatedSettings.validate()
if err != nil {
return updatedSettings, fmt.Errorf("validating new settings: %w", err)
}
return updatedSettings, nil
}
func (s Settings) copy() (copied Settings) {
copied.UserSettings = s.UserSettings.Copy()
copied.PortForwarder = s.PortForwarder
copied.Gateway = s.Gateway
copied.ServerName = s.ServerName
copied.Interface = s.Interface
copied.VPNProvider = s.VPNProvider
return copied
}
func (s *Settings) overrideWith(update Settings) {
s.UserSettings.OverrideWith(update.UserSettings)
s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder)
s.Gateway = gosettings.OverrideWithValidator(s.Gateway, update.Gateway)
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider)
}
var (
ErrVPNProviderNotSet = errors.New("VPN provider not set")
ErrServerNameNotSet = errors.New("server name not set")
ErrPortForwarderNotSet = errors.New("port forwarder not set")
ErrGatewayNotSet = errors.New("gateway not set")
ErrInterfaceNotSet = errors.New("interface not set")
)
func (s *Settings) validate() (err error) {
switch {
case s.VPNProvider == "":
return fmt.Errorf("%w", ErrVPNProviderNotSet)
case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "":
return fmt.Errorf("%w", ErrServerNameNotSet)
case s.PortForwarder == nil:
return fmt.Errorf("%w", ErrPortForwarderNotSet)
case !s.Gateway.IsValid():
return fmt.Errorf("%w", ErrGatewayNotSet)
case s.Interface == "":
return fmt.Errorf("%w", ErrInterfaceNotSet)
}
return s.UserSettings.Validate(s.VPNProvider)
}

View File

@@ -0,0 +1,60 @@
package service
import (
"context"
"fmt"
)
func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) {
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
if !*s.settings.UserSettings.Enabled {
return nil, nil //nolint:nilnil
}
s.logger.Info("starting")
port, err := s.settings.PortForwarder.PortForward(ctx, s.client, s.logger,
s.settings.Gateway, s.settings.ServerName)
if err != nil {
return nil, fmt.Errorf("port forwarding for the first time: %w", err)
}
s.logger.Info("port forwarded is " + fmt.Sprint(int(port)))
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
if err != nil {
return nil, fmt.Errorf("allowing port in firewall: %w", err)
}
err = s.writePortForwardedFile(port)
if err != nil {
_ = s.cleanup()
return nil, fmt.Errorf("writing port file: %w", err)
}
s.portMutex.Lock()
s.port = port
s.portMutex.Unlock()
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
s.keepPortCancel = keepPortCancel
runErrorCh := make(chan error)
keepPortDoneCh := make(chan struct{})
s.keepPortDoneCh = keepPortDoneCh
go func(ctx context.Context, settings Settings, port uint16,
runError chan<- error, doneCh chan<- struct{}) {
defer close(doneCh)
err = settings.PortForwarder.KeepPortForward(ctx, port,
settings.Gateway, settings.ServerName, s.logger)
crashed := ctx.Err() == nil
if !crashed { // stopped by Stop call
return
}
_ = s.cleanup()
runError <- err
}(keepPortCtx, s.settings, port, runErrorCh, keepPortDoneCh)
return runErrorCh, nil
}

View File

@@ -0,0 +1,47 @@
package service
import (
"context"
"fmt"
"os"
)
func (s *Service) Stop() (err error) {
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
s.portMutex.RLock()
serviceNotRunning := s.port == 0
s.portMutex.RUnlock()
if serviceNotRunning {
return nil
}
s.logger.Info("stopping")
s.keepPortCancel()
<-s.keepPortDoneCh
return s.cleanup()
}
func (s *Service) cleanup() (err error) {
s.portMutex.Lock()
defer s.portMutex.Unlock()
err = s.portAllower.RemoveAllowedPort(context.Background(), s.port)
if err != nil {
return fmt.Errorf("blocking previous port in firewall: %w", err)
}
s.port = 0
filepath := *s.settings.UserSettings.Filepath
s.logger.Info("removing port file " + filepath)
err = os.Remove(filepath)
if err != nil {
return fmt.Errorf("removing port file: %w", err)
}
return nil
}