feat(portforwarding): VPN_PORT_FORWARDING_DOWN_COMMAND option
This commit is contained in:
@@ -126,6 +126,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
VPN_PORT_FORWARDING_USERNAME= \
|
VPN_PORT_FORWARDING_USERNAME= \
|
||||||
VPN_PORT_FORWARDING_PASSWORD= \
|
VPN_PORT_FORWARDING_PASSWORD= \
|
||||||
VPN_PORT_FORWARDING_UP_COMMAND= \
|
VPN_PORT_FORWARDING_UP_COMMAND= \
|
||||||
|
VPN_PORT_FORWARDING_DOWN_COMMAND= \
|
||||||
# # Cyberghost only:
|
# # Cyberghost only:
|
||||||
OPENVPN_CERT= \
|
OPENVPN_CERT= \
|
||||||
OPENVPN_KEY= \
|
OPENVPN_KEY= \
|
||||||
|
|||||||
@@ -33,6 +33,10 @@ type PortForwarding struct {
|
|||||||
// It can be the empty string to indicate not to run a command.
|
// It can be the empty string to indicate not to run a command.
|
||||||
// It cannot be nil in the internal state.
|
// It cannot be nil in the internal state.
|
||||||
UpCommand *string `json:"up_command"`
|
UpCommand *string `json:"up_command"`
|
||||||
|
// DownCommand is the command to use after the port forwarding goes down.
|
||||||
|
// It can be the empty string to indicate to NOT run a command.
|
||||||
|
// It cannot be nil in the internal state.
|
||||||
|
DownCommand *string `json:"down_command"`
|
||||||
// ListeningPort is the port traffic would be redirected to from the
|
// ListeningPort is the port traffic would be redirected to from the
|
||||||
// forwarded port. The redirection is disabled if it is set to 0, which
|
// forwarded port. The redirection is disabled if it is set to 0, which
|
||||||
// is its default as well.
|
// is its default as well.
|
||||||
@@ -89,6 +93,7 @@ func (p *PortForwarding) Copy() (copied PortForwarding) {
|
|||||||
Provider: gosettings.CopyPointer(p.Provider),
|
Provider: gosettings.CopyPointer(p.Provider),
|
||||||
Filepath: gosettings.CopyPointer(p.Filepath),
|
Filepath: gosettings.CopyPointer(p.Filepath),
|
||||||
UpCommand: gosettings.CopyPointer(p.UpCommand),
|
UpCommand: gosettings.CopyPointer(p.UpCommand),
|
||||||
|
DownCommand: gosettings.CopyPointer(p.DownCommand),
|
||||||
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
|
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
|
||||||
Username: p.Username,
|
Username: p.Username,
|
||||||
Password: p.Password,
|
Password: p.Password,
|
||||||
@@ -100,6 +105,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) {
|
|||||||
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
|
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
|
||||||
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
|
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
|
||||||
p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand)
|
p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand)
|
||||||
|
p.DownCommand = gosettings.OverrideWithPointer(p.DownCommand, other.DownCommand)
|
||||||
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
|
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
|
||||||
p.Username = gosettings.OverrideWithComparable(p.Username, other.Username)
|
p.Username = gosettings.OverrideWithComparable(p.Username, other.Username)
|
||||||
p.Password = gosettings.OverrideWithComparable(p.Password, other.Password)
|
p.Password = gosettings.OverrideWithComparable(p.Password, other.Password)
|
||||||
@@ -110,6 +116,7 @@ func (p *PortForwarding) setDefaults() {
|
|||||||
p.Provider = gosettings.DefaultPointer(p.Provider, "")
|
p.Provider = gosettings.DefaultPointer(p.Provider, "")
|
||||||
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
|
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
|
||||||
p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "")
|
p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "")
|
||||||
|
p.DownCommand = gosettings.DefaultPointer(p.DownCommand, "")
|
||||||
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
|
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,9 +149,11 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) {
|
|||||||
}
|
}
|
||||||
node.Appendf("Forwarded port file path: %s", filepath)
|
node.Appendf("Forwarded port file path: %s", filepath)
|
||||||
|
|
||||||
command := *p.UpCommand
|
if *p.UpCommand != "" {
|
||||||
if command != "" {
|
node.Appendf("Forwarded port up command: %s", *p.UpCommand)
|
||||||
node.Appendf("Forwarded port command: %s", command)
|
}
|
||||||
|
if *p.DownCommand != "" {
|
||||||
|
node.Appendf("Forwarded port down command: %s", *p.DownCommand)
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Username != "" {
|
if p.Username != "" {
|
||||||
@@ -178,6 +187,9 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) {
|
|||||||
p.UpCommand = r.Get("VPN_PORT_FORWARDING_UP_COMMAND",
|
p.UpCommand = r.Get("VPN_PORT_FORWARDING_UP_COMMAND",
|
||||||
reader.ForceLowercase(false))
|
reader.ForceLowercase(false))
|
||||||
|
|
||||||
|
p.DownCommand = r.Get("VPN_PORT_FORWARDING_DOWN_COMMAND",
|
||||||
|
reader.ForceLowercase(false))
|
||||||
|
|
||||||
p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
|
p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
|
|||||||
Enabled: settings.Enabled,
|
Enabled: settings.Enabled,
|
||||||
Filepath: *settings.Filepath,
|
Filepath: *settings.Filepath,
|
||||||
UpCommand: *settings.UpCommand,
|
UpCommand: *settings.UpCommand,
|
||||||
|
DownCommand: *settings.DownCommand,
|
||||||
ListeningPort: *settings.ListeningPort,
|
ListeningPort: *settings.ListeningPort,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/command"
|
"github.com/qdm12/gluetun/internal/command"
|
||||||
)
|
)
|
||||||
|
|
||||||
func runUpCommand(ctx context.Context, cmder Cmder, logger Logger,
|
func runCommand(ctx context.Context, cmder Cmder, logger Logger,
|
||||||
commandTemplate string, ports []uint16,
|
commandTemplate string, ports []uint16,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
portStrings := make([]string, len(ports))
|
portStrings := make([]string, len(ports))
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Service_runUpCommand(t *testing.T) {
|
func Test_Service_runCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ func Test_Service_runUpCommand(t *testing.T) {
|
|||||||
logger := NewMockLogger(ctrl)
|
logger := NewMockLogger(ctrl)
|
||||||
logger.EXPECT().Info("1234,5678")
|
logger.EXPECT().Info("1234,5678")
|
||||||
|
|
||||||
err := runUpCommand(ctx, cmder, logger, commandTemplate, ports)
|
err := runCommand(ctx, cmder, logger, commandTemplate, ports)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ type Settings struct {
|
|||||||
PortForwarder PortForwarder
|
PortForwarder PortForwarder
|
||||||
Filepath string
|
Filepath string
|
||||||
UpCommand string
|
UpCommand string
|
||||||
|
DownCommand string
|
||||||
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
|
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
|
||||||
ServerName string // needed for PIA
|
ServerName string // needed for PIA
|
||||||
CanPortForward bool // needed for PIA
|
CanPortForward bool // needed for PIA
|
||||||
@@ -26,6 +27,7 @@ func (s Settings) Copy() (copied Settings) {
|
|||||||
copied.PortForwarder = s.PortForwarder
|
copied.PortForwarder = s.PortForwarder
|
||||||
copied.Filepath = s.Filepath
|
copied.Filepath = s.Filepath
|
||||||
copied.UpCommand = s.UpCommand
|
copied.UpCommand = s.UpCommand
|
||||||
|
copied.DownCommand = s.DownCommand
|
||||||
copied.Interface = s.Interface
|
copied.Interface = s.Interface
|
||||||
copied.ServerName = s.ServerName
|
copied.ServerName = s.ServerName
|
||||||
copied.CanPortForward = s.CanPortForward
|
copied.CanPortForward = s.CanPortForward
|
||||||
@@ -40,6 +42,7 @@ func (s *Settings) OverrideWith(update Settings) {
|
|||||||
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
|
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
|
||||||
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
|
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
|
||||||
s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand)
|
s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand)
|
||||||
|
s.DownCommand = gosettings.OverrideWithComparable(s.DownCommand, update.DownCommand)
|
||||||
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
|
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
|
||||||
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
|
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
|
||||||
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
|
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
|||||||
s.portMutex.Unlock()
|
s.portMutex.Unlock()
|
||||||
|
|
||||||
if s.settings.UpCommand != "" {
|
if s.settings.UpCommand != "" {
|
||||||
err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
|
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("running up command: %w", err)
|
err = fmt.Errorf("running up command: %w", err)
|
||||||
s.logger.Error(err.Error())
|
s.logger.Error(err.Error())
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Service) Stop() (err error) {
|
func (s *Service) Stop() (err error) {
|
||||||
@@ -30,6 +31,17 @@ func (s *Service) cleanup() (err error) {
|
|||||||
s.portMutex.Lock()
|
s.portMutex.Lock()
|
||||||
defer s.portMutex.Unlock()
|
defer s.portMutex.Unlock()
|
||||||
|
|
||||||
|
if s.settings.DownCommand != "" {
|
||||||
|
const downTimeout = 60 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
|
||||||
|
defer cancel()
|
||||||
|
err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("running down command: %w", err)
|
||||||
|
s.logger.Error(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, port := range s.ports {
|
for _, port := range s.ports {
|
||||||
err = s.portAllower.RemoveAllowedPort(context.Background(), port)
|
err = s.portAllower.RemoveAllowedPort(context.Background(), port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user