fix(portforward): rework run loop and fix deadlocks (#1874)
This commit is contained in:
23
internal/portforward/service/fs.go
Normal file
23
internal/portforward/service/fs.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func (s *Service) writePortForwardedFile(port uint16) (err error) {
|
||||
filepath := *s.settings.UserSettings.Filepath
|
||||
s.logger.Info("writing port file " + filepath)
|
||||
const perms = os.FileMode(0644)
|
||||
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing file: %w", err)
|
||||
}
|
||||
|
||||
err = os.Chown(filepath, s.puid, s.pgid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("chowning file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
16
internal/portforward/service/interfaces.go
Normal file
16
internal/portforward/service/interfaces.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type PortAllower interface {
|
||||
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
|
||||
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
45
internal/portforward/service/service.go
Normal file
45
internal/portforward/service/service.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
// State
|
||||
portMutex sync.RWMutex
|
||||
port uint16
|
||||
// Fixed parameters
|
||||
settings Settings
|
||||
puid int
|
||||
pgid int
|
||||
// Fixed injected objets
|
||||
client *http.Client
|
||||
portAllower PortAllower
|
||||
logger Logger
|
||||
// Internal channels and locks
|
||||
startStopMutex sync.Mutex
|
||||
keepPortCancel context.CancelFunc
|
||||
keepPortDoneCh <-chan struct{}
|
||||
}
|
||||
|
||||
func New(settings Settings, client *http.Client,
|
||||
portAllower PortAllower, logger Logger, puid, pgid int) *Service {
|
||||
return &Service{
|
||||
// Fixed parameters
|
||||
settings: settings,
|
||||
puid: puid,
|
||||
pgid: pgid,
|
||||
// Fixed injected objets
|
||||
client: client,
|
||||
portAllower: portAllower,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) GetPortForwarded() (port uint16) {
|
||||
s.portMutex.RLock()
|
||||
defer s.portMutex.RUnlock()
|
||||
return s.port
|
||||
}
|
||||
79
internal/portforward/service/settings.go
Normal file
79
internal/portforward/service/settings.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gosettings"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
UserSettings settings.PortForwarding
|
||||
PortForwarder provider.PortForwarder
|
||||
Gateway netip.Addr // needed for PIA and ProtonVPN
|
||||
ServerName string // needed for PIA
|
||||
Interface string // needed for PIA and ProtonVPN, tun0 for example
|
||||
VPNProvider string // used to validate new settings
|
||||
}
|
||||
|
||||
// UpdateWith deep copies the receiving settings, overrides the copy with
|
||||
// fields set in the partialUpdate argument, validates the new settings
|
||||
// and returns them if they are valid, or returns an error otherwise.
|
||||
// In all cases, the receiving settings are unmodified.
|
||||
func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) {
|
||||
updatedSettings = s.copy()
|
||||
updatedSettings.overrideWith(partialUpdate)
|
||||
err = updatedSettings.validate()
|
||||
if err != nil {
|
||||
return updatedSettings, fmt.Errorf("validating new settings: %w", err)
|
||||
}
|
||||
return updatedSettings, nil
|
||||
}
|
||||
|
||||
func (s Settings) copy() (copied Settings) {
|
||||
copied.UserSettings = s.UserSettings.Copy()
|
||||
copied.PortForwarder = s.PortForwarder
|
||||
copied.Gateway = s.Gateway
|
||||
copied.ServerName = s.ServerName
|
||||
copied.Interface = s.Interface
|
||||
copied.VPNProvider = s.VPNProvider
|
||||
return copied
|
||||
}
|
||||
|
||||
func (s *Settings) overrideWith(update Settings) {
|
||||
s.UserSettings.OverrideWith(update.UserSettings)
|
||||
s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder)
|
||||
s.Gateway = gosettings.OverrideWithValidator(s.Gateway, update.Gateway)
|
||||
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
|
||||
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
|
||||
s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVPNProviderNotSet = errors.New("VPN provider not set")
|
||||
ErrServerNameNotSet = errors.New("server name not set")
|
||||
ErrPortForwarderNotSet = errors.New("port forwarder not set")
|
||||
ErrGatewayNotSet = errors.New("gateway not set")
|
||||
ErrInterfaceNotSet = errors.New("interface not set")
|
||||
)
|
||||
|
||||
func (s *Settings) validate() (err error) {
|
||||
switch {
|
||||
case s.VPNProvider == "":
|
||||
return fmt.Errorf("%w", ErrVPNProviderNotSet)
|
||||
case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "":
|
||||
return fmt.Errorf("%w", ErrServerNameNotSet)
|
||||
case s.PortForwarder == nil:
|
||||
return fmt.Errorf("%w", ErrPortForwarderNotSet)
|
||||
case !s.Gateway.IsValid():
|
||||
return fmt.Errorf("%w", ErrGatewayNotSet)
|
||||
case s.Interface == "":
|
||||
return fmt.Errorf("%w", ErrInterfaceNotSet)
|
||||
}
|
||||
|
||||
return s.UserSettings.Validate(s.VPNProvider)
|
||||
}
|
||||
60
internal/portforward/service/start.go
Normal file
60
internal/portforward/service/start.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) {
|
||||
s.startStopMutex.Lock()
|
||||
defer s.startStopMutex.Unlock()
|
||||
|
||||
if !*s.settings.UserSettings.Enabled {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
s.logger.Info("starting")
|
||||
port, err := s.settings.PortForwarder.PortForward(ctx, s.client, s.logger,
|
||||
s.settings.Gateway, s.settings.ServerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port forwarding for the first time: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("port forwarded is " + fmt.Sprint(int(port)))
|
||||
|
||||
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allowing port in firewall: %w", err)
|
||||
}
|
||||
|
||||
err = s.writePortForwardedFile(port)
|
||||
if err != nil {
|
||||
_ = s.cleanup()
|
||||
return nil, fmt.Errorf("writing port file: %w", err)
|
||||
}
|
||||
|
||||
s.portMutex.Lock()
|
||||
s.port = port
|
||||
s.portMutex.Unlock()
|
||||
|
||||
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
|
||||
s.keepPortCancel = keepPortCancel
|
||||
runErrorCh := make(chan error)
|
||||
keepPortDoneCh := make(chan struct{})
|
||||
s.keepPortDoneCh = keepPortDoneCh
|
||||
|
||||
go func(ctx context.Context, settings Settings, port uint16,
|
||||
runError chan<- error, doneCh chan<- struct{}) {
|
||||
defer close(doneCh)
|
||||
err = settings.PortForwarder.KeepPortForward(ctx, port,
|
||||
settings.Gateway, settings.ServerName, s.logger)
|
||||
crashed := ctx.Err() == nil
|
||||
if !crashed { // stopped by Stop call
|
||||
return
|
||||
}
|
||||
_ = s.cleanup()
|
||||
runError <- err
|
||||
}(keepPortCtx, s.settings, port, runErrorCh, keepPortDoneCh)
|
||||
|
||||
return runErrorCh, nil
|
||||
}
|
||||
47
internal/portforward/service/stop.go
Normal file
47
internal/portforward/service/stop.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func (s *Service) Stop() (err error) {
|
||||
s.startStopMutex.Lock()
|
||||
defer s.startStopMutex.Unlock()
|
||||
|
||||
s.portMutex.RLock()
|
||||
serviceNotRunning := s.port == 0
|
||||
s.portMutex.RUnlock()
|
||||
if serviceNotRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info("stopping")
|
||||
|
||||
s.keepPortCancel()
|
||||
<-s.keepPortDoneCh
|
||||
|
||||
return s.cleanup()
|
||||
}
|
||||
|
||||
func (s *Service) cleanup() (err error) {
|
||||
s.portMutex.Lock()
|
||||
defer s.portMutex.Unlock()
|
||||
|
||||
err = s.portAllower.RemoveAllowedPort(context.Background(), s.port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("blocking previous port in firewall: %w", err)
|
||||
}
|
||||
|
||||
s.port = 0
|
||||
|
||||
filepath := *s.settings.UserSettings.Filepath
|
||||
s.logger.Info("removing port file " + filepath)
|
||||
err = os.Remove(filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing port file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user