fix(portforward): rework run loop and fix deadlocks (#1874)

This commit is contained in:
Quentin McGaw
2023-09-23 12:57:12 +02:00
committed by GitHub
parent c435bbb32c
commit 71201411f4
30 changed files with 453 additions and 476 deletions

View File

@@ -2,14 +2,14 @@ package vpn
import (
"context"
"time"
"errors"
"github.com/qdm12/gluetun/internal/models"
)
func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) {
func (l *Loop) cleanup(vpnProvider string) {
for _, vpnPort := range l.vpnInputPorts {
err := l.fw.RemoveAllowedPort(ctx, vpnPort)
err := l.fw.RemoveAllowedPort(context.Background(), vpnPort)
if err != nil {
l.logger.Error("cannot remove allowed input port from firewall: " + err.Error())
}
@@ -17,11 +17,11 @@ func (l *Loop) cleanup(ctx context.Context, pfEnabled bool) {
l.publicip.SetData(models.PublicIP{}) // clear public IP address data
if pfEnabled {
const pfTimeout = 100 * time.Millisecond
err := l.stopPortForwarding(ctx, pfTimeout)
if err != nil {
l.logger.Error("cannot stop port forwarding: " + err.Error())
err := l.stopPortForwarding(vpnProvider)
if err != nil {
portForwardingAlreadyStopped := errors.Is(err, context.Canceled)
if !portForwardingAlreadyStopped {
l.logger.Error("stopping port forwarding: " + err.Error())
}
}
}

View File

@@ -8,6 +8,8 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
func ptrTo[T any](value T) *T { return &value }
// waitForError waits 100ms for an error in the waitError channel.
func (l *Loop) waitForError(ctx context.Context,
waitError chan error) (err error) {

View File

@@ -7,7 +7,7 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/portforward"
portforward "github.com/qdm12/gluetun/internal/portforward/service"
"github.com/qdm12/gluetun/internal/provider"
)
@@ -22,8 +22,7 @@ type Routing interface {
}
type PortForward interface {
Start(ctx context.Context, data portforward.StartData) (outcome string, err error)
Stop(ctx context.Context) (outcome string, err error)
UpdateWith(settings portforward.Settings) (err error)
}
type OpenVPN interface {

View File

@@ -1,47 +1,35 @@
package vpn
import (
"context"
"fmt"
"time"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/portforward/service"
)
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
if !data.portForwarding {
return nil
}
// only used for PIA for now
func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
if err != nil {
return fmt.Errorf("obtaining VPN local gateway IP for interface %s: %w", data.vpnIntf, err)
}
l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{
partialUpdate := service.Settings{
PortForwarder: data.portForwarder,
Gateway: gateway,
ServerName: data.serverName,
Interface: data.vpnIntf,
ServerName: data.serverName,
VPNProvider: data.portForwarder.Name(),
}
_, err = l.portForward.Start(ctx, pfData)
if err != nil {
return fmt.Errorf("starting port forwarding: %w", err)
}
return nil
return l.portForward.UpdateWith(partialUpdate)
}
func (l *Loop) stopPortForwarding(ctx context.Context,
timeout time.Duration) (err error) {
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
func (l *Loop) stopPortForwarding(vpnProvider string) (err error) {
partialUpdate := service.Settings{
VPNProvider: vpnProvider,
UserSettings: settings.PortForwarding{
Enabled: ptrTo(false),
},
}
_, err = l.portForward.Stop(ctx)
return err
return l.portForward.UpdateWith(partialUpdate)
}

View File

@@ -22,10 +22,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
providerConf := l.providers.Get(*settings.Provider.Name)
portForwarding := *settings.Provider.PortForwarding.Enabled
customPortForwardingProvider := *settings.Provider.PortForwarding.Provider
portForwader := providerConf
if portForwarding && customPortForwardingProvider != "" {
if customPortForwardingProvider != "" {
portForwader = l.providers.Get(customPortForwardingProvider)
}
@@ -49,10 +48,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue
}
tunnelUpData := tunnelUpData{
portForwarding: portForwarding,
serverName: serverName,
portForwarder: portForwader,
vpnIntf: vpnInterface,
serverName: serverName,
portForwarder: portForwader,
vpnIntf: vpnInterface,
}
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
@@ -76,7 +74,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-tunnelReady:
go l.onTunnelUp(openvpnCtx, tunnelUpData)
case <-ctx.Done():
l.cleanup(context.Background(), portForwarding)
l.cleanup(portForwader.Name())
openvpnCancel()
<-waitError
close(waitError)
@@ -84,7 +82,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-l.stop:
l.userTrigger = true
l.logger.Info("stopping")
l.cleanup(context.Background(), portForwarding)
l.cleanup(portForwader.Name())
openvpnCancel()
<-waitError
// do not close waitError or the waitError
@@ -97,7 +95,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case err := <-waitError: // unexpected error
l.statusManager.Lock() // prevent SetStatus from running in parallel
l.cleanup(context.Background(), portForwarding)
l.cleanup(portForwader.Name())
openvpnCancel()
l.statusManager.SetStatus(constants.Crashed)
l.logAndWait(ctx, err)

View File

@@ -10,10 +10,9 @@ import (
type tunnelUpData struct {
// Port forwarding
portForwarding bool
vpnIntf string
serverName string
portForwarder provider.PortForwarder
vpnIntf string
serverName string
portForwarder provider.PortForwarder
}
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
@@ -42,7 +41,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
}
}
err := l.startPortForwarding(ctx, data)
err := l.startPortForwarding(data)
if err != nil {
l.logger.Error(err.Error())
}