feat(portforwarding): VPN_PORT_FORWARDING_DOWN_COMMAND option

This commit is contained in:
Quentin McGaw
2024-11-10 10:18:29 +00:00
parent a035a151bd
commit 0374c14e42
8 changed files with 36 additions and 7 deletions

View File

@@ -126,6 +126,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
VPN_PORT_FORWARDING_USERNAME= \
VPN_PORT_FORWARDING_PASSWORD= \
VPN_PORT_FORWARDING_UP_COMMAND= \
VPN_PORT_FORWARDING_DOWN_COMMAND= \
# # Cyberghost only:
OPENVPN_CERT= \
OPENVPN_KEY= \

View File

@@ -33,6 +33,10 @@ type PortForwarding struct {
// It can be the empty string to indicate not to run a command.
// It cannot be nil in the internal state.
UpCommand *string `json:"up_command"`
// DownCommand is the command to use after the port forwarding goes down.
// It can be the empty string to indicate to NOT run a command.
// It cannot be nil in the internal state.
DownCommand *string `json:"down_command"`
// ListeningPort is the port traffic would be redirected to from the
// forwarded port. The redirection is disabled if it is set to 0, which
// is its default as well.
@@ -89,6 +93,7 @@ func (p *PortForwarding) Copy() (copied PortForwarding) {
Provider: gosettings.CopyPointer(p.Provider),
Filepath: gosettings.CopyPointer(p.Filepath),
UpCommand: gosettings.CopyPointer(p.UpCommand),
DownCommand: gosettings.CopyPointer(p.DownCommand),
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
Username: p.Username,
Password: p.Password,
@@ -100,6 +105,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) {
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand)
p.DownCommand = gosettings.OverrideWithPointer(p.DownCommand, other.DownCommand)
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
p.Username = gosettings.OverrideWithComparable(p.Username, other.Username)
p.Password = gosettings.OverrideWithComparable(p.Password, other.Password)
@@ -110,6 +116,7 @@ func (p *PortForwarding) setDefaults() {
p.Provider = gosettings.DefaultPointer(p.Provider, "")
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "")
p.DownCommand = gosettings.DefaultPointer(p.DownCommand, "")
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
}
@@ -142,9 +149,11 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) {
}
node.Appendf("Forwarded port file path: %s", filepath)
command := *p.UpCommand
if command != "" {
node.Appendf("Forwarded port command: %s", command)
if *p.UpCommand != "" {
node.Appendf("Forwarded port up command: %s", *p.UpCommand)
}
if *p.DownCommand != "" {
node.Appendf("Forwarded port down command: %s", *p.DownCommand)
}
if p.Username != "" {
@@ -178,6 +187,9 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) {
p.UpCommand = r.Get("VPN_PORT_FORWARDING_UP_COMMAND",
reader.ForceLowercase(false))
p.DownCommand = r.Get("VPN_PORT_FORWARDING_DOWN_COMMAND",
reader.ForceLowercase(false))
p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
if err != nil {
return err

View File

@@ -44,6 +44,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
Enabled: settings.Enabled,
Filepath: *settings.Filepath,
UpCommand: *settings.UpCommand,
DownCommand: *settings.DownCommand,
ListeningPort: *settings.ListeningPort,
},
},

View File

@@ -9,7 +9,7 @@ import (
"github.com/qdm12/gluetun/internal/command"
)
func runUpCommand(ctx context.Context, cmder Cmder, logger Logger,
func runCommand(ctx context.Context, cmder Cmder, logger Logger,
commandTemplate string, ports []uint16,
) (err error) {
portStrings := make([]string, len(ports))

View File

@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_Service_runUpCommand(t *testing.T) {
func Test_Service_runCommand(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
@@ -22,7 +22,7 @@ func Test_Service_runUpCommand(t *testing.T) {
logger := NewMockLogger(ctrl)
logger.EXPECT().Info("1234,5678")
err := runUpCommand(ctx, cmder, logger, commandTemplate, ports)
err := runCommand(ctx, cmder, logger, commandTemplate, ports)
require.NoError(t, err)
}

View File

@@ -13,6 +13,7 @@ type Settings struct {
PortForwarder PortForwarder
Filepath string
UpCommand string
DownCommand string
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
ServerName string // needed for PIA
CanPortForward bool // needed for PIA
@@ -26,6 +27,7 @@ func (s Settings) Copy() (copied Settings) {
copied.PortForwarder = s.PortForwarder
copied.Filepath = s.Filepath
copied.UpCommand = s.UpCommand
copied.DownCommand = s.DownCommand
copied.Interface = s.Interface
copied.ServerName = s.ServerName
copied.CanPortForward = s.CanPortForward
@@ -40,6 +42,7 @@ func (s *Settings) OverrideWith(update Settings) {
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand)
s.DownCommand = gosettings.OverrideWithComparable(s.DownCommand, update.DownCommand)
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)

View File

@@ -74,7 +74,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
s.portMutex.Unlock()
if s.settings.UpCommand != "" {
err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
if err != nil {
err = fmt.Errorf("running up command: %w", err)
s.logger.Error(err.Error())

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"time"
)
func (s *Service) Stop() (err error) {
@@ -30,6 +31,17 @@ func (s *Service) cleanup() (err error) {
s.portMutex.Lock()
defer s.portMutex.Unlock()
if s.settings.DownCommand != "" {
const downTimeout = 60 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
defer cancel()
err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports)
if err != nil {
err = fmt.Errorf("running down command: %w", err)
s.logger.Error(err.Error())
}
}
for _, port := range s.ports {
err = s.portAllower.RemoveAllowedPort(context.Background(), port)
if err != nil {