feat(portforwarding): allow running script upon port forwarding success (#2399)
This commit is contained in:
@@ -125,6 +125,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
VPN_PORT_FORWARDING_STATUS_FILE="/tmp/gluetun/forwarded_port" \
|
VPN_PORT_FORWARDING_STATUS_FILE="/tmp/gluetun/forwarded_port" \
|
||||||
VPN_PORT_FORWARDING_USERNAME= \
|
VPN_PORT_FORWARDING_USERNAME= \
|
||||||
VPN_PORT_FORWARDING_PASSWORD= \
|
VPN_PORT_FORWARDING_PASSWORD= \
|
||||||
|
VPN_PORT_FORWARDING_UP_COMMAND= \
|
||||||
# # Cyberghost only:
|
# # Cyberghost only:
|
||||||
OPENVPN_CERT= \
|
OPENVPN_CERT= \
|
||||||
OPENVPN_KEY= \
|
OPENVPN_KEY= \
|
||||||
|
|||||||
@@ -380,7 +380,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
|
|
||||||
portForwardLogger := logger.New(log.SetComponent("port forwarding"))
|
portForwardLogger := logger.New(log.SetComponent("port forwarding"))
|
||||||
portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding,
|
portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding,
|
||||||
routingConf, httpClient, firewallConf, portForwardLogger, puid, pgid)
|
routingConf, httpClient, firewallConf, portForwardLogger, cmder, puid, pgid)
|
||||||
portForwardRunError, err := portForwardLooper.Start(ctx)
|
portForwardRunError, err := portForwardLooper.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("starting port forwarding loop: %w", err)
|
return fmt.Errorf("starting port forwarding loop: %w", err)
|
||||||
|
|||||||
150
internal/command/split.go
Normal file
150
internal/command/split.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package command
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrCommandEmpty = errors.New("command is empty")
|
||||||
|
ErrSingleQuoteUnterminated = errors.New("unterminated single-quoted string")
|
||||||
|
ErrDoubleQuoteUnterminated = errors.New("unterminated double-quoted string")
|
||||||
|
ErrEscapeUnterminated = errors.New("unterminated backslash-escape")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Split splits a command string into a slice of arguments.
|
||||||
|
// This is especially important for commands such as:
|
||||||
|
// /bin/sh -c "echo hello"
|
||||||
|
// which should be split into: ["/bin/sh", "-c", "echo hello"]
|
||||||
|
// It supports backslash-escapes, single-quotes and double-quotes.
|
||||||
|
// It does not support:
|
||||||
|
// - the $" quoting style.
|
||||||
|
// - expansion (brace, shell or pathname).
|
||||||
|
func Split(command string) (words []string, err error) {
|
||||||
|
if command == "" {
|
||||||
|
return nil, fmt.Errorf("%w", ErrCommandEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
const bufferSize = 1024
|
||||||
|
buffer := bytes.NewBuffer(make([]byte, bufferSize))
|
||||||
|
|
||||||
|
startIndex := 0
|
||||||
|
|
||||||
|
for startIndex < len(command) {
|
||||||
|
// skip any split characters at the start
|
||||||
|
character, runeSize := utf8.DecodeRuneInString(command[startIndex:])
|
||||||
|
switch {
|
||||||
|
case strings.ContainsRune(" \n\t", character):
|
||||||
|
startIndex += runeSize
|
||||||
|
case character == '\\':
|
||||||
|
// Look ahead to eventually skip an escaped newline
|
||||||
|
if command[startIndex+runeSize:] == "" {
|
||||||
|
return nil, fmt.Errorf("%w: %q", ErrEscapeUnterminated, command)
|
||||||
|
}
|
||||||
|
character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:])
|
||||||
|
if character == '\n' {
|
||||||
|
startIndex += runeSize + runeSize // backslash and newline
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var word string
|
||||||
|
buffer.Reset()
|
||||||
|
word, startIndex, err = splitWord(command, startIndex, buffer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("splitting word in %q: %w", command, err)
|
||||||
|
}
|
||||||
|
words = append(words, word)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return words, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WARNING: buffer must be cleared before calling this function.
|
||||||
|
func splitWord(input string, startIndex int, buffer *bytes.Buffer) (
|
||||||
|
word string, newStartIndex int, err error,
|
||||||
|
) {
|
||||||
|
cursor := startIndex
|
||||||
|
for cursor < len(input) {
|
||||||
|
character, runeLength := utf8.DecodeRuneInString(input[cursor:])
|
||||||
|
cursor += runeLength
|
||||||
|
if character == '"' ||
|
||||||
|
character == '\'' ||
|
||||||
|
character == '\\' ||
|
||||||
|
character == ' ' ||
|
||||||
|
character == '\n' ||
|
||||||
|
character == '\t' {
|
||||||
|
buffer.WriteString(input[startIndex : cursor-runeLength])
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.ContainsRune(" \n\t", character): // spacing character
|
||||||
|
return buffer.String(), cursor, nil
|
||||||
|
case character == '"':
|
||||||
|
return handleDoubleQuoted(input, cursor, buffer)
|
||||||
|
case character == '\'':
|
||||||
|
return handleSingleQuoted(input, cursor, buffer)
|
||||||
|
case character == '\\':
|
||||||
|
return handleEscaped(input, cursor, buffer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.WriteString(input[startIndex:])
|
||||||
|
return buffer.String(), len(input), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleDoubleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||||
|
word string, newStartIndex int, err error,
|
||||||
|
) {
|
||||||
|
cursor := startIndex
|
||||||
|
for cursor < len(input) {
|
||||||
|
nextCharacter, nextRuneLength := utf8.DecodeRuneInString(input[cursor:])
|
||||||
|
cursor += nextRuneLength
|
||||||
|
switch nextCharacter {
|
||||||
|
case '"': // end of the double quoted string
|
||||||
|
buffer.WriteString(input[startIndex : cursor-nextRuneLength])
|
||||||
|
return splitWord(input, cursor, buffer)
|
||||||
|
case '\\': // escaped character
|
||||||
|
escapedCharacter, escapedRuneLength := utf8.DecodeRuneInString(input[cursor:])
|
||||||
|
cursor += escapedRuneLength
|
||||||
|
if !strings.ContainsRune("$`\"\n\\", escapedCharacter) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
buffer.WriteString(input[startIndex : cursor-nextRuneLength-escapedRuneLength])
|
||||||
|
if escapedCharacter != '\n' {
|
||||||
|
// skip backslash entirely for the newline character
|
||||||
|
buffer.WriteRune(escapedCharacter)
|
||||||
|
}
|
||||||
|
startIndex = cursor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("%w", ErrDoubleQuoteUnterminated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) (
|
||||||
|
word string, newStartIndex int, err error,
|
||||||
|
) {
|
||||||
|
closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'')
|
||||||
|
if closingQuoteIndex == -1 {
|
||||||
|
return "", 0, fmt.Errorf("%w", ErrSingleQuoteUnterminated)
|
||||||
|
}
|
||||||
|
buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex])
|
||||||
|
const singleQuoteRuneLength = 1
|
||||||
|
startIndex += closingQuoteIndex + singleQuoteRuneLength
|
||||||
|
return splitWord(input, startIndex, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleEscaped(input string, startIndex int, buffer *bytes.Buffer) (
|
||||||
|
word string, newStartIndex int, err error,
|
||||||
|
) {
|
||||||
|
if input[startIndex:] == "" {
|
||||||
|
return "", 0, fmt.Errorf("%w", ErrEscapeUnterminated)
|
||||||
|
}
|
||||||
|
character, runeLength := utf8.DecodeRuneInString(input[startIndex:])
|
||||||
|
if character != '\n' { // backslash-escaped newline is ignored
|
||||||
|
buffer.WriteString(input[startIndex : startIndex+runeLength])
|
||||||
|
}
|
||||||
|
startIndex += runeLength
|
||||||
|
return splitWord(input, startIndex, buffer)
|
||||||
|
}
|
||||||
110
internal/command/split_test.go
Normal file
110
internal/command/split_test.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package command
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Split(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
command string
|
||||||
|
words []string
|
||||||
|
errWrapped error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"empty": {
|
||||||
|
command: "",
|
||||||
|
errWrapped: ErrCommandEmpty,
|
||||||
|
errMessage: "command is empty",
|
||||||
|
},
|
||||||
|
"concrete_sh_command": {
|
||||||
|
command: `/bin/sh -c "echo 123"`,
|
||||||
|
words: []string{"/bin/sh", "-c", "echo 123"},
|
||||||
|
},
|
||||||
|
"single_word": {
|
||||||
|
command: "word1",
|
||||||
|
words: []string{"word1"},
|
||||||
|
},
|
||||||
|
"two_words_single_space": {
|
||||||
|
command: "word1 word2",
|
||||||
|
words: []string{"word1", "word2"},
|
||||||
|
},
|
||||||
|
"two_words_multiple_space": {
|
||||||
|
command: "word1 word2",
|
||||||
|
words: []string{"word1", "word2"},
|
||||||
|
},
|
||||||
|
"two_words_no_expansion": {
|
||||||
|
command: "word1* word2?",
|
||||||
|
words: []string{"word1*", "word2?"},
|
||||||
|
},
|
||||||
|
"escaped_single quote": {
|
||||||
|
command: "ain\\'t good",
|
||||||
|
words: []string{"ain't", "good"},
|
||||||
|
},
|
||||||
|
"escaped_single_quote_all_single_quoted": {
|
||||||
|
command: "'ain'\\''t good'",
|
||||||
|
words: []string{"ain't good"},
|
||||||
|
},
|
||||||
|
"empty_single_quoted": {
|
||||||
|
command: "word1 '' word2",
|
||||||
|
words: []string{"word1", "", "word2"},
|
||||||
|
},
|
||||||
|
"escaped_newline": {
|
||||||
|
command: "word1\\\nword2",
|
||||||
|
words: []string{"word1word2"},
|
||||||
|
},
|
||||||
|
"quoted_newline": {
|
||||||
|
command: "text \"with\na\" quoted newline",
|
||||||
|
words: []string{"text", "with\na", "quoted", "newline"},
|
||||||
|
},
|
||||||
|
"quoted_escaped_newline": {
|
||||||
|
command: "\"word1\\d\\\\\\\" word2\\\nword3 word4\"",
|
||||||
|
words: []string{"word1\\d\\\" word2word3 word4"},
|
||||||
|
},
|
||||||
|
"escaped_separated_newline": {
|
||||||
|
command: "word1 \\\n word2",
|
||||||
|
words: []string{"word1", "word2"},
|
||||||
|
},
|
||||||
|
"double_quotes_no_spacing": {
|
||||||
|
command: "word1\"word2\"word3",
|
||||||
|
words: []string{"word1word2word3"},
|
||||||
|
},
|
||||||
|
"unterminated_single_quote": {
|
||||||
|
command: "'abc'\\''def",
|
||||||
|
errWrapped: ErrSingleQuoteUnterminated,
|
||||||
|
errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`,
|
||||||
|
},
|
||||||
|
"unterminated_double_quote": {
|
||||||
|
command: "\"abc'def",
|
||||||
|
errWrapped: ErrDoubleQuoteUnterminated,
|
||||||
|
errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`,
|
||||||
|
},
|
||||||
|
"unterminated_escape": {
|
||||||
|
command: "abc\\",
|
||||||
|
errWrapped: ErrEscapeUnterminated,
|
||||||
|
errMessage: `splitting word in "abc\\": unterminated backslash-escape`,
|
||||||
|
},
|
||||||
|
"unterminated_escape_only": {
|
||||||
|
command: " \\",
|
||||||
|
errWrapped: ErrEscapeUnterminated,
|
||||||
|
errMessage: `unterminated backslash-escape: " \\"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
words, err := Split(testCase.command)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.words, words)
|
||||||
|
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||||
|
if testCase.errWrapped != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,6 +29,10 @@ type PortForwarding struct {
|
|||||||
// to write to a file. It cannot be nil for the
|
// to write to a file. It cannot be nil for the
|
||||||
// internal state
|
// internal state
|
||||||
Filepath *string `json:"status_file_path"`
|
Filepath *string `json:"status_file_path"`
|
||||||
|
// UpCommand is the command to use when the port forwarding is up.
|
||||||
|
// It can be the empty string to indicate not to run a command.
|
||||||
|
// It cannot be nil in the internal state.
|
||||||
|
UpCommand *string `json:"up_command"`
|
||||||
// ListeningPort is the port traffic would be redirected to from the
|
// ListeningPort is the port traffic would be redirected to from the
|
||||||
// forwarded port. The redirection is disabled if it is set to 0, which
|
// forwarded port. The redirection is disabled if it is set to 0, which
|
||||||
// is its default as well.
|
// is its default as well.
|
||||||
@@ -84,6 +88,7 @@ func (p *PortForwarding) Copy() (copied PortForwarding) {
|
|||||||
Enabled: gosettings.CopyPointer(p.Enabled),
|
Enabled: gosettings.CopyPointer(p.Enabled),
|
||||||
Provider: gosettings.CopyPointer(p.Provider),
|
Provider: gosettings.CopyPointer(p.Provider),
|
||||||
Filepath: gosettings.CopyPointer(p.Filepath),
|
Filepath: gosettings.CopyPointer(p.Filepath),
|
||||||
|
UpCommand: gosettings.CopyPointer(p.UpCommand),
|
||||||
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
|
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
|
||||||
Username: p.Username,
|
Username: p.Username,
|
||||||
Password: p.Password,
|
Password: p.Password,
|
||||||
@@ -94,6 +99,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) {
|
|||||||
p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled)
|
p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled)
|
||||||
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
|
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
|
||||||
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
|
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
|
||||||
|
p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand)
|
||||||
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
|
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
|
||||||
p.Username = gosettings.OverrideWithComparable(p.Username, other.Username)
|
p.Username = gosettings.OverrideWithComparable(p.Username, other.Username)
|
||||||
p.Password = gosettings.OverrideWithComparable(p.Password, other.Password)
|
p.Password = gosettings.OverrideWithComparable(p.Password, other.Password)
|
||||||
@@ -103,6 +109,7 @@ func (p *PortForwarding) setDefaults() {
|
|||||||
p.Enabled = gosettings.DefaultPointer(p.Enabled, false)
|
p.Enabled = gosettings.DefaultPointer(p.Enabled, false)
|
||||||
p.Provider = gosettings.DefaultPointer(p.Provider, "")
|
p.Provider = gosettings.DefaultPointer(p.Provider, "")
|
||||||
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
|
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
|
||||||
|
p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "")
|
||||||
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
|
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,6 +142,11 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) {
|
|||||||
}
|
}
|
||||||
node.Appendf("Forwarded port file path: %s", filepath)
|
node.Appendf("Forwarded port file path: %s", filepath)
|
||||||
|
|
||||||
|
command := *p.UpCommand
|
||||||
|
if command != "" {
|
||||||
|
node.Appendf("Forwarded port command: %s", command)
|
||||||
|
}
|
||||||
|
|
||||||
if p.Username != "" {
|
if p.Username != "" {
|
||||||
credentialsNode := node.Appendf("Credentials:")
|
credentialsNode := node.Appendf("Credentials:")
|
||||||
credentialsNode.Appendf("Username: %s", p.Username)
|
credentialsNode.Appendf("Username: %s", p.Username)
|
||||||
@@ -163,6 +175,9 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) {
|
|||||||
"PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE",
|
"PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE",
|
||||||
))
|
))
|
||||||
|
|
||||||
|
p.UpCommand = r.Get("VPN_PORT_FORWARDING_UP_COMMAND",
|
||||||
|
reader.ForceLowercase(false))
|
||||||
|
|
||||||
p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
|
p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package portforward
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service interface {
|
type Service interface {
|
||||||
@@ -29,3 +30,8 @@ type Logger interface {
|
|||||||
Warn(s string)
|
Warn(s string)
|
||||||
Error(s string)
|
Error(s string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Cmder interface {
|
||||||
|
Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string,
|
||||||
|
waitError <-chan error, startErr error)
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type Loop struct {
|
|||||||
client *http.Client
|
client *http.Client
|
||||||
portAllower PortAllower
|
portAllower PortAllower
|
||||||
logger Logger
|
logger Logger
|
||||||
|
cmder Cmder
|
||||||
// Fixed parameters
|
// Fixed parameters
|
||||||
uid, gid int
|
uid, gid int
|
||||||
// Internal channels and locks
|
// Internal channels and locks
|
||||||
@@ -34,7 +35,7 @@ type Loop struct {
|
|||||||
|
|
||||||
func NewLoop(settings settings.PortForwarding, routing Routing,
|
func NewLoop(settings settings.PortForwarding, routing Routing,
|
||||||
client *http.Client, portAllower PortAllower,
|
client *http.Client, portAllower PortAllower,
|
||||||
logger Logger, uid, gid int,
|
logger Logger, cmder Cmder, uid, gid int,
|
||||||
) *Loop {
|
) *Loop {
|
||||||
return &Loop{
|
return &Loop{
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -42,6 +43,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
|
|||||||
Service: service.Settings{
|
Service: service.Settings{
|
||||||
Enabled: settings.Enabled,
|
Enabled: settings.Enabled,
|
||||||
Filepath: *settings.Filepath,
|
Filepath: *settings.Filepath,
|
||||||
|
UpCommand: *settings.UpCommand,
|
||||||
ListeningPort: *settings.ListeningPort,
|
ListeningPort: *settings.ListeningPort,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -49,6 +51,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
|
|||||||
client: client,
|
client: client,
|
||||||
portAllower: portAllower,
|
portAllower: portAllower,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
cmder: cmder,
|
||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
}
|
}
|
||||||
@@ -115,7 +118,7 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
|
|||||||
*serviceSettings.Enabled = *serviceSettings.Enabled && *l.settings.VPNIsUp
|
*serviceSettings.Enabled = *serviceSettings.Enabled && *l.settings.VPNIsUp
|
||||||
|
|
||||||
l.service = service.New(serviceSettings, l.routing, l.client,
|
l.service = service.New(serviceSettings, l.routing, l.client,
|
||||||
l.portAllower, l.logger, l.uid, l.gid)
|
l.portAllower, l.logger, l.cmder, l.uid, l.gid)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
serviceRunError, err = l.service.Start(runCtx)
|
serviceRunError, err = l.service.Start(runCtx)
|
||||||
|
|||||||
59
internal/portforward/service/command.go
Normal file
59
internal/portforward/service/command.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/command"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runUpCommand(ctx context.Context, cmder Cmder, logger Logger,
|
||||||
|
commandTemplate string, ports []uint16,
|
||||||
|
) (err error) {
|
||||||
|
portStrings := make([]string, len(ports))
|
||||||
|
for i, port := range ports {
|
||||||
|
portStrings[i] = fmt.Sprint(int(port))
|
||||||
|
}
|
||||||
|
portsString := strings.Join(portStrings, ",")
|
||||||
|
commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString)
|
||||||
|
args, err := command.Split(commandString)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, args[0], args[1:]...) // #nosec G204
|
||||||
|
stdout, stderr, waitError, err := cmder.Start(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
streamCtx, streamCancel := context.WithCancel(context.Background())
|
||||||
|
streamDone := make(chan struct{})
|
||||||
|
go streamLines(streamCtx, streamDone, logger, stdout, stderr)
|
||||||
|
|
||||||
|
err = <-waitError
|
||||||
|
streamCancel()
|
||||||
|
<-streamDone
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamLines(ctx context.Context, done chan<- struct{},
|
||||||
|
logger Logger, stdout, stderr <-chan string,
|
||||||
|
) {
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
var line string
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case line = <-stdout:
|
||||||
|
logger.Info(line)
|
||||||
|
case line = <-stderr:
|
||||||
|
logger.Error(line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
28
internal/portforward/service/command_test.go
Normal file
28
internal/portforward/service/command_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
"github.com/qdm12/gluetun/internal/command"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Service_runUpCommand(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cmder := command.New()
|
||||||
|
const commandTemplate = `/bin/sh -c "echo {{PORTS}}"`
|
||||||
|
ports := []uint16{1234, 5678}
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
logger.EXPECT().Info("1234,5678")
|
||||||
|
|
||||||
|
err := runUpCommand(ctx, cmder, logger, commandTemplate, ports)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
)
|
)
|
||||||
@@ -32,3 +33,8 @@ type PortForwarder interface {
|
|||||||
ports []uint16, err error)
|
ports []uint16, err error)
|
||||||
KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error)
|
KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Cmder interface {
|
||||||
|
Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string,
|
||||||
|
waitError <-chan error, startErr error)
|
||||||
|
}
|
||||||
|
|||||||
3
internal/portforward/service/mocks_generate_test.go
Normal file
3
internal/portforward/service/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger
|
||||||
82
internal/portforward/service/mocks_test.go
Normal file
82
internal/portforward/service/mocks_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/portforward/service (interfaces: Logger)
|
||||||
|
|
||||||
|
// Package service is a generated GoMock package.
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockLogger is a mock of Logger interface.
|
||||||
|
type MockLogger struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockLoggerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||||
|
type MockLoggerMockRecorder struct {
|
||||||
|
mock *MockLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockLogger creates a new mock instance.
|
||||||
|
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
mock := &MockLogger{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug mocks base method.
|
||||||
|
func (m *MockLogger) Debug(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Debug", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug indicates an expected call of Debug.
|
||||||
|
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error mocks base method.
|
||||||
|
func (m *MockLogger) Error(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Error", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error indicates an expected call of Error.
|
||||||
|
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info mocks base method.
|
||||||
|
func (m *MockLogger) Info(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Info", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info indicates an expected call of Info.
|
||||||
|
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn mocks base method.
|
||||||
|
func (m *MockLogger) Warn(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Warn", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn indicates an expected call of Warn.
|
||||||
|
func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0)
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ type Service struct {
|
|||||||
client *http.Client
|
client *http.Client
|
||||||
portAllower PortAllower
|
portAllower PortAllower
|
||||||
logger Logger
|
logger Logger
|
||||||
|
cmder Cmder
|
||||||
// Internal channels and locks
|
// Internal channels and locks
|
||||||
startStopMutex sync.Mutex
|
startStopMutex sync.Mutex
|
||||||
keepPortCancel context.CancelFunc
|
keepPortCancel context.CancelFunc
|
||||||
@@ -26,7 +27,7 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(settings Settings, routing Routing, client *http.Client,
|
func New(settings Settings, routing Routing, client *http.Client,
|
||||||
portAllower PortAllower, logger Logger, puid, pgid int,
|
portAllower PortAllower, logger Logger, cmder Cmder, puid, pgid int,
|
||||||
) *Service {
|
) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
// Fixed parameters
|
// Fixed parameters
|
||||||
@@ -38,6 +39,7 @@ func New(settings Settings, routing Routing, client *http.Client,
|
|||||||
client: client,
|
client: client,
|
||||||
portAllower: portAllower,
|
portAllower: portAllower,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
cmder: cmder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type Settings struct {
|
|||||||
Enabled *bool
|
Enabled *bool
|
||||||
PortForwarder PortForwarder
|
PortForwarder PortForwarder
|
||||||
Filepath string
|
Filepath string
|
||||||
|
UpCommand string
|
||||||
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
|
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
|
||||||
ServerName string // needed for PIA
|
ServerName string // needed for PIA
|
||||||
CanPortForward bool // needed for PIA
|
CanPortForward bool // needed for PIA
|
||||||
@@ -24,6 +25,7 @@ func (s Settings) Copy() (copied Settings) {
|
|||||||
copied.Enabled = gosettings.CopyPointer(s.Enabled)
|
copied.Enabled = gosettings.CopyPointer(s.Enabled)
|
||||||
copied.PortForwarder = s.PortForwarder
|
copied.PortForwarder = s.PortForwarder
|
||||||
copied.Filepath = s.Filepath
|
copied.Filepath = s.Filepath
|
||||||
|
copied.UpCommand = s.UpCommand
|
||||||
copied.Interface = s.Interface
|
copied.Interface = s.Interface
|
||||||
copied.ServerName = s.ServerName
|
copied.ServerName = s.ServerName
|
||||||
copied.CanPortForward = s.CanPortForward
|
copied.CanPortForward = s.CanPortForward
|
||||||
@@ -37,6 +39,7 @@ func (s *Settings) OverrideWith(update Settings) {
|
|||||||
s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled)
|
s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled)
|
||||||
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
|
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
|
||||||
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
|
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
|
||||||
|
s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand)
|
||||||
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
|
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
|
||||||
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
|
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
|
||||||
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
|
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
|
||||||
|
|||||||
@@ -73,6 +73,14 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
|||||||
s.ports = ports
|
s.ports = ports
|
||||||
s.portMutex.Unlock()
|
s.portMutex.Unlock()
|
||||||
|
|
||||||
|
if s.settings.UpCommand != "" {
|
||||||
|
err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("running up command: %w", err)
|
||||||
|
s.logger.Error(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
|
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
|
||||||
s.keepPortCancel = keepPortCancel
|
s.keepPortCancel = keepPortCancel
|
||||||
runErrorCh := make(chan error)
|
runErrorCh := make(chan error)
|
||||||
|
|||||||
Reference in New Issue
Block a user