fix(portforward): rework run loop and fix deadlocks (#1874)
This commit is contained in:
@@ -2,14 +2,14 @@ package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"errors"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) {
|
||||
func (l *Loop) cleanup(vpnProvider string) {
|
||||
for _, vpnPort := range l.vpnInputPorts {
|
||||
err := l.fw.RemoveAllowedPort(ctx, vpnPort)
|
||||
err := l.fw.RemoveAllowedPort(context.Background(), vpnPort)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot remove allowed input port from firewall: " + err.Error())
|
||||
}
|
||||
@@ -17,11 +17,11 @@ func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) {
|
||||
|
||||
l.publicip.SetData(models.PublicIP{}) // clear public IP address data
|
||||
|
||||
if pfEnabled {
|
||||
const pfTimeout = 100 * time.Millisecond
|
||||
err := l.stopPortForwarding(ctx, pfTimeout)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot stop port forwarding: " + err.Error())
|
||||
err := l.stopPortForwarding(vpnProvider)
|
||||
if err != nil {
|
||||
portForwardingAlreadyStopped := errors.Is(err, context.Canceled)
|
||||
if !portForwardingAlreadyStopped {
|
||||
l.logger.Error("stopping port forwarding: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func ptrTo[T any](value T) *T { return &value }
|
||||
|
||||
// waitForError waits 100ms for an error in the waitError channel.
|
||||
func (l *Loop) waitForError(ctx context.Context,
|
||||
waitError chan error) (err error) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
portforward "github.com/qdm12/gluetun/internal/portforward/service"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
|
||||
@@ -22,8 +22,7 @@ type Routing interface {
|
||||
}
|
||||
|
||||
type PortForward interface {
|
||||
Start(ctx context.Context, data portforward.StartData) (outcome string, err error)
|
||||
Stop(ctx context.Context) (outcome string, err error)
|
||||
UpdateWith(settings portforward.Settings) (err error)
|
||||
}
|
||||
|
||||
type OpenVPN interface {
|
||||
|
||||
@@ -1,47 +1,35 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/portforward/service"
|
||||
)
|
||||
|
||||
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
|
||||
if !data.portForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only used for PIA for now
|
||||
func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
|
||||
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("obtaining VPN local gateway IP for interface %s: %w", data.vpnIntf, err)
|
||||
}
|
||||
l.logger.Info("VPN gateway IP address: " + gateway.String())
|
||||
|
||||
pfData := portforward.StartData{
|
||||
partialUpdate := service.Settings{
|
||||
PortForwarder: data.portForwarder,
|
||||
Gateway: gateway,
|
||||
ServerName: data.serverName,
|
||||
Interface: data.vpnIntf,
|
||||
ServerName: data.serverName,
|
||||
VPNProvider: data.portForwarder.Name(),
|
||||
}
|
||||
_, err = l.portForward.Start(ctx, pfData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting port forwarding: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return l.portForward.UpdateWith(partialUpdate)
|
||||
}
|
||||
|
||||
func (l *Loop) stopPortForwarding(ctx context.Context,
|
||||
timeout time.Duration) (err error) {
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
func (l *Loop) stopPortForwarding(vpnProvider string) (err error) {
|
||||
partialUpdate := service.Settings{
|
||||
VPNProvider: vpnProvider,
|
||||
UserSettings: settings.PortForwarding{
|
||||
Enabled: ptrTo(false),
|
||||
},
|
||||
}
|
||||
|
||||
_, err = l.portForward.Stop(ctx)
|
||||
return err
|
||||
return l.portForward.UpdateWith(partialUpdate)
|
||||
}
|
||||
|
||||
@@ -22,10 +22,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
|
||||
providerConf := l.providers.Get(*settings.Provider.Name)
|
||||
|
||||
portForwarding := *settings.Provider.PortForwarding.Enabled
|
||||
customPortForwardingProvider := *settings.Provider.PortForwarding.Provider
|
||||
portForwader := providerConf
|
||||
if portForwarding && customPortForwardingProvider != "" {
|
||||
if customPortForwardingProvider != "" {
|
||||
portForwader = l.providers.Get(customPortForwardingProvider)
|
||||
}
|
||||
|
||||
@@ -49,10 +48,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
continue
|
||||
}
|
||||
tunnelUpData := tunnelUpData{
|
||||
portForwarding: portForwarding,
|
||||
serverName: serverName,
|
||||
portForwarder: portForwader,
|
||||
vpnIntf: vpnInterface,
|
||||
serverName: serverName,
|
||||
portForwarder: portForwader,
|
||||
vpnIntf: vpnInterface,
|
||||
}
|
||||
|
||||
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
|
||||
@@ -76,7 +74,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
case <-tunnelReady:
|
||||
go l.onTunnelUp(openvpnCtx, tunnelUpData)
|
||||
case <-ctx.Done():
|
||||
l.cleanup(context.Background(), portForwarding)
|
||||
l.cleanup(portForwader.Name())
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
close(waitError)
|
||||
@@ -84,7 +82,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
case <-l.stop:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("stopping")
|
||||
l.cleanup(context.Background(), portForwarding)
|
||||
l.cleanup(portForwader.Name())
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
// do not close waitError or the waitError
|
||||
@@ -97,7 +95,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
case err := <-waitError: // unexpected error
|
||||
l.statusManager.Lock() // prevent SetStatus from running in parallel
|
||||
|
||||
l.cleanup(context.Background(), portForwarding)
|
||||
l.cleanup(portForwader.Name())
|
||||
openvpnCancel()
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
|
||||
@@ -10,10 +10,9 @@ import (
|
||||
|
||||
type tunnelUpData struct {
|
||||
// Port forwarding
|
||||
portForwarding bool
|
||||
vpnIntf string
|
||||
serverName string
|
||||
portForwarder provider.PortForwarder
|
||||
vpnIntf string
|
||||
serverName string
|
||||
portForwarder provider.PortForwarder
|
||||
}
|
||||
|
||||
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
@@ -42,7 +41,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
}
|
||||
}
|
||||
|
||||
err := l.startPortForwarding(ctx, data)
|
||||
err := l.startPortForwarding(data)
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user