feat(portforwarding): allow running script upon port forwarding success (#2399)
This commit is contained in:
@@ -3,6 +3,7 @@ package portforward
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
@@ -29,3 +30,8 @@ type Logger interface {
|
||||
Warn(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
type Cmder interface {
|
||||
Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string,
|
||||
waitError <-chan error, startErr error)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ type Loop struct {
|
||||
client *http.Client
|
||||
portAllower PortAllower
|
||||
logger Logger
|
||||
cmder Cmder
|
||||
// Fixed parameters
|
||||
uid, gid int
|
||||
// Internal channels and locks
|
||||
@@ -34,7 +35,7 @@ type Loop struct {
|
||||
|
||||
func NewLoop(settings settings.PortForwarding, routing Routing,
|
||||
client *http.Client, portAllower PortAllower,
|
||||
logger Logger, uid, gid int,
|
||||
logger Logger, cmder Cmder, uid, gid int,
|
||||
) *Loop {
|
||||
return &Loop{
|
||||
settings: Settings{
|
||||
@@ -42,6 +43,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
|
||||
Service: service.Settings{
|
||||
Enabled: settings.Enabled,
|
||||
Filepath: *settings.Filepath,
|
||||
UpCommand: *settings.UpCommand,
|
||||
ListeningPort: *settings.ListeningPort,
|
||||
},
|
||||
},
|
||||
@@ -49,6 +51,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
|
||||
client: client,
|
||||
portAllower: portAllower,
|
||||
logger: logger,
|
||||
cmder: cmder,
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
}
|
||||
@@ -115,7 +118,7 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
|
||||
*serviceSettings.Enabled = *serviceSettings.Enabled && *l.settings.VPNIsUp
|
||||
|
||||
l.service = service.New(serviceSettings, l.routing, l.client,
|
||||
l.portAllower, l.logger, l.uid, l.gid)
|
||||
l.portAllower, l.logger, l.cmder, l.uid, l.gid)
|
||||
|
||||
var err error
|
||||
serviceRunError, err = l.service.Start(runCtx)
|
||||
|
||||
59
internal/portforward/service/command.go
Normal file
59
internal/portforward/service/command.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/command"
|
||||
)
|
||||
|
||||
func runUpCommand(ctx context.Context, cmder Cmder, logger Logger,
|
||||
commandTemplate string, ports []uint16,
|
||||
) (err error) {
|
||||
portStrings := make([]string, len(ports))
|
||||
for i, port := range ports {
|
||||
portStrings[i] = fmt.Sprint(int(port))
|
||||
}
|
||||
portsString := strings.Join(portStrings, ",")
|
||||
commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString)
|
||||
args, err := command.Split(commandString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing command: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, args[0], args[1:]...) // #nosec G204
|
||||
stdout, stderr, waitError, err := cmder.Start(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
streamCtx, streamCancel := context.WithCancel(context.Background())
|
||||
streamDone := make(chan struct{})
|
||||
go streamLines(streamCtx, streamDone, logger, stdout, stderr)
|
||||
|
||||
err = <-waitError
|
||||
streamCancel()
|
||||
<-streamDone
|
||||
return err
|
||||
}
|
||||
|
||||
func streamLines(ctx context.Context, done chan<- struct{},
|
||||
logger Logger, stdout, stderr <-chan string,
|
||||
) {
|
||||
defer close(done)
|
||||
|
||||
var line string
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case line = <-stdout:
|
||||
logger.Info(line)
|
||||
case line = <-stderr:
|
||||
logger.Error(line)
|
||||
}
|
||||
}
|
||||
}
|
||||
28
internal/portforward/service/command_test.go
Normal file
28
internal/portforward/service/command_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//go:build linux
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/command"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Service_runUpCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
cmder := command.New()
|
||||
const commandTemplate = `/bin/sh -c "echo {{PORTS}}"`
|
||||
ports := []uint16{1234, 5678}
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Info("1234,5678")
|
||||
|
||||
err := runUpCommand(ctx, cmder, logger, commandTemplate, ports)
|
||||
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
@@ -32,3 +33,8 @@ type PortForwarder interface {
|
||||
ports []uint16, err error)
|
||||
KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error)
|
||||
}
|
||||
|
||||
type Cmder interface {
|
||||
Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string,
|
||||
waitError <-chan error, startErr error)
|
||||
}
|
||||
|
||||
3
internal/portforward/service/mocks_generate_test.go
Normal file
3
internal/portforward/service/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package service
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
|
||||
82
internal/portforward/service/mocks_test.go
Normal file
82
internal/portforward/service/mocks_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/portforward/service (interfaces: Logger)
|
||||
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockLogger) Info(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Info", arg0)
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||
}
|
||||
|
||||
// Warn mocks base method.
|
||||
func (m *MockLogger) Warn(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Warn", arg0)
|
||||
}
|
||||
|
||||
// Warn indicates an expected call of Warn.
|
||||
func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0)
|
||||
}
|
||||
@@ -19,6 +19,7 @@ type Service struct {
|
||||
client *http.Client
|
||||
portAllower PortAllower
|
||||
logger Logger
|
||||
cmder Cmder
|
||||
// Internal channels and locks
|
||||
startStopMutex sync.Mutex
|
||||
keepPortCancel context.CancelFunc
|
||||
@@ -26,7 +27,7 @@ type Service struct {
|
||||
}
|
||||
|
||||
func New(settings Settings, routing Routing, client *http.Client,
|
||||
portAllower PortAllower, logger Logger, puid, pgid int,
|
||||
portAllower PortAllower, logger Logger, cmder Cmder, puid, pgid int,
|
||||
) *Service {
|
||||
return &Service{
|
||||
// Fixed parameters
|
||||
@@ -38,6 +39,7 @@ func New(settings Settings, routing Routing, client *http.Client,
|
||||
client: client,
|
||||
portAllower: portAllower,
|
||||
logger: logger,
|
||||
cmder: cmder,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ type Settings struct {
|
||||
Enabled *bool
|
||||
PortForwarder PortForwarder
|
||||
Filepath string
|
||||
UpCommand string
|
||||
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
|
||||
ServerName string // needed for PIA
|
||||
CanPortForward bool // needed for PIA
|
||||
@@ -24,6 +25,7 @@ func (s Settings) Copy() (copied Settings) {
|
||||
copied.Enabled = gosettings.CopyPointer(s.Enabled)
|
||||
copied.PortForwarder = s.PortForwarder
|
||||
copied.Filepath = s.Filepath
|
||||
copied.UpCommand = s.UpCommand
|
||||
copied.Interface = s.Interface
|
||||
copied.ServerName = s.ServerName
|
||||
copied.CanPortForward = s.CanPortForward
|
||||
@@ -37,6 +39,7 @@ func (s *Settings) OverrideWith(update Settings) {
|
||||
s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled)
|
||||
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
|
||||
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
|
||||
s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand)
|
||||
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
|
||||
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
|
||||
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
|
||||
|
||||
@@ -73,6 +73,14 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
s.ports = ports
|
||||
s.portMutex.Unlock()
|
||||
|
||||
if s.settings.UpCommand != "" {
|
||||
err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("running up command: %w", err)
|
||||
s.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
|
||||
s.keepPortCancel = keepPortCancel
|
||||
runErrorCh := make(chan error)
|
||||
|
||||
Reference in New Issue
Block a user