fix(portforward): rework run loop and fix deadlocks (#1874)
This commit is contained in:
@@ -377,9 +377,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
portForwardLogger := logger.New(log.SetComponent("port forwarding"))
|
||||
portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding,
|
||||
httpClient, firewallConf, portForwardLogger, puid, pgid)
|
||||
portForwardHandler, portForwardCtx, portForwardDone := goshutdown.NewGoRoutineHandler(
|
||||
"port forwarding", goroutine.OptionTimeout(time.Second))
|
||||
go portForwardLooper.Run(portForwardCtx, portForwardDone)
|
||||
portForwardRunError, _ := portForwardLooper.Start(context.Background())
|
||||
|
||||
unboundLogger := logger.New(log.SetComponent("dns"))
|
||||
unboundLooper := dns.NewLoop(dnsConf, allSettings.DNS, httpClient,
|
||||
@@ -481,13 +479,21 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
order.OptionOnSuccess(defaultShutdownOnSuccess),
|
||||
order.OptionOnFailure(defaultShutdownOnFailure))
|
||||
orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler,
|
||||
vpnHandler, portForwardHandler, otherGroupHandler)
|
||||
vpnHandler, otherGroupHandler)
|
||||
|
||||
// Start VPN for the first time in a blocking call
|
||||
// until the VPN is launched
|
||||
_, _ = vpnLooper.ApplyStatus(ctx, constants.Running) // TODO option to disable with variable
|
||||
|
||||
<-ctx.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = portForwardLooper.Stop()
|
||||
if err != nil {
|
||||
logger.Error("stopping port forward loop: " + err.Error())
|
||||
}
|
||||
case err := <-portForwardRunError:
|
||||
logger.Errorf("port forwarding loop crashed: %s", err)
|
||||
}
|
||||
|
||||
return orderHandler.Shutdown(context.Background())
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ type PortForwarding struct {
|
||||
Filepath *string `json:"status_file_path"`
|
||||
}
|
||||
|
||||
func (p PortForwarding) validate(vpnProvider string) (err error) {
|
||||
func (p PortForwarding) Validate(vpnProvider string) (err error) {
|
||||
if !*p.Enabled {
|
||||
return nil
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (p PortForwarding) validate(vpnProvider string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PortForwarding) copy() (copied PortForwarding) {
|
||||
func (p *PortForwarding) Copy() (copied PortForwarding) {
|
||||
return PortForwarding{
|
||||
Enabled: gosettings.CopyPointer(p.Enabled),
|
||||
Provider: gosettings.CopyPointer(p.Provider),
|
||||
@@ -73,7 +73,7 @@ func (p *PortForwarding) mergeWith(other PortForwarding) {
|
||||
p.Filepath = gosettings.MergeWithPointer(p.Filepath, other.Filepath)
|
||||
}
|
||||
|
||||
func (p *PortForwarding) overrideWith(other PortForwarding) {
|
||||
func (p *PortForwarding) OverrideWith(other PortForwarding) {
|
||||
p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled)
|
||||
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
|
||||
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
|
||||
|
||||
@@ -49,7 +49,7 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) {
|
||||
return fmt.Errorf("server selection: %w", err)
|
||||
}
|
||||
|
||||
err = p.PortForwarding.validate(*p.Name)
|
||||
err = p.PortForwarding.Validate(*p.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("port forwarding: %w", err)
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (p *Provider) copy() (copied Provider) {
|
||||
return Provider{
|
||||
Name: gosettings.CopyPointer(p.Name),
|
||||
ServerSelection: p.ServerSelection.copy(),
|
||||
PortForwarding: p.PortForwarding.copy(),
|
||||
PortForwarding: p.PortForwarding.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (p *Provider) mergeWith(other Provider) {
|
||||
func (p *Provider) overrideWith(other Provider) {
|
||||
p.Name = gosettings.OverrideWithPointer(p.Name, other.Name)
|
||||
p.ServerSelection.overrideWith(other.ServerSelection)
|
||||
p.PortForwarding.overrideWith(other.PortForwarding)
|
||||
p.PortForwarding.OverrideWith(other.PortForwarding)
|
||||
}
|
||||
|
||||
func (p *Provider) setDefaults() {
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package portforward
|
||||
|
||||
import "context"
|
||||
|
||||
// firewallBlockPort obtains the state port thread safely and blocks
|
||||
// it in the firewall if it is not the zero value (0).
|
||||
func (l *Loop) firewallBlockPort(ctx context.Context) {
|
||||
port := l.state.GetPortForwarded()
|
||||
if port == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := l.portAllower.RemoveAllowedPort(ctx, port)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot block previous port in firewall: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// firewallAllowPort obtains the state port thread safely and allows
|
||||
// it in the firewall if it is not the zero value (0).
|
||||
func (l *Loop) firewallAllowPort(ctx context.Context) {
|
||||
port := l.state.GetPortForwarded()
|
||||
if port == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
startData := l.state.GetStartData()
|
||||
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot allow port: " + err.Error())
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func (l *Loop) removePortForwardedFile() {
|
||||
filepath := *l.state.GetSettings().Filepath
|
||||
l.logger.Info("removing port file " + filepath)
|
||||
if err := os.Remove(filepath); err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) writePortForwardedFile(port uint16) {
|
||||
filepath := *l.state.GetSettings().Filepath
|
||||
l.logger.Info("writing port file " + filepath)
|
||||
if err := writePortForwardedToFile(filepath, port, l.puid, l.pgid); err != nil {
|
||||
l.logger.Error("writing port forwarded to file: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(filepath string, port uint16, uid, gid int) (err error) {
|
||||
const perms = os.FileMode(0644)
|
||||
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing file: %w", err)
|
||||
}
|
||||
|
||||
err = os.Chown(filepath, uid, gid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("chowning file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package portforward
|
||||
|
||||
func (l *Loop) GetPortForwarded() (port uint16) {
|
||||
return l.state.GetPortForwarded()
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (l *Loop) logAndWait(ctx context.Context, err error) {
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
l.logger.Info("retrying in " + l.backoffTime.String())
|
||||
timer := time.NewTimer(l.backoffTime)
|
||||
l.backoffTime *= 2
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,20 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
import "context"
|
||||
|
||||
type Service interface {
|
||||
Start(ctx context.Context) (runError <-chan error, err error)
|
||||
Stop() (err error)
|
||||
GetPortForwarded() (port uint16)
|
||||
}
|
||||
|
||||
type PortAllower interface {
|
||||
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
|
||||
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
package portforward
|
||||
|
||||
type Logger interface {
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -1,64 +1,139 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/portforward/state"
|
||||
"github.com/qdm12/gluetun/internal/portforward/service"
|
||||
)
|
||||
|
||||
type Loop struct {
|
||||
statusManager *loopstate.State
|
||||
state *state.State
|
||||
// Fixed parameters
|
||||
puid int
|
||||
pgid int
|
||||
// Objects
|
||||
// State
|
||||
settings service.Settings
|
||||
settingsMutex sync.RWMutex
|
||||
service Service
|
||||
// Fixed injected objets
|
||||
client *http.Client
|
||||
portAllower PortAllower
|
||||
logger Logger
|
||||
// Fixed parameters
|
||||
uid, gid int
|
||||
// Internal channels and locks
|
||||
start chan struct{}
|
||||
running chan models.LoopStatus
|
||||
stop chan struct{}
|
||||
stopped chan struct{}
|
||||
startMu sync.Mutex
|
||||
backoffTime time.Duration
|
||||
userTrigger bool
|
||||
// runCtx is used to detect when the loop has exited
|
||||
// when performing an update
|
||||
runCtx context.Context //nolint:containedctx
|
||||
runCancel context.CancelFunc
|
||||
updatedSignal chan<- struct{}
|
||||
runDone <-chan struct{}
|
||||
}
|
||||
|
||||
const defaultBackoffTime = 5 * time.Second
|
||||
|
||||
func NewLoop(settings settings.PortForwarding,
|
||||
client *http.Client, portAllower PortAllower,
|
||||
logger Logger, puid, pgid int) *Loop {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
stop := make(chan struct{})
|
||||
stopped := make(chan struct{})
|
||||
|
||||
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
|
||||
state := state.New(statusManager, settings)
|
||||
|
||||
logger Logger, uid, gid int) *Loop {
|
||||
return &Loop{
|
||||
statusManager: statusManager,
|
||||
state: state,
|
||||
puid: puid,
|
||||
pgid: pgid,
|
||||
// Objects
|
||||
settings: service.Settings{
|
||||
UserSettings: settings,
|
||||
},
|
||||
client: client,
|
||||
portAllower: portAllower,
|
||||
logger: logger,
|
||||
start: start,
|
||||
running: running,
|
||||
stop: stop,
|
||||
stopped: stopped,
|
||||
userTrigger: true,
|
||||
backoffTime: defaultBackoffTime,
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) Start(_ context.Context) (runError <-chan error, _ error) {
|
||||
l.runCtx, l.runCancel = context.WithCancel(context.Background())
|
||||
runDone := make(chan struct{})
|
||||
l.runDone = runDone
|
||||
|
||||
updatedSignal := make(chan struct{})
|
||||
l.updatedSignal = updatedSignal
|
||||
runErrorCh := make(chan error)
|
||||
|
||||
go l.run(l.runCtx, runDone, runErrorCh, updatedSignal)
|
||||
|
||||
return runErrorCh, nil
|
||||
}
|
||||
|
||||
func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
|
||||
runErrorCh chan<- error, updatedSignal <-chan struct{}) {
|
||||
defer close(runDone)
|
||||
|
||||
var serviceRunError <-chan error
|
||||
for {
|
||||
select {
|
||||
case <-runCtx.Done():
|
||||
// Stop call takes care of stopping the service
|
||||
return
|
||||
case <-updatedSignal: // first and subsequent start trigger
|
||||
case err := <-serviceRunError:
|
||||
runErrorCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
firstRun := l.service == nil
|
||||
if !firstRun {
|
||||
err := l.service.Stop()
|
||||
if err != nil {
|
||||
runErrorCh <- fmt.Errorf("stopping previous service: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
l.settingsMutex.RLock()
|
||||
l.service = service.New(l.settings, l.client,
|
||||
l.portAllower, l.logger, l.uid, l.gid)
|
||||
l.settingsMutex.RUnlock()
|
||||
|
||||
var err error
|
||||
serviceRunError, err = l.service.Start(runCtx)
|
||||
if err != nil {
|
||||
if runCtx.Err() == nil { // crashed but NOT stopped
|
||||
runErrorCh <- fmt.Errorf("starting new service: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) UpdateWith(partialUpdate service.Settings) (err error) {
|
||||
l.settingsMutex.Lock()
|
||||
l.settings, err = l.settings.UpdateWith(partialUpdate)
|
||||
l.settingsMutex.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case l.updatedSignal <- struct{}{}:
|
||||
// Settings are validated and if the service fails to start
|
||||
// or crashes at runtime, the loop will stop and signal its
|
||||
// parent goroutine. Settings validation should be the only
|
||||
// error feedback for the caller of `Update`.
|
||||
return nil
|
||||
case <-l.runCtx.Done():
|
||||
// loop has been stopped, no update can be done
|
||||
return l.runCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) Stop() (err error) {
|
||||
l.runCancel()
|
||||
<-l.runDone
|
||||
|
||||
if l.service != nil {
|
||||
return l.service.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Loop) GetPortForwarded() (port uint16) {
|
||||
if l.service == nil {
|
||||
return 0
|
||||
}
|
||||
return l.service.GetPortForwarded()
|
||||
}
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
defer close(done)
|
||||
|
||||
select {
|
||||
case <-l.start: // l.state.SetStartData called beforehand
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
for ctx.Err() == nil {
|
||||
pfCtx, pfCancel := context.WithCancel(ctx)
|
||||
|
||||
portCh := make(chan uint16)
|
||||
errorCh := make(chan error)
|
||||
|
||||
startData := l.state.GetStartData()
|
||||
|
||||
go func(ctx context.Context, startData StartData) {
|
||||
port, err := startData.PortForwarder.PortForward(ctx, l.client, l.logger,
|
||||
startData.Gateway, startData.ServerName)
|
||||
if err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
portCh <- port
|
||||
|
||||
// Infinite loop
|
||||
err = startData.PortForwarder.KeepPortForward(ctx, port,
|
||||
startData.Gateway, startData.ServerName, l.logger)
|
||||
errorCh <- err
|
||||
}(pfCtx, startData)
|
||||
|
||||
if l.userTrigger {
|
||||
l.userTrigger = false
|
||||
l.running <- constants.Running
|
||||
} else { // crash
|
||||
l.backoffTime = defaultBackoffTime
|
||||
l.statusManager.SetStatus(constants.Running)
|
||||
}
|
||||
|
||||
stayHere := true
|
||||
stopped := false
|
||||
for stayHere {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pfCancel()
|
||||
if stopped {
|
||||
return
|
||||
}
|
||||
<-errorCh
|
||||
close(errorCh)
|
||||
close(portCh)
|
||||
l.removePortForwardedFile()
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(0)
|
||||
return
|
||||
case <-l.start:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("starting")
|
||||
pfCancel()
|
||||
stayHere = false
|
||||
case <-l.stop:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("stopping")
|
||||
pfCancel()
|
||||
<-errorCh
|
||||
l.removePortForwardedFile()
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(0)
|
||||
l.stopped <- struct{}{}
|
||||
stopped = true
|
||||
case port := <-portCh:
|
||||
l.logger.Info("port forwarded is " + strconv.Itoa(int(port)))
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(port)
|
||||
l.firewallAllowPort(ctx)
|
||||
l.writePortForwardedFile(port)
|
||||
case err := <-errorCh:
|
||||
pfCancel()
|
||||
close(errorCh)
|
||||
close(portCh)
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
stayHere = false
|
||||
}
|
||||
}
|
||||
pfCancel() // for linting
|
||||
}
|
||||
}
|
||||
23
internal/portforward/service/fs.go
Normal file
23
internal/portforward/service/fs.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func (s *Service) writePortForwardedFile(port uint16) (err error) {
|
||||
filepath := *s.settings.UserSettings.Filepath
|
||||
s.logger.Info("writing port file " + filepath)
|
||||
const perms = os.FileMode(0644)
|
||||
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing file: %w", err)
|
||||
}
|
||||
|
||||
err = os.Chown(filepath, s.puid, s.pgid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("chowning file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
16
internal/portforward/service/interfaces.go
Normal file
16
internal/portforward/service/interfaces.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type PortAllower interface {
|
||||
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
|
||||
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Info(s string)
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
45
internal/portforward/service/service.go
Normal file
45
internal/portforward/service/service.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
// State
|
||||
portMutex sync.RWMutex
|
||||
port uint16
|
||||
// Fixed parameters
|
||||
settings Settings
|
||||
puid int
|
||||
pgid int
|
||||
// Fixed injected objets
|
||||
client *http.Client
|
||||
portAllower PortAllower
|
||||
logger Logger
|
||||
// Internal channels and locks
|
||||
startStopMutex sync.Mutex
|
||||
keepPortCancel context.CancelFunc
|
||||
keepPortDoneCh <-chan struct{}
|
||||
}
|
||||
|
||||
func New(settings Settings, client *http.Client,
|
||||
portAllower PortAllower, logger Logger, puid, pgid int) *Service {
|
||||
return &Service{
|
||||
// Fixed parameters
|
||||
settings: settings,
|
||||
puid: puid,
|
||||
pgid: pgid,
|
||||
// Fixed injected objets
|
||||
client: client,
|
||||
portAllower: portAllower,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) GetPortForwarded() (port uint16) {
|
||||
s.portMutex.RLock()
|
||||
defer s.portMutex.RUnlock()
|
||||
return s.port
|
||||
}
|
||||
79
internal/portforward/service/settings.go
Normal file
79
internal/portforward/service/settings.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gosettings"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
UserSettings settings.PortForwarding
|
||||
PortForwarder provider.PortForwarder
|
||||
Gateway netip.Addr // needed for PIA and ProtonVPN
|
||||
ServerName string // needed for PIA
|
||||
Interface string // needed for PIA and ProtonVPN, tun0 for example
|
||||
VPNProvider string // used to validate new settings
|
||||
}
|
||||
|
||||
// UpdateWith deep copies the receiving settings, overrides the copy with
|
||||
// fields set in the partialUpdate argument, validates the new settings
|
||||
// and returns them if they are valid, or returns an error otherwise.
|
||||
// In all cases, the receiving settings are unmodified.
|
||||
func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) {
|
||||
updatedSettings = s.copy()
|
||||
updatedSettings.overrideWith(partialUpdate)
|
||||
err = updatedSettings.validate()
|
||||
if err != nil {
|
||||
return updatedSettings, fmt.Errorf("validating new settings: %w", err)
|
||||
}
|
||||
return updatedSettings, nil
|
||||
}
|
||||
|
||||
func (s Settings) copy() (copied Settings) {
|
||||
copied.UserSettings = s.UserSettings.Copy()
|
||||
copied.PortForwarder = s.PortForwarder
|
||||
copied.Gateway = s.Gateway
|
||||
copied.ServerName = s.ServerName
|
||||
copied.Interface = s.Interface
|
||||
copied.VPNProvider = s.VPNProvider
|
||||
return copied
|
||||
}
|
||||
|
||||
func (s *Settings) overrideWith(update Settings) {
|
||||
s.UserSettings.OverrideWith(update.UserSettings)
|
||||
s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder)
|
||||
s.Gateway = gosettings.OverrideWithValidator(s.Gateway, update.Gateway)
|
||||
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
|
||||
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
|
||||
s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVPNProviderNotSet = errors.New("VPN provider not set")
|
||||
ErrServerNameNotSet = errors.New("server name not set")
|
||||
ErrPortForwarderNotSet = errors.New("port forwarder not set")
|
||||
ErrGatewayNotSet = errors.New("gateway not set")
|
||||
ErrInterfaceNotSet = errors.New("interface not set")
|
||||
)
|
||||
|
||||
func (s *Settings) validate() (err error) {
|
||||
switch {
|
||||
case s.VPNProvider == "":
|
||||
return fmt.Errorf("%w", ErrVPNProviderNotSet)
|
||||
case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "":
|
||||
return fmt.Errorf("%w", ErrServerNameNotSet)
|
||||
case s.PortForwarder == nil:
|
||||
return fmt.Errorf("%w", ErrPortForwarderNotSet)
|
||||
case !s.Gateway.IsValid():
|
||||
return fmt.Errorf("%w", ErrGatewayNotSet)
|
||||
case s.Interface == "":
|
||||
return fmt.Errorf("%w", ErrInterfaceNotSet)
|
||||
}
|
||||
|
||||
return s.UserSettings.Validate(s.VPNProvider)
|
||||
}
|
||||
60
internal/portforward/service/start.go
Normal file
60
internal/portforward/service/start.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) {
|
||||
s.startStopMutex.Lock()
|
||||
defer s.startStopMutex.Unlock()
|
||||
|
||||
if !*s.settings.UserSettings.Enabled {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
s.logger.Info("starting")
|
||||
port, err := s.settings.PortForwarder.PortForward(ctx, s.client, s.logger,
|
||||
s.settings.Gateway, s.settings.ServerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port forwarding for the first time: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("port forwarded is " + fmt.Sprint(int(port)))
|
||||
|
||||
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allowing port in firewall: %w", err)
|
||||
}
|
||||
|
||||
err = s.writePortForwardedFile(port)
|
||||
if err != nil {
|
||||
_ = s.cleanup()
|
||||
return nil, fmt.Errorf("writing port file: %w", err)
|
||||
}
|
||||
|
||||
s.portMutex.Lock()
|
||||
s.port = port
|
||||
s.portMutex.Unlock()
|
||||
|
||||
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
|
||||
s.keepPortCancel = keepPortCancel
|
||||
runErrorCh := make(chan error)
|
||||
keepPortDoneCh := make(chan struct{})
|
||||
s.keepPortDoneCh = keepPortDoneCh
|
||||
|
||||
go func(ctx context.Context, settings Settings, port uint16,
|
||||
runError chan<- error, doneCh chan<- struct{}) {
|
||||
defer close(doneCh)
|
||||
err = settings.PortForwarder.KeepPortForward(ctx, port,
|
||||
settings.Gateway, settings.ServerName, s.logger)
|
||||
crashed := ctx.Err() == nil
|
||||
if !crashed { // stopped by Stop call
|
||||
return
|
||||
}
|
||||
_ = s.cleanup()
|
||||
runError <- err
|
||||
}(keepPortCtx, s.settings, port, runErrorCh, keepPortDoneCh)
|
||||
|
||||
return runErrorCh, nil
|
||||
}
|
||||
47
internal/portforward/service/stop.go
Normal file
47
internal/portforward/service/stop.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func (s *Service) Stop() (err error) {
|
||||
s.startStopMutex.Lock()
|
||||
defer s.startStopMutex.Unlock()
|
||||
|
||||
s.portMutex.RLock()
|
||||
serviceNotRunning := s.port == 0
|
||||
s.portMutex.RUnlock()
|
||||
if serviceNotRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info("stopping")
|
||||
|
||||
s.keepPortCancel()
|
||||
<-s.keepPortDoneCh
|
||||
|
||||
return s.cleanup()
|
||||
}
|
||||
|
||||
func (s *Service) cleanup() (err error) {
|
||||
s.portMutex.Lock()
|
||||
defer s.portMutex.Unlock()
|
||||
|
||||
err = s.portAllower.RemoveAllowedPort(context.Background(), s.port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("blocking previous port in firewall: %w", err)
|
||||
}
|
||||
|
||||
s.port = 0
|
||||
|
||||
filepath := *s.settings.UserSettings.Filepath
|
||||
s.logger.Info("removing port file " + filepath)
|
||||
err = os.Remove(filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing port file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
)
|
||||
|
||||
func (l *Loop) GetSettings() (settings settings.PortForwarding) {
|
||||
return l.state.GetSettings()
|
||||
}
|
||||
|
||||
func (l *Loop) SetSettings(ctx context.Context, settings settings.PortForwarding) (
|
||||
outcome string) {
|
||||
return l.state.SetSettings(ctx, settings)
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package state
|
||||
|
||||
// GetPortForwarded is used by the control HTTP server
|
||||
// to obtain the port currently forwarded.
|
||||
func (s *State) GetPortForwarded() (port uint16) {
|
||||
s.portForwardedMu.RLock()
|
||||
defer s.portForwardedMu.RUnlock()
|
||||
return s.portForwarded
|
||||
}
|
||||
|
||||
// SetPortForwarded is only used from within the OpenVPN loop
|
||||
// to set the port forwarded.
|
||||
func (s *State) SetPortForwarded(port uint16) {
|
||||
s.portForwardedMu.Lock()
|
||||
defer s.portForwardedMu.Unlock()
|
||||
s.portForwarded = port
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
func (s *State) GetSettings() (settings settings.PortForwarding) {
|
||||
s.settingsMu.RLock()
|
||||
defer s.settingsMu.RUnlock()
|
||||
return s.settings
|
||||
}
|
||||
|
||||
func (s *State) SetSettings(ctx context.Context, settings settings.PortForwarding) (
|
||||
outcome string) {
|
||||
s.settingsMu.Lock()
|
||||
|
||||
settingsUnchanged := reflect.DeepEqual(s.settings, settings)
|
||||
if settingsUnchanged {
|
||||
s.settingsMu.Unlock()
|
||||
return "settings left unchanged"
|
||||
}
|
||||
|
||||
if s.settings.Filepath != settings.Filepath {
|
||||
_ = os.Rename(*s.settings.Filepath, *settings.Filepath)
|
||||
}
|
||||
|
||||
newEnabled := *settings.Enabled
|
||||
previousEnabled := *s.settings.Enabled
|
||||
|
||||
s.settings = settings
|
||||
s.settingsMu.Unlock()
|
||||
|
||||
switch {
|
||||
case !newEnabled && !previousEnabled:
|
||||
case newEnabled && previousEnabled:
|
||||
// no need to restart for now since we os.Rename the file here.
|
||||
case newEnabled && !previousEnabled:
|
||||
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
|
||||
case !newEnabled && previousEnabled:
|
||||
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
|
||||
}
|
||||
|
||||
return "settings updated"
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
|
||||
type StartData struct {
|
||||
PortForwarder provider.PortForwarder
|
||||
Gateway netip.Addr // needed for PIA
|
||||
ServerName string // needed for PIA
|
||||
Interface string // tun0 for example
|
||||
}
|
||||
|
||||
func (s *State) GetStartData() (startData StartData) {
|
||||
s.startDataMu.RLock()
|
||||
defer s.startDataMu.RUnlock()
|
||||
return s.startData
|
||||
}
|
||||
|
||||
func (s *State) SetStartData(startData StartData) {
|
||||
s.startDataMu.Lock()
|
||||
defer s.startDataMu.Unlock()
|
||||
s.startData = startData
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func New(statusApplier StatusApplier,
|
||||
settings settings.PortForwarding) *State {
|
||||
return &State{
|
||||
statusApplier: statusApplier,
|
||||
settings: settings,
|
||||
}
|
||||
}
|
||||
|
||||
type State struct {
|
||||
statusApplier StatusApplier
|
||||
|
||||
settings settings.PortForwarding
|
||||
settingsMu sync.RWMutex
|
||||
|
||||
portForwarded uint16
|
||||
portForwardedMu sync.RWMutex
|
||||
|
||||
startData StartData
|
||||
startDataMu sync.RWMutex
|
||||
}
|
||||
|
||||
type StatusApplier interface {
|
||||
ApplyStatus(ctx context.Context, status models.LoopStatus) (
|
||||
outcome string, err error)
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/portforward/state"
|
||||
)
|
||||
|
||||
func (l *Loop) GetStatus() (status models.LoopStatus) {
|
||||
return l.statusManager.GetStatus()
|
||||
}
|
||||
|
||||
type StartData = state.StartData
|
||||
|
||||
func (l *Loop) Start(ctx context.Context, data StartData) (
|
||||
outcome string, err error) {
|
||||
l.startMu.Lock()
|
||||
defer l.startMu.Unlock()
|
||||
l.state.SetStartData(data)
|
||||
return l.statusManager.ApplyStatus(ctx, constants.Running)
|
||||
}
|
||||
|
||||
func (l *Loop) Stop(ctx context.Context) (outcome string, err error) {
|
||||
return l.statusManager.ApplyStatus(ctx, constants.Stopped)
|
||||
}
|
||||
@@ -21,6 +21,7 @@ type Provider interface {
|
||||
}
|
||||
|
||||
type PortForwarder interface {
|
||||
Name() string
|
||||
PortForward(ctx context.Context, client *http.Client,
|
||||
logger utils.Logger, gateway netip.Addr, serverName string) (
|
||||
port uint16, err error)
|
||||
|
||||
@@ -2,14 +2,14 @@ package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"errors"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) {
|
||||
func (l *Loop) cleanup(vpnProvider string) {
|
||||
for _, vpnPort := range l.vpnInputPorts {
|
||||
err := l.fw.RemoveAllowedPort(ctx, vpnPort)
|
||||
err := l.fw.RemoveAllowedPort(context.Background(), vpnPort)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot remove allowed input port from firewall: " + err.Error())
|
||||
}
|
||||
@@ -17,11 +17,11 @@ func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) {
|
||||
|
||||
l.publicip.SetData(models.PublicIP{}) // clear public IP address data
|
||||
|
||||
if pfEnabled {
|
||||
const pfTimeout = 100 * time.Millisecond
|
||||
err := l.stopPortForwarding(ctx, pfTimeout)
|
||||
err := l.stopPortForwarding(vpnProvider)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot stop port forwarding: " + err.Error())
|
||||
portForwardingAlreadyStopped := errors.Is(err, context.Canceled)
|
||||
if !portForwardingAlreadyStopped {
|
||||
l.logger.Error("stopping port forwarding: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func ptrTo[T any](value T) *T { return &value }
|
||||
|
||||
// waitForError waits 100ms for an error in the waitError channel.
|
||||
func (l *Loop) waitForError(ctx context.Context,
|
||||
waitError chan error) (err error) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
portforward "github.com/qdm12/gluetun/internal/portforward/service"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
|
||||
@@ -22,8 +22,7 @@ type Routing interface {
|
||||
}
|
||||
|
||||
type PortForward interface {
|
||||
Start(ctx context.Context, data portforward.StartData) (outcome string, err error)
|
||||
Stop(ctx context.Context) (outcome string, err error)
|
||||
UpdateWith(settings portforward.Settings) (err error)
|
||||
}
|
||||
|
||||
type OpenVPN interface {
|
||||
|
||||
@@ -1,47 +1,35 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/portforward/service"
|
||||
)
|
||||
|
||||
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
|
||||
if !data.portForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only used for PIA for now
|
||||
func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
|
||||
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("obtaining VPN local gateway IP for interface %s: %w", data.vpnIntf, err)
|
||||
}
|
||||
l.logger.Info("VPN gateway IP address: " + gateway.String())
|
||||
|
||||
pfData := portforward.StartData{
|
||||
partialUpdate := service.Settings{
|
||||
PortForwarder: data.portForwarder,
|
||||
Gateway: gateway,
|
||||
ServerName: data.serverName,
|
||||
Interface: data.vpnIntf,
|
||||
ServerName: data.serverName,
|
||||
VPNProvider: data.portForwarder.Name(),
|
||||
}
|
||||
_, err = l.portForward.Start(ctx, pfData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting port forwarding: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return l.portForward.UpdateWith(partialUpdate)
|
||||
}
|
||||
|
||||
func (l *Loop) stopPortForwarding(ctx context.Context,
|
||||
timeout time.Duration) (err error) {
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
func (l *Loop) stopPortForwarding(vpnProvider string) (err error) {
|
||||
partialUpdate := service.Settings{
|
||||
VPNProvider: vpnProvider,
|
||||
UserSettings: settings.PortForwarding{
|
||||
Enabled: ptrTo(false),
|
||||
},
|
||||
}
|
||||
|
||||
_, err = l.portForward.Stop(ctx)
|
||||
return err
|
||||
return l.portForward.UpdateWith(partialUpdate)
|
||||
}
|
||||
|
||||
@@ -22,10 +22,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
|
||||
providerConf := l.providers.Get(*settings.Provider.Name)
|
||||
|
||||
portForwarding := *settings.Provider.PortForwarding.Enabled
|
||||
customPortForwardingProvider := *settings.Provider.PortForwarding.Provider
|
||||
portForwader := providerConf
|
||||
if portForwarding && customPortForwardingProvider != "" {
|
||||
if customPortForwardingProvider != "" {
|
||||
portForwader = l.providers.Get(customPortForwardingProvider)
|
||||
}
|
||||
|
||||
@@ -49,7 +48,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
continue
|
||||
}
|
||||
tunnelUpData := tunnelUpData{
|
||||
portForwarding: portForwarding,
|
||||
serverName: serverName,
|
||||
portForwarder: portForwader,
|
||||
vpnIntf: vpnInterface,
|
||||
@@ -76,7 +74,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
case <-tunnelReady:
|
||||
go l.onTunnelUp(openvpnCtx, tunnelUpData)
|
||||
case <-ctx.Done():
|
||||
l.cleanup(context.Background(), portForwarding)
|
||||
l.cleanup(portForwader.Name())
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
close(waitError)
|
||||
@@ -84,7 +82,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
case <-l.stop:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("stopping")
|
||||
l.cleanup(context.Background(), portForwarding)
|
||||
l.cleanup(portForwader.Name())
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
// do not close waitError or the waitError
|
||||
@@ -97,7 +95,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
case err := <-waitError: // unexpected error
|
||||
l.statusManager.Lock() // prevent SetStatus from running in parallel
|
||||
|
||||
l.cleanup(context.Background(), portForwarding)
|
||||
l.cleanup(portForwader.Name())
|
||||
openvpnCancel()
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type tunnelUpData struct {
|
||||
// Port forwarding
|
||||
portForwarding bool
|
||||
vpnIntf string
|
||||
serverName string
|
||||
portForwarder provider.PortForwarder
|
||||
@@ -42,7 +41,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
}
|
||||
}
|
||||
|
||||
err := l.startPortForwarding(ctx, data)
|
||||
err := l.startPortForwarding(data)
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user