From 0374c14e429df16734d2cfa0784027f43624c247 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sun, 10 Nov 2024 10:18:29 +0000 Subject: [PATCH] feat(portforwarding): `VPN_PORT_FORWARDING_DOWN_COMMAND` option --- Dockerfile | 1 + internal/configuration/settings/portforward.go | 18 +++++++++++++++--- internal/portforward/loop.go | 1 + internal/portforward/service/command.go | 2 +- internal/portforward/service/command_test.go | 4 ++-- internal/portforward/service/settings.go | 3 +++ internal/portforward/service/start.go | 2 +- internal/portforward/service/stop.go | 12 ++++++++++++ 8 files changed, 36 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 08e83df1..1415abcb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -126,6 +126,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ VPN_PORT_FORWARDING_USERNAME= \ VPN_PORT_FORWARDING_PASSWORD= \ VPN_PORT_FORWARDING_UP_COMMAND= \ + VPN_PORT_FORWARDING_DOWN_COMMAND= \ # # Cyberghost only: OPENVPN_CERT= \ OPENVPN_KEY= \ diff --git a/internal/configuration/settings/portforward.go b/internal/configuration/settings/portforward.go index dfdcec25..d445d657 100644 --- a/internal/configuration/settings/portforward.go +++ b/internal/configuration/settings/portforward.go @@ -33,6 +33,10 @@ type PortForwarding struct { // It can be the empty string to indicate not to run a command. // It cannot be nil in the internal state. UpCommand *string `json:"up_command"` + // DownCommand is the command to use after the port forwarding goes down. + // It can be the empty string to indicate to NOT run a command. + // It cannot be nil in the internal state. + DownCommand *string `json:"down_command"` // ListeningPort is the port traffic would be redirected to from the // forwarded port. The redirection is disabled if it is set to 0, which // is its default as well. @@ -89,6 +93,7 @@ func (p *PortForwarding) Copy() (copied PortForwarding) { Provider: gosettings.CopyPointer(p.Provider), Filepath: gosettings.CopyPointer(p.Filepath), UpCommand: gosettings.CopyPointer(p.UpCommand), + DownCommand: gosettings.CopyPointer(p.DownCommand), ListeningPort: gosettings.CopyPointer(p.ListeningPort), Username: p.Username, Password: p.Password, @@ -100,6 +105,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) { p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider) p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath) p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand) + p.DownCommand = gosettings.OverrideWithPointer(p.DownCommand, other.DownCommand) p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort) p.Username = gosettings.OverrideWithComparable(p.Username, other.Username) p.Password = gosettings.OverrideWithComparable(p.Password, other.Password) @@ -110,6 +116,7 @@ func (p *PortForwarding) setDefaults() { p.Provider = gosettings.DefaultPointer(p.Provider, "") p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port") p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "") + p.DownCommand = gosettings.DefaultPointer(p.DownCommand, "") p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0) } @@ -142,9 +149,11 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) { } node.Appendf("Forwarded port file path: %s", filepath) - command := *p.UpCommand - if command != "" { - node.Appendf("Forwarded port command: %s", command) + if *p.UpCommand != "" { + node.Appendf("Forwarded port up command: %s", *p.UpCommand) + } + if *p.DownCommand != "" { + node.Appendf("Forwarded port down command: %s", *p.DownCommand) } if p.Username != "" { @@ -178,6 +187,9 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) { p.UpCommand = r.Get("VPN_PORT_FORWARDING_UP_COMMAND", reader.ForceLowercase(false)) + p.DownCommand = r.Get("VPN_PORT_FORWARDING_DOWN_COMMAND", + reader.ForceLowercase(false)) + p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT") if err != nil { return err diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index cc9e0a15..8f943145 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -44,6 +44,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing, Enabled: settings.Enabled, Filepath: *settings.Filepath, UpCommand: *settings.UpCommand, + DownCommand: *settings.DownCommand, ListeningPort: *settings.ListeningPort, }, }, diff --git a/internal/portforward/service/command.go b/internal/portforward/service/command.go index 30a6dcf4..2a87b3eb 100644 --- a/internal/portforward/service/command.go +++ b/internal/portforward/service/command.go @@ -9,7 +9,7 @@ import ( "github.com/qdm12/gluetun/internal/command" ) -func runUpCommand(ctx context.Context, cmder Cmder, logger Logger, +func runCommand(ctx context.Context, cmder Cmder, logger Logger, commandTemplate string, ports []uint16, ) (err error) { portStrings := make([]string, len(ports)) diff --git a/internal/portforward/service/command_test.go b/internal/portforward/service/command_test.go index 59b471aa..d2747250 100644 --- a/internal/portforward/service/command_test.go +++ b/internal/portforward/service/command_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test_Service_runUpCommand(t *testing.T) { +func Test_Service_runCommand(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -22,7 +22,7 @@ func Test_Service_runUpCommand(t *testing.T) { logger := NewMockLogger(ctrl) logger.EXPECT().Info("1234,5678") - err := runUpCommand(ctx, cmder, logger, commandTemplate, ports) + err := runCommand(ctx, cmder, logger, commandTemplate, ports) require.NoError(t, err) } diff --git a/internal/portforward/service/settings.go b/internal/portforward/service/settings.go index b44d790c..2847be51 100644 --- a/internal/portforward/service/settings.go +++ b/internal/portforward/service/settings.go @@ -13,6 +13,7 @@ type Settings struct { PortForwarder PortForwarder Filepath string UpCommand string + DownCommand string Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example ServerName string // needed for PIA CanPortForward bool // needed for PIA @@ -26,6 +27,7 @@ func (s Settings) Copy() (copied Settings) { copied.PortForwarder = s.PortForwarder copied.Filepath = s.Filepath copied.UpCommand = s.UpCommand + copied.DownCommand = s.DownCommand copied.Interface = s.Interface copied.ServerName = s.ServerName copied.CanPortForward = s.CanPortForward @@ -40,6 +42,7 @@ func (s *Settings) OverrideWith(update Settings) { s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder) s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath) s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand) + s.DownCommand = gosettings.OverrideWithComparable(s.DownCommand, update.DownCommand) s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface) s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName) s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward) diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index c7fb9bbf..a13ac0e4 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -74,7 +74,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) s.portMutex.Unlock() if s.settings.UpCommand != "" { - err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports) + err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports) if err != nil { err = fmt.Errorf("running up command: %w", err) s.logger.Error(err.Error()) diff --git a/internal/portforward/service/stop.go b/internal/portforward/service/stop.go index e8b8d681..208951c3 100644 --- a/internal/portforward/service/stop.go +++ b/internal/portforward/service/stop.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "time" ) func (s *Service) Stop() (err error) { @@ -30,6 +31,17 @@ func (s *Service) cleanup() (err error) { s.portMutex.Lock() defer s.portMutex.Unlock() + if s.settings.DownCommand != "" { + const downTimeout = 60 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), downTimeout) + defer cancel() + err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports) + if err != nil { + err = fmt.Errorf("running down command: %w", err) + s.logger.Error(err.Error()) + } + } + for _, port := range s.ports { err = s.portAllower.RemoveAllowedPort(context.Background(), port) if err != nil {