feat(portforwarding): VPN_PORT_FORWARDING_DOWN_COMMAND option

This commit is contained in:
Quentin McGaw
2024-11-10 10:18:29 +00:00
parent a035a151bd
commit 0374c14e42
8 changed files with 36 additions and 7 deletions

View File

@@ -44,6 +44,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
Enabled: settings.Enabled,
Filepath: *settings.Filepath,
UpCommand: *settings.UpCommand,
DownCommand: *settings.DownCommand,
ListeningPort: *settings.ListeningPort,
},
},

View File

@@ -9,7 +9,7 @@ import (
"github.com/qdm12/gluetun/internal/command"
)
func runUpCommand(ctx context.Context, cmder Cmder, logger Logger,
func runCommand(ctx context.Context, cmder Cmder, logger Logger,
commandTemplate string, ports []uint16,
) (err error) {
portStrings := make([]string, len(ports))

View File

@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_Service_runUpCommand(t *testing.T) {
func Test_Service_runCommand(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
@@ -22,7 +22,7 @@ func Test_Service_runUpCommand(t *testing.T) {
logger := NewMockLogger(ctrl)
logger.EXPECT().Info("1234,5678")
err := runUpCommand(ctx, cmder, logger, commandTemplate, ports)
err := runCommand(ctx, cmder, logger, commandTemplate, ports)
require.NoError(t, err)
}

View File

@@ -13,6 +13,7 @@ type Settings struct {
PortForwarder PortForwarder
Filepath string
UpCommand string
DownCommand string
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
ServerName string // needed for PIA
CanPortForward bool // needed for PIA
@@ -26,6 +27,7 @@ func (s Settings) Copy() (copied Settings) {
copied.PortForwarder = s.PortForwarder
copied.Filepath = s.Filepath
copied.UpCommand = s.UpCommand
copied.DownCommand = s.DownCommand
copied.Interface = s.Interface
copied.ServerName = s.ServerName
copied.CanPortForward = s.CanPortForward
@@ -40,6 +42,7 @@ func (s *Settings) OverrideWith(update Settings) {
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand)
s.DownCommand = gosettings.OverrideWithComparable(s.DownCommand, update.DownCommand)
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)

View File

@@ -74,7 +74,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
s.portMutex.Unlock()
if s.settings.UpCommand != "" {
err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
if err != nil {
err = fmt.Errorf("running up command: %w", err)
s.logger.Error(err.Error())

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"time"
)
func (s *Service) Stop() (err error) {
@@ -30,6 +31,17 @@ func (s *Service) cleanup() (err error) {
s.portMutex.Lock()
defer s.portMutex.Unlock()
if s.settings.DownCommand != "" {
const downTimeout = 60 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
defer cancel()
err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports)
if err != nil {
err = fmt.Errorf("running down command: %w", err)
s.logger.Error(err.Error())
}
}
for _, port := range s.ports {
err = s.portAllower.RemoveAllowedPort(context.Background(), port)
if err != nil {