feat(portforwarding): VPN_PORT_FORWARDING_DOWN_COMMAND option
This commit is contained in:
@@ -126,6 +126,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
VPN_PORT_FORWARDING_USERNAME= \
|
||||
VPN_PORT_FORWARDING_PASSWORD= \
|
||||
VPN_PORT_FORWARDING_UP_COMMAND= \
|
||||
VPN_PORT_FORWARDING_DOWN_COMMAND= \
|
||||
# # Cyberghost only:
|
||||
OPENVPN_CERT= \
|
||||
OPENVPN_KEY= \
|
||||
|
||||
@@ -33,6 +33,10 @@ type PortForwarding struct {
|
||||
// It can be the empty string to indicate not to run a command.
|
||||
// It cannot be nil in the internal state.
|
||||
UpCommand *string `json:"up_command"`
|
||||
// DownCommand is the command to use after the port forwarding goes down.
|
||||
// It can be the empty string to indicate to NOT run a command.
|
||||
// It cannot be nil in the internal state.
|
||||
DownCommand *string `json:"down_command"`
|
||||
// ListeningPort is the port traffic would be redirected to from the
|
||||
// forwarded port. The redirection is disabled if it is set to 0, which
|
||||
// is its default as well.
|
||||
@@ -89,6 +93,7 @@ func (p *PortForwarding) Copy() (copied PortForwarding) {
|
||||
Provider: gosettings.CopyPointer(p.Provider),
|
||||
Filepath: gosettings.CopyPointer(p.Filepath),
|
||||
UpCommand: gosettings.CopyPointer(p.UpCommand),
|
||||
DownCommand: gosettings.CopyPointer(p.DownCommand),
|
||||
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
@@ -100,6 +105,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) {
|
||||
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
|
||||
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
|
||||
p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand)
|
||||
p.DownCommand = gosettings.OverrideWithPointer(p.DownCommand, other.DownCommand)
|
||||
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
|
||||
p.Username = gosettings.OverrideWithComparable(p.Username, other.Username)
|
||||
p.Password = gosettings.OverrideWithComparable(p.Password, other.Password)
|
||||
@@ -110,6 +116,7 @@ func (p *PortForwarding) setDefaults() {
|
||||
p.Provider = gosettings.DefaultPointer(p.Provider, "")
|
||||
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
|
||||
p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "")
|
||||
p.DownCommand = gosettings.DefaultPointer(p.DownCommand, "")
|
||||
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
|
||||
}
|
||||
|
||||
@@ -142,9 +149,11 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) {
|
||||
}
|
||||
node.Appendf("Forwarded port file path: %s", filepath)
|
||||
|
||||
command := *p.UpCommand
|
||||
if command != "" {
|
||||
node.Appendf("Forwarded port command: %s", command)
|
||||
if *p.UpCommand != "" {
|
||||
node.Appendf("Forwarded port up command: %s", *p.UpCommand)
|
||||
}
|
||||
if *p.DownCommand != "" {
|
||||
node.Appendf("Forwarded port down command: %s", *p.DownCommand)
|
||||
}
|
||||
|
||||
if p.Username != "" {
|
||||
@@ -178,6 +187,9 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) {
|
||||
p.UpCommand = r.Get("VPN_PORT_FORWARDING_UP_COMMAND",
|
||||
reader.ForceLowercase(false))
|
||||
|
||||
p.DownCommand = r.Get("VPN_PORT_FORWARDING_DOWN_COMMAND",
|
||||
reader.ForceLowercase(false))
|
||||
|
||||
p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -44,6 +44,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
|
||||
Enabled: settings.Enabled,
|
||||
Filepath: *settings.Filepath,
|
||||
UpCommand: *settings.UpCommand,
|
||||
DownCommand: *settings.DownCommand,
|
||||
ListeningPort: *settings.ListeningPort,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/command"
|
||||
)
|
||||
|
||||
func runUpCommand(ctx context.Context, cmder Cmder, logger Logger,
|
||||
func runCommand(ctx context.Context, cmder Cmder, logger Logger,
|
||||
commandTemplate string, ports []uint16,
|
||||
) (err error) {
|
||||
portStrings := make([]string, len(ports))
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Service_runUpCommand(t *testing.T) {
|
||||
func Test_Service_runCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
@@ -22,7 +22,7 @@ func Test_Service_runUpCommand(t *testing.T) {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Info("1234,5678")
|
||||
|
||||
err := runUpCommand(ctx, cmder, logger, commandTemplate, ports)
|
||||
err := runCommand(ctx, cmder, logger, commandTemplate, ports)
|
||||
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ type Settings struct {
|
||||
PortForwarder PortForwarder
|
||||
Filepath string
|
||||
UpCommand string
|
||||
DownCommand string
|
||||
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
|
||||
ServerName string // needed for PIA
|
||||
CanPortForward bool // needed for PIA
|
||||
@@ -26,6 +27,7 @@ func (s Settings) Copy() (copied Settings) {
|
||||
copied.PortForwarder = s.PortForwarder
|
||||
copied.Filepath = s.Filepath
|
||||
copied.UpCommand = s.UpCommand
|
||||
copied.DownCommand = s.DownCommand
|
||||
copied.Interface = s.Interface
|
||||
copied.ServerName = s.ServerName
|
||||
copied.CanPortForward = s.CanPortForward
|
||||
@@ -40,6 +42,7 @@ func (s *Settings) OverrideWith(update Settings) {
|
||||
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
|
||||
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
|
||||
s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand)
|
||||
s.DownCommand = gosettings.OverrideWithComparable(s.DownCommand, update.DownCommand)
|
||||
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
|
||||
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
|
||||
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
|
||||
|
||||
@@ -74,7 +74,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
s.portMutex.Unlock()
|
||||
|
||||
if s.settings.UpCommand != "" {
|
||||
err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
|
||||
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("running up command: %w", err)
|
||||
s.logger.Error(err.Error())
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Service) Stop() (err error) {
|
||||
@@ -30,6 +31,17 @@ func (s *Service) cleanup() (err error) {
|
||||
s.portMutex.Lock()
|
||||
defer s.portMutex.Unlock()
|
||||
|
||||
if s.settings.DownCommand != "" {
|
||||
const downTimeout = 60 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
|
||||
defer cancel()
|
||||
err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("running down command: %w", err)
|
||||
s.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
for _, port := range s.ports {
|
||||
err = s.portAllower.RemoveAllowedPort(context.Background(), port)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user