From 0406de399d5055eeb421acebe4e4e36382ece2c5 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 23 Sep 2023 11:46:14 +0000 Subject: [PATCH] chore(portforward): move vpn gateway obtention within port forwarding service --- cmd/gluetun/main.go | 2 +- internal/portforward/interfaces.go | 9 +++- internal/portforward/loop.go | 6 ++- internal/portforward/service/interfaces.go | 14 ++++++ internal/portforward/service/service.go | 4 +- internal/portforward/service/settings.go | 19 +++----- internal/portforward/service/start.go | 26 ++++++++--- .../privateinternetaccess/portforward.go | 45 ++++++++++++------- internal/provider/protonvpn/portforward.go | 41 +++++++---------- internal/provider/protonvpn/provider.go | 1 + internal/provider/provider.go | 8 +--- internal/provider/utils/noportforward.go | 16 +++---- internal/provider/utils/portforward.go | 27 +++++++++++ internal/vpn/portforward.go | 9 ---- 14 files changed, 135 insertions(+), 92 deletions(-) create mode 100644 internal/provider/utils/portforward.go diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 0131c7ee..e91c431e 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -376,7 +376,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, portForwardLogger := logger.New(log.SetComponent("port forwarding")) portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding, - httpClient, firewallConf, portForwardLogger, puid, pgid) + routingConf, httpClient, firewallConf, portForwardLogger, puid, pgid) portForwardRunError, _ := portForwardLooper.Start(context.Background()) unboundLogger := logger.New(log.SetComponent("dns")) diff --git a/internal/portforward/interfaces.go b/internal/portforward/interfaces.go index 7811c733..15face10 100644 --- a/internal/portforward/interfaces.go +++ b/internal/portforward/interfaces.go @@ -1,6 +1,9 @@ package portforward -import "context" +import ( + "context" + "net/netip" +) type Service interface { Start(ctx context.Context) (runError <-chan error, err error) @@ -8,6 +11,10 @@ type Service interface { GetPortForwarded() (port uint16) } +type Routing interface { + VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) +} + type PortAllower interface { SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error) diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index f8ff4cc2..30f33e77 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -16,6 +16,7 @@ type Loop struct { settingsMutex sync.RWMutex service Service // Fixed injected objets + routing Routing client *http.Client portAllower PortAllower logger Logger @@ -30,13 +31,14 @@ type Loop struct { runDone <-chan struct{} } -func NewLoop(settings settings.PortForwarding, +func NewLoop(settings settings.PortForwarding, routing Routing, client *http.Client, portAllower PortAllower, logger Logger, uid, gid int) *Loop { return &Loop{ settings: service.Settings{ UserSettings: settings, }, + routing: routing, client: client, portAllower: portAllower, logger: logger, @@ -85,7 +87,7 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, } l.settingsMutex.RLock() - l.service = service.New(l.settings, l.client, + l.service = service.New(l.settings, l.routing, l.client, l.portAllower, l.logger, l.uid, l.gid) l.settingsMutex.RUnlock() diff --git a/internal/portforward/service/interfaces.go b/internal/portforward/service/interfaces.go index 0265ea7f..eed25869 100644 --- a/internal/portforward/service/interfaces.go +++ b/internal/portforward/service/interfaces.go @@ -2,6 +2,9 @@ package service import ( "context" + "net/netip" + + "github.com/qdm12/gluetun/internal/provider/utils" ) type PortAllower interface { @@ -9,8 +12,19 @@ type PortAllower interface { RemoveAllowedPort(ctx context.Context, port uint16) (err error) } +type Routing interface { + VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) +} + type Logger interface { Info(s string) Warn(s string) Error(s string) } + +type PortForwarder interface { + Name() string + PortForward(ctx context.Context, objects utils.PortForwardObjects) ( + port uint16, err error) + KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) +} diff --git a/internal/portforward/service/service.go b/internal/portforward/service/service.go index 16cd07db..ba954289 100644 --- a/internal/portforward/service/service.go +++ b/internal/portforward/service/service.go @@ -15,6 +15,7 @@ type Service struct { puid int pgid int // Fixed injected objets + routing Routing client *http.Client portAllower PortAllower logger Logger @@ -24,7 +25,7 @@ type Service struct { keepPortDoneCh <-chan struct{} } -func New(settings Settings, client *http.Client, +func New(settings Settings, routing Routing, client *http.Client, portAllower PortAllower, logger Logger, puid, pgid int) *Service { return &Service{ // Fixed parameters @@ -32,6 +33,7 @@ func New(settings Settings, client *http.Client, puid: puid, pgid: pgid, // Fixed injected objets + routing: routing, client: client, portAllower: portAllower, logger: logger, diff --git a/internal/portforward/service/settings.go b/internal/portforward/service/settings.go index 8f8ce108..43ebb118 100644 --- a/internal/portforward/service/settings.go +++ b/internal/portforward/service/settings.go @@ -3,21 +3,18 @@ package service import ( "errors" "fmt" - "net/netip" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants/providers" - "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gosettings" ) type Settings struct { UserSettings settings.PortForwarding - PortForwarder provider.PortForwarder - Gateway netip.Addr // needed for PIA and ProtonVPN - ServerName string // needed for PIA - Interface string // needed for PIA and ProtonVPN, tun0 for example - VPNProvider string // used to validate new settings + PortForwarder PortForwarder + Interface string // needed for PIA and ProtonVPN, tun0 for example + ServerName string // needed for PIA + VPNProvider string // used to validate new settings } // UpdateWith deep copies the receiving settings, overrides the copy with @@ -37,9 +34,8 @@ func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, func (s Settings) copy() (copied Settings) { copied.UserSettings = s.UserSettings.Copy() copied.PortForwarder = s.PortForwarder - copied.Gateway = s.Gateway - copied.ServerName = s.ServerName copied.Interface = s.Interface + copied.ServerName = s.ServerName copied.VPNProvider = s.VPNProvider return copied } @@ -47,9 +43,8 @@ func (s Settings) copy() (copied Settings) { func (s *Settings) overrideWith(update Settings) { s.UserSettings.OverrideWith(update.UserSettings) s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder) - s.Gateway = gosettings.OverrideWithValidator(s.Gateway, update.Gateway) - s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName) s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface) + s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName) s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider) } @@ -69,8 +64,6 @@ func (s *Settings) validate() (err error) { return fmt.Errorf("%w", ErrServerNameNotSet) case s.PortForwarder == nil: return fmt.Errorf("%w", ErrPortForwarderNotSet) - case !s.Gateway.IsValid(): - return fmt.Errorf("%w", ErrGatewayNotSet) case s.Interface == "": return fmt.Errorf("%w", ErrInterfaceNotSet) } diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index c50f5c53..3afa0311 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -3,6 +3,8 @@ package service import ( "context" "fmt" + + "github.com/qdm12/gluetun/internal/provider/utils" ) func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) { @@ -14,8 +16,19 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) } s.logger.Info("starting") - port, err := s.settings.PortForwarder.PortForward(ctx, s.client, s.logger, - s.settings.Gateway, s.settings.ServerName) + + gateway, err := s.routing.VPNLocalGatewayIP(s.settings.Interface) + if err != nil { + return nil, fmt.Errorf("getting VPN local gateway IP: %w", err) + } + + obj := utils.PortForwardObjects{ + Logger: s.logger, + Gateway: gateway, + Client: s.client, + ServerName: s.settings.ServerName, + } + port, err := s.settings.PortForwarder.PortForward(ctx, obj) if err != nil { return nil, fmt.Errorf("port forwarding for the first time: %w", err) } @@ -43,18 +56,17 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) keepPortDoneCh := make(chan struct{}) s.keepPortDoneCh = keepPortDoneCh - go func(ctx context.Context, settings Settings, port uint16, - runError chan<- error, doneCh chan<- struct{}) { + go func(ctx context.Context, portForwarder PortForwarder, + obj utils.PortForwardObjects, runError chan<- error, doneCh chan<- struct{}) { defer close(doneCh) - err = settings.PortForwarder.KeepPortForward(ctx, port, - settings.Gateway, settings.ServerName, s.logger) + err = portForwarder.KeepPortForward(ctx, obj) crashed := ctx.Err() == nil if !crashed { // stopped by Stop call return } _ = s.cleanup() runError <- err - }(keepPortCtx, s.settings, port, runErrorCh, keepPortDoneCh) + }(keepPortCtx, s.settings.PortForwarder, obj, runErrorCh, keepPortDoneCh) return runErrorCh, nil } diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 93f49b3b..563c0453 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -22,30 +22,33 @@ import ( ) var ( - ErrServerNameNotFound = errors.New("server name not found in servers") - ErrGatewayIPIsNotValid = errors.New("gateway IP address is not valid") - ErrServerNameEmpty = errors.New("server name is empty") + ErrServerNameNotFound = errors.New("server name not found in servers") ) // PortForward obtains a VPN server side port forwarded from PIA. -func (p *Provider) PortForward(ctx context.Context, client *http.Client, - logger utils.Logger, gateway netip.Addr, serverName string) ( - port uint16, err error) { +func (p *Provider) PortForward(ctx context.Context, + objects utils.PortForwardObjects) (port uint16, err error) { + switch { + case objects.ServerName == "": + panic("server name cannot be empty") + case !objects.Gateway.IsValid(): + panic("gateway is not set") + } + + serverName := objects.ServerName + server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName) if !ok { return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName) } + logger := objects.Logger + if !server.PortForward { logger.Error("The server " + serverName + " (region " + server.Region + ") does not support port forwarding") return 0, nil } - if !gateway.IsValid() { - return 0, fmt.Errorf("%w: %s", ErrGatewayIPIsNotValid, gateway) - } else if serverName == "" { - return 0, ErrServerNameEmpty - } privateIPClient, err := newHTTPClient(serverName) if err != nil { @@ -70,7 +73,8 @@ func (p *Provider) PortForward(ctx context.Context, client *http.Client, } if !dataFound || expired { - data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, + client := objects.Client + data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway, p.portForwardPath, p.authFilePath) if err != nil { return 0, fmt.Errorf("refreshing port forward data: %w", err) @@ -80,7 +84,7 @@ func (p *Provider) PortForward(ctx context.Context, client *http.Client, logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration)) // First time binding - if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { + if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil { return 0, fmt.Errorf("binding port: %w", err) } @@ -91,9 +95,16 @@ var ( ErrPortForwardedExpired = errors.New("port forwarded data expired") ) -func (p *Provider) KeepPortForward(ctx context.Context, _ uint16, - gateway netip.Addr, serverName string, _ utils.Logger) (err error) { - privateIPClient, err := newHTTPClient(serverName) +func (p *Provider) KeepPortForward(ctx context.Context, + objects utils.PortForwardObjects) (err error) { + switch { + case objects.ServerName == "": + panic("server name cannot be empty") + case !objects.Gateway.IsValid(): + panic("gateway is not set") + } + + privateIPClient, err := newHTTPClient(objects.ServerName) if err != nil { return fmt.Errorf("creating custom HTTP client: %w", err) } @@ -120,7 +131,7 @@ func (p *Provider) KeepPortForward(ctx context.Context, _ uint16, } return ctx.Err() case <-keepAliveTimer.C: - err := bindPort(ctx, privateIPClient, gateway, data) + err = bindPort(ctx, privateIPClient, objects.Gateway, data) if err != nil { return fmt.Errorf("binding port: %w", err) } diff --git a/internal/provider/protonvpn/portforward.go b/internal/provider/protonvpn/portforward.go index c83fc24a..0d13f2c8 100644 --- a/internal/provider/protonvpn/portforward.go +++ b/internal/provider/protonvpn/portforward.go @@ -2,10 +2,7 @@ package protonvpn import ( "context" - "errors" "fmt" - "net/http" - "net/netip" "strings" "time" @@ -13,31 +10,24 @@ import ( "github.com/qdm12/gluetun/internal/provider/utils" ) -var ( - ErrGatewayIPNotValid = errors.New("gateway IP address is not valid") -) - // PortForward obtains a VPN server side port forwarded from ProtonVPN gateway. -func (p *Provider) PortForward(ctx context.Context, _ *http.Client, - logger utils.Logger, gateway netip.Addr, _ string) ( +func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) ( port uint16, err error) { - if !gateway.IsValid() { - return 0, fmt.Errorf("%w", ErrGatewayIPNotValid) - } - client := natpmp.New() _, externalIPv4Address, err := client.ExternalAddress(ctx, - gateway) + objects.Gateway) if err != nil { return 0, fmt.Errorf("getting external IPv4 address: %w", err) } + logger := objects.Logger + logger.Info("gateway external IPv4 address is " + externalIPv4Address.String()) const internalPort, externalPort = 0, 0 const lifetime = 60 * time.Second _, _, assignedUDPExternalPort, assignedLifetime, err := - client.AddPortMapping(ctx, gateway, "udp", + client.AddPortMapping(ctx, objects.Gateway, "udp", internalPort, externalPort, lifetime) if err != nil { return 0, fmt.Errorf("adding UDP port mapping: %w", err) @@ -45,7 +35,7 @@ func (p *Provider) PortForward(ctx context.Context, _ *http.Client, checkLifetime(logger, "UDP", lifetime, assignedLifetime) _, _, assignedTCPExternalPort, assignedLifetime, err := - client.AddPortMapping(ctx, gateway, "tcp", + client.AddPortMapping(ctx, objects.Gateway, "tcp", internalPort, externalPort, lifetime) if err != nil { return 0, fmt.Errorf("adding TCP port mapping: %w", err) @@ -55,6 +45,8 @@ func (p *Provider) PortForward(ctx context.Context, _ *http.Client, checkExternalPorts(logger, assignedUDPExternalPort, assignedTCPExternalPort) port = assignedTCPExternalPort + p.portForwarded = port + return port, nil } @@ -74,11 +66,12 @@ func checkExternalPorts(logger utils.Logger, udpPort, tcpPort uint16) { } } -func (p *Provider) KeepPortForward(ctx context.Context, port uint16, - gateway netip.Addr, _ string, logger utils.Logger) (err error) { +func (p *Provider) KeepPortForward(ctx context.Context, + objects utils.PortForwardObjects) (err error) { client := natpmp.New() const refreshTimeout = 45 * time.Second timer := time.NewTimer(refreshTimeout) + logger := objects.Logger for { select { case <-ctx.Done(): @@ -92,8 +85,8 @@ func (p *Provider) KeepPortForward(ctx context.Context, port uint16, for _, networkProtocol := range networkProtocols { _, _, assignedExternalPort, assignedLiftetime, err := - client.AddPortMapping(ctx, gateway, networkProtocol, - internalPort, port, lifetime) + client.AddPortMapping(ctx, objects.Gateway, networkProtocol, + internalPort, p.portForwarded, lifetime) if err != nil { return fmt.Errorf("adding port mapping: %w", err) } @@ -104,10 +97,10 @@ func (p *Provider) KeepPortForward(ctx context.Context, port uint16, assignedLiftetime, lifetime)) } - if port != assignedExternalPort { - logger.Warn(fmt.Sprintf("external port assigned %d changed to %d", - port, assignedExternalPort)) - port = assignedExternalPort + if p.portForwarded != assignedExternalPort { + objects.Logger.Warn(fmt.Sprintf("external port assigned %d changed to %d", + p.portForwarded, assignedExternalPort)) + p.portForwarded = assignedExternalPort } } diff --git a/internal/provider/protonvpn/provider.go b/internal/provider/protonvpn/provider.go index b0e5d711..902d1d16 100644 --- a/internal/provider/protonvpn/provider.go +++ b/internal/provider/protonvpn/provider.go @@ -13,6 +13,7 @@ type Provider struct { storage common.Storage randSource rand.Source common.Fetcher + portForwarded uint16 } func New(storage common.Storage, randSource rand.Source, diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 5b3e4efb..96b48df7 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -2,8 +2,6 @@ package provider import ( "context" - "net/http" - "net/netip" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" @@ -22,9 +20,7 @@ type Provider interface { type PortForwarder interface { Name() string - PortForward(ctx context.Context, client *http.Client, - logger utils.Logger, gateway netip.Addr, serverName string) ( + PortForward(ctx context.Context, objects utils.PortForwardObjects) ( port uint16, err error) - KeepPortForward(ctx context.Context, port uint16, gateway netip.Addr, - serverName string, _ utils.Logger) (err error) + KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) } diff --git a/internal/provider/utils/noportforward.go b/internal/provider/utils/noportforward.go index 2df92e5f..4d49df81 100644 --- a/internal/provider/utils/noportforward.go +++ b/internal/provider/utils/noportforward.go @@ -4,16 +4,11 @@ import ( "context" "errors" "fmt" - "net/http" - "net/netip" ) type NoPortForwarder interface { - PortForward(ctx context.Context, client *http.Client, - logger Logger, gateway netip.Addr, serverName string) ( - port uint16, err error) - KeepPortForward(ctx context.Context, port uint16, gateway netip.Addr, - serverName string, logger Logger) (err error) + PortForward(ctx context.Context, objects PortForwardObjects) (port uint16, err error) + KeepPortForward(ctx context.Context, objects PortForwardObjects) (err error) } type NoPortForwarding struct { @@ -28,12 +23,11 @@ func NewNoPortForwarding(providerName string) *NoPortForwarding { var ErrPortForwardingNotSupported = errors.New("custom port forwarding obtention is not supported") -func (n *NoPortForwarding) PortForward(context.Context, *http.Client, - Logger, netip.Addr, string) (port uint16, err error) { +func (n *NoPortForwarding) PortForward(context.Context, PortForwardObjects) ( + port uint16, err error) { return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) } -func (n *NoPortForwarding) KeepPortForward(context.Context, uint16, netip.Addr, - string, Logger) (err error) { +func (n *NoPortForwarding) KeepPortForward(context.Context, PortForwardObjects) (err error) { return fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) } diff --git a/internal/provider/utils/portforward.go b/internal/provider/utils/portforward.go new file mode 100644 index 00000000..baae073d --- /dev/null +++ b/internal/provider/utils/portforward.go @@ -0,0 +1,27 @@ +package utils + +import ( + "net/http" + "net/netip" +) + +// PortForwardObjects contains fields that may or may not need to be set +// depending on the port forwarding provider code. +type PortForwardObjects struct { + // Logger is a logger, used by both Private Internet Access and ProtonVPN. + Logger Logger + // Gateway is the VPN gateway IP address, used by Private Internet Access + // and ProtonVPN. + Gateway netip.Addr + // Client is used to query the VPN gateway for Private Internet Access. + Client *http.Client + // ServerName is used by Private Internet Access for port forwarding, + // and to look up the server data from storage. + // TODO use server data directly to remove storage dependency for port + // forwarding implementation. + ServerName string +} + +type Routing interface { + VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) +} diff --git a/internal/vpn/portforward.go b/internal/vpn/portforward.go index d05d1f8e..9ac3eba6 100644 --- a/internal/vpn/portforward.go +++ b/internal/vpn/portforward.go @@ -1,22 +1,13 @@ package vpn import ( - "fmt" - "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/portforward/service" ) func (l *Loop) startPortForwarding(data tunnelUpData) (err error) { - gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf) - if err != nil { - return fmt.Errorf("obtaining VPN local gateway IP for interface %s: %w", data.vpnIntf, err) - } - l.logger.Info("VPN gateway IP address: " + gateway.String()) - partialUpdate := service.Settings{ PortForwarder: data.portForwarder, - Gateway: gateway, Interface: data.vpnIntf, ServerName: data.serverName, VPNProvider: data.portForwarder.Name(),