Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67ae5f5065 | ||
|
|
cbfdb25190 | ||
|
|
638f233b3c | ||
|
|
c450c54d67 | ||
|
|
d166314f8b | ||
|
|
7064a44403 | ||
|
|
c33158c13c |
@@ -197,6 +197,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
# Control server
|
||||
HTTP_CONTROL_SERVER_LOG=on \
|
||||
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
||||
HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \
|
||||
# Server data updater
|
||||
UPDATER_PERIOD=0 \
|
||||
UPDATER_MIN_RATIO=0.8 \
|
||||
|
||||
@@ -161,12 +161,14 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
return cli.Update(ctx, args[2:], logger)
|
||||
case "format-servers":
|
||||
return cli.FormatServers(args[2:])
|
||||
case "genkey":
|
||||
return cli.GenKey(args[2:])
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
||||
}
|
||||
}
|
||||
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2023-07-01T00:00:00Z")
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -177,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
Version: buildInfo.Version,
|
||||
Commit: buildInfo.Commit,
|
||||
Created: buildInfo.Created,
|
||||
Announcement: "Wiki moved to https://github.com/qdm12/gluetun-wiki",
|
||||
Announcement: "All control server routes will become private by default after the v3.41.0 release",
|
||||
AnnounceExp: announcementExp,
|
||||
// Sponsor information
|
||||
PaypalUser: "qmcgaw",
|
||||
@@ -474,6 +476,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
||||
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
|
||||
logger.New(log.SetComponent("http server")),
|
||||
allSettings.ControlServer.AuthFilePath,
|
||||
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
|
||||
storage, ipv6Supported)
|
||||
if err != nil {
|
||||
@@ -595,6 +598,7 @@ type clier interface {
|
||||
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
|
||||
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
|
||||
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
|
||||
GenKey(args []string) error
|
||||
}
|
||||
|
||||
type Tun interface {
|
||||
|
||||
1
go.mod
1
go.mod
@@ -8,6 +8,7 @@ require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/klauspost/compress v1.17.8
|
||||
github.com/klauspost/pgzip v1.2.6
|
||||
github.com/pelletier/go-toml/v2 v2.2.2
|
||||
github.com/qdm12/dns v1.11.0
|
||||
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
|
||||
github.com/qdm12/gosettings v0.4.2
|
||||
|
||||
8
go.sum
8
go.sum
@@ -83,6 +83,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
||||
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -113,10 +115,16 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
|
||||
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=
|
||||
|
||||
66
internal/cli/genkey.go
Normal file
66
internal/cli/genkey.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"flag"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (c *CLI) GenKey(args []string) (err error) {
|
||||
flagSet := flag.NewFlagSet("genkey", flag.ExitOnError)
|
||||
err = flagSet.Parse(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing flags: %w", err)
|
||||
}
|
||||
|
||||
const keyLength = 128 / 8
|
||||
keyBytes := make([]byte, keyLength)
|
||||
|
||||
_, _ = rand.Read(keyBytes)
|
||||
|
||||
key := base58Encode(keyBytes)
|
||||
fmt.Println(key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func base58Encode(data []byte) string {
|
||||
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||
const radix = 58
|
||||
|
||||
zcount := 0
|
||||
for zcount < len(data) && data[zcount] == 0 {
|
||||
zcount++
|
||||
}
|
||||
|
||||
// integer simplification of ceil(log(256)/log(58))
|
||||
ceilLog256Div58 := (len(data)-zcount)*555/406 + 1 //nolint:gomnd
|
||||
size := zcount + ceilLog256Div58
|
||||
|
||||
output := make([]byte, size)
|
||||
|
||||
high := size - 1
|
||||
for _, b := range data {
|
||||
i := size - 1
|
||||
for carry := uint32(b); i > high || carry != 0; i-- {
|
||||
carry += 256 * uint32(output[i]) //nolint:gomnd
|
||||
output[i] = byte(carry % radix)
|
||||
carry /= radix
|
||||
}
|
||||
high = i
|
||||
}
|
||||
|
||||
// Determine the additional "zero-gap" in the output buffer
|
||||
additionalZeroGapEnd := zcount
|
||||
for additionalZeroGapEnd < size && output[additionalZeroGapEnd] == 0 {
|
||||
additionalZeroGapEnd++
|
||||
}
|
||||
|
||||
val := output[additionalZeroGapEnd-zcount:]
|
||||
size = len(val)
|
||||
for i := range val {
|
||||
output[i] = alphabet[val[i]]
|
||||
}
|
||||
|
||||
return string(output[:size])
|
||||
}
|
||||
@@ -19,6 +19,11 @@ type ControlServer struct {
|
||||
// Log can be true or false to enable logging on requests.
|
||||
// It cannot be nil in the internal state.
|
||||
Log *bool
|
||||
// AuthFilePath is the path to the file containing the authentication
|
||||
// configuration for the middleware.
|
||||
// It cannot be empty in the internal state and defaults to
|
||||
// /gluetun/auth/config.toml.
|
||||
AuthFilePath string
|
||||
}
|
||||
|
||||
func (c ControlServer) validate() (err error) {
|
||||
@@ -46,6 +51,7 @@ func (c *ControlServer) copy() (copied ControlServer) {
|
||||
return ControlServer{
|
||||
Address: gosettings.CopyPointer(c.Address),
|
||||
Log: gosettings.CopyPointer(c.Log),
|
||||
AuthFilePath: c.AuthFilePath,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,11 +61,13 @@ func (c *ControlServer) copy() (copied ControlServer) {
|
||||
func (c *ControlServer) overrideWith(other ControlServer) {
|
||||
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
|
||||
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
|
||||
c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath)
|
||||
}
|
||||
|
||||
func (c *ControlServer) setDefaults() {
|
||||
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
|
||||
c.Log = gosettings.DefaultPointer(c.Log, true)
|
||||
c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml")
|
||||
}
|
||||
|
||||
func (c ControlServer) String() string {
|
||||
@@ -70,6 +78,7 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("Control server settings:")
|
||||
node.Appendf("Listening address: %s", *c.Address)
|
||||
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
|
||||
node.Appendf("Authentication file path: %s", c.AuthFilePath)
|
||||
return node
|
||||
}
|
||||
|
||||
@@ -78,6 +87,10 @@ func (c *ControlServer) read(r *reader.Reader) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
|
||||
|
||||
c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -191,11 +191,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
||||
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
||||
}
|
||||
|
||||
if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 {
|
||||
if vpnServiceProvider == providers.Custom {
|
||||
switch len(settings.Names) {
|
||||
case 0:
|
||||
case 1:
|
||||
// Allow a single name to be specified for the custom provider in case
|
||||
// the user wants to use VPN server side port forwarding with PIA
|
||||
// which requires a server name for TLS verification.
|
||||
filterChoices.Names = settings.Names
|
||||
default:
|
||||
return fmt.Errorf("%w: %d names specified instead of "+
|
||||
"0 or 1 for the custom provider",
|
||||
ErrNameNotValid, len(settings.Names))
|
||||
}
|
||||
}
|
||||
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
|
||||
if err != nil {
|
||||
@@ -229,6 +237,8 @@ func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string)
|
||||
switch {
|
||||
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
|
||||
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
|
||||
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
|
||||
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
|
||||
case *settings.StreamOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
|
||||
|
||||
@@ -78,7 +78,8 @@ func Test_Settings_String(t *testing.T) {
|
||||
| └── Enabled: no
|
||||
├── Control server settings:
|
||||
| ├── Listening address: :8000
|
||||
| └── Logging: yes
|
||||
| ├── Logging: yes
|
||||
| └── Authentication file path: /gluetun/auth/config.toml
|
||||
├── OS Alpine settings:
|
||||
| ├── Process UID: 1000
|
||||
| └── Process GID: 1000
|
||||
|
||||
98
internal/firewall/delete.go
Normal file
98
internal/firewall/delete.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isDeleteMatchInstruction returns true if the iptables instruction
|
||||
// is a delete instruction by rule matching. It returns false if the
|
||||
// instruction is a delete instruction by line number, or not a delete
|
||||
// instruction.
|
||||
func isDeleteMatchInstruction(instruction string) bool {
|
||||
fields := strings.Fields(instruction)
|
||||
for i, field := range fields {
|
||||
switch {
|
||||
case field != "-D" && field != "--delete": //nolint:goconst
|
||||
continue
|
||||
case i == len(fields)-1: // malformed: missing chain name
|
||||
return false
|
||||
case i == len(fields)-2: // chain name is last field
|
||||
return true
|
||||
default:
|
||||
// chain name is fields[i+1]
|
||||
const base, bitLength = 10, 16
|
||||
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
|
||||
return err != nil // not a line number
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
|
||||
runner Runner, logger Logger) (err error) {
|
||||
targetRule, err := parseIptablesInstruction(instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, iptablesBinary,
|
||||
targetRule, runner, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding iptables chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
logger.Debug("rule matching \"" + instruction + "\" not found")
|
||||
return nil
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
|
||||
instruction, lineNumber))
|
||||
|
||||
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
|
||||
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
|
||||
logger.Debug(cmd.String())
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||
if output != "" {
|
||||
err = fmt.Errorf("%w: %s", err, output)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findLineNumber finds the line number of an iptables rule.
|
||||
// It returns 0 if the rule is not found.
|
||||
func findLineNumber(ctx context.Context, iptablesBinary string,
|
||||
instruction iptablesInstruction, runner Runner, logger Logger) (
|
||||
lineNumber uint16, err error) {
|
||||
listFlags := []string{"-t", instruction.table, "-L", instruction.chain,
|
||||
"--line-numbers", "-n", "-v"}
|
||||
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
|
||||
logger.Debug(cmd.String())
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||
if output != "" {
|
||||
err = fmt.Errorf("%w: %s", err, output)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
chain, err := parseChain(output)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing chain list: %w", err)
|
||||
}
|
||||
|
||||
for _, rule := range chain.rules {
|
||||
if instruction.equalToRule(instruction.table, chain.name, rule) {
|
||||
return rule.lineNumber, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
188
internal/firewall/delete_test.go
Normal file
188
internal/firewall/delete_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_isDeleteMatchInstruction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
isDeleteMatch bool
|
||||
}{
|
||||
"not_delete": {
|
||||
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
|
||||
},
|
||||
"malformed_missing_chain_name": {
|
||||
instruction: "-t nat -D",
|
||||
},
|
||||
"delete_chain_name_last_field": {
|
||||
instruction: "-t nat --delete PREROUTING",
|
||||
isDeleteMatch: true,
|
||||
},
|
||||
"delete_match": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
|
||||
isDeleteMatch: true,
|
||||
},
|
||||
"delete_line_number_last_field": {
|
||||
instruction: "-t nat -D PREROUTING 2",
|
||||
},
|
||||
"delete_line_number": {
|
||||
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
|
||||
},
|
||||
}
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
|
||||
|
||||
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
|
||||
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
|
||||
"^--line-numbers$", "^-n$", "^-v$")
|
||||
}
|
||||
|
||||
func Test_deleteIPTablesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const iptablesBinary = "/sbin/iptables"
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
makeRunner func(ctrl *gomock.Controller) *MockRunner
|
||||
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
||||
"fields count 1 is not even: \"invalid\"",
|
||||
},
|
||||
"list_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().
|
||||
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||
},
|
||||
"rule_not_found": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
|
||||
num pkts bytes target prot opt in out source destination
|
||||
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
|
||||
nil)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
|
||||
return logger
|
||||
},
|
||||
},
|
||||
"rule_found_delete_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||
"num pkts bytes target prot opt in out source destination \n"+
|
||||
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||
},
|
||||
"rule_found_delete_success": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||
"num pkts bytes target prot opt in out source destination \n"+
|
||||
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
instruction := testCase.instruction
|
||||
var runner *MockRunner
|
||||
if testCase.makeRunner != nil {
|
||||
runner = testCase.makeRunner(ctrl)
|
||||
}
|
||||
var logger *MockLogger
|
||||
if testCase.makeLogger != nil {
|
||||
logger = testCase.makeLogger(ctrl)
|
||||
}
|
||||
|
||||
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
13
internal/firewall/interfaces.go
Normal file
13
internal/firewall/interfaces.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package firewall
|
||||
|
||||
import "github.com/qdm12/golibs/command"
|
||||
|
||||
type Runner interface {
|
||||
Run(cmd command.ExecCmd) (output string, err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -40,10 +40,14 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
c.logger.Debug(c.ip6Tables + " " + instruction)
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ip6Tables, instruction, output, err)
|
||||
@@ -55,7 +59,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
case "ACCEPT", "DROP": //nolint:goconst
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
}
|
||||
|
||||
@@ -70,10 +70,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
c.logger.Debug(c.ipTables + " " + instruction)
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ipTables, instruction, output, err)
|
||||
@@ -143,7 +147,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
|
||||
defaultInterface string, connection models.Connection, remove bool) error {
|
||||
protocol := connection.Protocol
|
||||
if protocol == "tcp-client" {
|
||||
protocol = "tcp"
|
||||
protocol = "tcp" //nolint:goconst
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
|
||||
|
||||
381
internal/firewall/list.go
Normal file
381
internal/firewall/list.go
Normal file
@@ -0,0 +1,381 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type chain struct {
|
||||
name string
|
||||
policy string
|
||||
packets uint64
|
||||
bytes uint64
|
||||
rules []chainRule
|
||||
}
|
||||
|
||||
type chainRule struct {
|
||||
lineNumber uint16 // starts from 1 and cannot be zero.
|
||||
packets uint64
|
||||
bytes uint64
|
||||
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
|
||||
protocol string // "tcp", "udp" or "" for all protocols.
|
||||
inputInterface string // input interface, for example "tun0" or "*""
|
||||
outputInterface string // output interface, for example "eth0" or "*""
|
||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
}
|
||||
|
||||
var (
|
||||
ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
)
|
||||
|
||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
// Text example:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// pkts bytes target prot opt in out source destination
|
||||
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
|
||||
iptablesOutput = strings.TrimSpace(iptablesOutput)
|
||||
linesWithComments := strings.Split(iptablesOutput, "\n")
|
||||
|
||||
// Filter out lines starting with a '#' character
|
||||
lines := make([]string, 0, len(linesWithComments))
|
||||
for _, line := range linesWithComments {
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, line)
|
||||
}
|
||||
|
||||
const minLines = 2 // chain general information line + legend line
|
||||
if len(lines) < minLines {
|
||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||
ErrChainListMalformed, iptablesOutput)
|
||||
}
|
||||
|
||||
c, err = parseChainGeneralDataLine(lines[0])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check for the legend line
|
||||
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
|
||||
legendLine := strings.TrimSpace(lines[1])
|
||||
legendFields := strings.Fields(legendLine)
|
||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||
}
|
||||
|
||||
lines = lines[2:] // remove chain general information line and legend line
|
||||
if len(lines) == 0 {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
c.rules = make([]chainRule, len(lines))
|
||||
for i, line := range lines {
|
||||
c.rules[i], err = parseChainRuleLine(line)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// parseChainGeneralDataLine parses the first line of iptables chain list output.
|
||||
// For example, it can parse the following line:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// It returns a chain struct with the parsed data.
|
||||
func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
runesToRemove := []rune{'(', ')', ','}
|
||||
for _, r := range runesToRemove {
|
||||
line = strings.ReplaceAll(line, string(r), "")
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
const expectedNumberOfFields = 8
|
||||
if len(fields) != expectedNumberOfFields {
|
||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
indexToExpectedValue := map[int]string{
|
||||
0: "Chain",
|
||||
2: "policy",
|
||||
5: "packets",
|
||||
7: "bytes",
|
||||
}
|
||||
for index, expectedValue := range indexToExpectedValue {
|
||||
if fields[index] == expectedValue {
|
||||
continue
|
||||
}
|
||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||
ErrChainListMalformed, expectedValue, index, line)
|
||||
}
|
||||
|
||||
base.name = fields[1] // chain name could be custom
|
||||
base.policy = fields[3]
|
||||
err = checkTarget(base.policy)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
|
||||
}
|
||||
|
||||
packets, err := parseMetricSize(fields[4])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
base.packets = packets
|
||||
|
||||
bytes, err := parseMetricSize(fields[6])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
base.bytes = bytes
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||
)
|
||||
|
||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
|
||||
const minFields = 10
|
||||
if len(fields) < minFields {
|
||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
for fieldIndex, field := range fields[:minFields] {
|
||||
err = parseChainRuleField(fieldIndex, field, &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(fields) > minFields {
|
||||
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||
if field == "" {
|
||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||
}
|
||||
|
||||
const (
|
||||
numIndex = iota
|
||||
packetsIndex
|
||||
bytesIndex
|
||||
targetIndex
|
||||
protocolIndex
|
||||
optIndex
|
||||
inputInterfaceIndex
|
||||
outputInterfaceIndex
|
||||
sourceIndex
|
||||
destinationIndex
|
||||
)
|
||||
|
||||
switch fieldIndex {
|
||||
case numIndex:
|
||||
rule.lineNumber, err = parseLineNumber(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing line number: %w", err)
|
||||
}
|
||||
case packetsIndex:
|
||||
rule.packets, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
case bytesIndex:
|
||||
rule.bytes, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
case targetIndex:
|
||||
err = checkTarget(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking target: %w", err)
|
||||
}
|
||||
rule.target = field
|
||||
case protocolIndex:
|
||||
rule.protocol, err = parseProtocol(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing protocol: %w", err)
|
||||
}
|
||||
case optIndex: // ignored
|
||||
case inputInterfaceIndex:
|
||||
rule.inputInterface = field
|
||||
case outputInterfaceIndex:
|
||||
rule.outputInterface = field
|
||||
case sourceIndex:
|
||||
rule.source, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case destinationIndex:
|
||||
rule.destination, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||
for i := 0; i < len(optionalFields); i++ {
|
||||
key := optionalFields[i]
|
||||
switch key {
|
||||
case "tcp", "udp":
|
||||
i++
|
||||
value := optionalFields[i]
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
||||
}
|
||||
rule.destinationPort = uint16(destinationPort)
|
||||
case "redir":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
case "ports":
|
||||
i++
|
||||
ports, err := parsePortsCSV(optionalFields[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||
}
|
||||
rule.redirPorts = ports
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fields := strings.Split(s, ",")
|
||||
ports = make([]uint16, len(fields))
|
||||
for i, field := range fields {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(field, base, bitLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
)
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if lineNumber == 0 {
|
||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||
}
|
||||
return uint16(lineNumber), nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTargetUnknown = errors.New("unknown target")
|
||||
)
|
||||
|
||||
func checkTarget(target string) (err error) {
|
||||
switch target {
|
||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrProtocolUnknown = errors.New("unknown protocol")
|
||||
)
|
||||
|
||||
func parseProtocol(s string) (protocol string, err error) {
|
||||
switch s {
|
||||
case "0":
|
||||
case "6":
|
||||
protocol = "tcp"
|
||||
case "17":
|
||||
protocol = "udp"
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||
)
|
||||
|
||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||
// returns the raw integer matching it.
|
||||
func parseMetricSize(size string) (n uint64, err error) {
|
||||
if size == "" {
|
||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||
}
|
||||
|
||||
//nolint:gomnd
|
||||
multiplerLetterToValue := map[byte]uint64{
|
||||
'K': 1000,
|
||||
'M': 1000000,
|
||||
'G': 1000000000,
|
||||
'T': 1000000000000,
|
||||
}
|
||||
|
||||
lastCharacter := size[len(size)-1]
|
||||
multiplier, ok := multiplerLetterToValue[lastCharacter]
|
||||
if ok { // multiplier present
|
||||
size = size[:len(size)-1]
|
||||
} else {
|
||||
multiplier = 1
|
||||
}
|
||||
|
||||
const base, bitLength = 10, 64
|
||||
n, err = strconv.ParseUint(size, base, bitLength)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||
}
|
||||
n *= multiplier
|
||||
return n, nil
|
||||
}
|
||||
121
internal/firewall/list_test.go
Normal file
121
internal/firewall/list_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseChain(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
iptablesOutput string
|
||||
table chain
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_output": {
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||
},
|
||||
"single_line_only": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||
},
|
||||
"malformed_general_data_line": {
|
||||
iptablesOutput: `Chain INPUT
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||
"expected 8 fields in \"Chain INPUT\"",
|
||||
},
|
||||
"malformed_legend": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: legend " +
|
||||
"\"num pkts bytes target prot opt in out source\" " +
|
||||
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||
},
|
||||
"no_rule": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
table: chain{
|
||||
name: "INPUT",
|
||||
policy: "ACCEPT",
|
||||
packets: 140000,
|
||||
bytes: 226000000,
|
||||
},
|
||||
},
|
||||
"some_rules": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source destination
|
||||
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
3 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
|
||||
`,
|
||||
table: chain{
|
||||
name: "INPUT",
|
||||
policy: "ACCEPT",
|
||||
packets: 140000,
|
||||
bytes: 226000000,
|
||||
rules: []chainRule{
|
||||
{
|
||||
lineNumber: 1,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "udp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destinationPort: 55405,
|
||||
},
|
||||
{
|
||||
lineNumber: 2,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "tcp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destinationPort: 55405,
|
||||
},
|
||||
{
|
||||
lineNumber: 3,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "DROP",
|
||||
protocol: "",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
table, err := parseChain(testCase.iptablesOutput)
|
||||
|
||||
assert.Equal(t, testCase.table, table)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,12 +5,6 @@ import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) {
|
||||
c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+
|
||||
"no default route matching its family", subnet))
|
||||
|
||||
3
internal/firewall/mocks_generate_test.go
Normal file
3
internal/firewall/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package firewall
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger
|
||||
109
internal/firewall/mocks_test.go
Normal file
109
internal/firewall/mocks_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger)
|
||||
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
command "github.com/qdm12/golibs/command"
|
||||
)
|
||||
|
||||
// MockRunner is a mock of Runner interface.
|
||||
type MockRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
||||
type MockRunnerMockRecorder struct {
|
||||
mock *MockRunner
|
||||
}
|
||||
|
||||
// NewMockRunner creates a new mock instance.
|
||||
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
||||
mock := &MockRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
||||
}
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockLogger) Info(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Info", arg0)
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||
}
|
||||
166
internal/firewall/parse.go
Normal file
166
internal/firewall/parse.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
if i.table == "" {
|
||||
i.table = "filter"
|
||||
}
|
||||
}
|
||||
|
||||
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
|
||||
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
|
||||
switch {
|
||||
case i.table != table:
|
||||
return false
|
||||
case i.chain != chain:
|
||||
return false
|
||||
case i.target != rule.target:
|
||||
return false
|
||||
case i.protocol != rule.protocol:
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.source, rule.source):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
}
|
||||
|
||||
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
return instruction == chainRule ||
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var (
|
||||
ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
)
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
if len(fields)%2 != 0 {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
||||
ErrIptablesCommandMalformed, len(fields), s)
|
||||
}
|
||||
|
||||
for i := 0; i < len(fields); i += 2 {
|
||||
key := fields[i]
|
||||
value := fields[i+1]
|
||||
err = parseInstructionFlag(key, value, &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
||||
switch key {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.chain = value
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match": // ignore match
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
instruction.outputInterface = value
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
portStrings := strings.Split(value, ",")
|
||||
instruction.toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
return netip.ParsePrefix(value)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
138
internal/firewall/parse_test.go
Normal file
138
internal/firewall/parse_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseIptablesInstruction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
s string
|
||||
instruction iptablesInstruction
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_instruction": {
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: empty instruction",
|
||||
},
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
s: "-A INPUT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
},
|
||||
},
|
||||
"instruction_A": {
|
||||
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("5.6.7.8/32"),
|
||||
destinationPort: 10000,
|
||||
target: "ACCEPT",
|
||||
},
|
||||
},
|
||||
"nat_redirection": {
|
||||
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
instruction: iptablesInstruction{
|
||||
table: "nat",
|
||||
chain: "PREROUTING",
|
||||
append: false,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
destinationPort: 43716,
|
||||
target: "REDIRECT",
|
||||
toPorts: []uint16{5678},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rule, err := parseIptablesInstruction(testCase.s)
|
||||
|
||||
assert.Equal(t, testCase.instruction, rule)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseIPPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
value string
|
||||
prefix netip.Prefix
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
|
||||
},
|
||||
"invalid": {
|
||||
value: "invalid",
|
||||
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
|
||||
},
|
||||
"valid_ipv4_with_bits": {
|
||||
value: "10.0.0.0/16",
|
||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
|
||||
},
|
||||
"valid_ipv4_without_bits": {
|
||||
value: "10.0.0.4",
|
||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
|
||||
},
|
||||
"valid_ipv6_with_bits": {
|
||||
value: "2001:db8::/32",
|
||||
prefix: netip.PrefixFrom(
|
||||
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||
32),
|
||||
},
|
||||
"valid_ipv6_without_bits": {
|
||||
value: "2001:db8::",
|
||||
prefix: netip.PrefixFrom(
|
||||
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||
128),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prefix, err := parseIPPrefix(testCase.value)
|
||||
|
||||
assert.Equal(t, testCase.prefix, prefix)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/golibs/command (interfaces: Runner)
|
||||
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
command "github.com/qdm12/golibs/command"
|
||||
)
|
||||
|
||||
// MockRunner is a mock of Runner interface.
|
||||
type MockRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
||||
type MockRunnerMockRecorder struct {
|
||||
mock *MockRunner
|
||||
}
|
||||
|
||||
// NewMockRunner creates a new mock instance.
|
||||
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
||||
mock := &MockRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
||||
}
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner
|
||||
|
||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||
return newCmdMatcher(path,
|
||||
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
@@ -58,9 +59,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
|
||||
servers = make([]models.Server, 0, len(hostToIPs))
|
||||
for _, serverData := range data.Servers {
|
||||
city, region := parseCity(serverData.City)
|
||||
server := models.Server{
|
||||
Country: serverData.Country,
|
||||
City: serverData.City,
|
||||
City: city,
|
||||
Region: region,
|
||||
ISP: serverData.ISP,
|
||||
}
|
||||
|
||||
@@ -96,3 +99,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
func parseCity(city string) (parsedCity, region string) {
|
||||
commaIndex := strings.Index(city, ", ")
|
||||
if commaIndex == -1 {
|
||||
return city, ""
|
||||
}
|
||||
return city[:commaIndex], city[commaIndex+2:]
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Check out the JSON data from https://api.nordvpn.com/v2/servers?limit=10
|
||||
@@ -92,6 +93,9 @@ func (s serversData) idToData() (
|
||||
) {
|
||||
groups = make(map[uint32]groupData, len(s.Groups))
|
||||
for _, group := range s.Groups {
|
||||
if group.Type.Identifier == "regions" { //nolint:goconst
|
||||
group.Title = strings.ReplaceAll(group.Title, ",", "")
|
||||
}
|
||||
groups[group.ID] = group
|
||||
}
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
|
||||
|
||||
server := models.Server{
|
||||
Country: location.Country.Name,
|
||||
Region: jsonServer.region(groups),
|
||||
Region: region,
|
||||
City: location.Country.City.Name,
|
||||
Categories: jsonServer.categories(groups),
|
||||
Hostname: jsonServer.Hostname,
|
||||
|
||||
@@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
}
|
||||
|
||||
serverName := objects.ServerName
|
||||
|
||||
apiIP := buildAPIIPAddress(objects.Gateway)
|
||||
logger := objects.Logger
|
||||
|
||||
if !objects.CanPortForward {
|
||||
@@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
|
||||
if !dataFound || expired {
|
||||
client := objects.Client
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP,
|
||||
p.portForwardPath, objects.Username, objects.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing port forward data: %w", err)
|
||||
@@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
||||
|
||||
// First time binding
|
||||
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
|
||||
if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil {
|
||||
return nil, fmt.Errorf("binding port: %w", err)
|
||||
}
|
||||
|
||||
@@ -100,6 +100,8 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
||||
panic("gateway is not set")
|
||||
}
|
||||
|
||||
apiIP := buildAPIIPAddress(objects.Gateway)
|
||||
|
||||
privateIPClient, err := newHTTPClient(objects.ServerName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating custom HTTP client: %w", err)
|
||||
@@ -127,7 +129,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-keepAliveTimer.C:
|
||||
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
|
||||
err = bindPort(ctx, privateIPClient, apiIP, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("binding port: %w", err)
|
||||
}
|
||||
@@ -139,14 +141,25 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) {
|
||||
if gateway.Is6() {
|
||||
panic("IPv6 gateway not supported")
|
||||
}
|
||||
|
||||
gatewayBytes := gateway.As4()
|
||||
gatewayBytes[2] = 128
|
||||
gatewayBytes[3] = 1
|
||||
return netip.AddrFrom4(gatewayBytes)
|
||||
}
|
||||
|
||||
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||
gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
|
||||
apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchToken(ctx, client, username, password)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("fetching token: %w", err)
|
||||
}
|
||||
|
||||
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
||||
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("fetching port forwarding data: %w", err)
|
||||
}
|
||||
@@ -286,7 +299,7 @@ func fetchToken(ctx context.Context, client *http.Client,
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
|
||||
func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) (
|
||||
port uint16, signature string, expiration time.Time, err error) {
|
||||
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
|
||||
|
||||
@@ -294,7 +307,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway neti
|
||||
queryParams.Add("token", token)
|
||||
url := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
||||
Host: net.JoinHostPort(apiIP.String(), "19999"),
|
||||
Path: "/getSignature",
|
||||
RawQuery: queryParams.Encode(),
|
||||
}
|
||||
@@ -340,7 +353,7 @@ var (
|
||||
ErrBadResponse = errors.New("bad response received")
|
||||
)
|
||||
|
||||
func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
|
||||
func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
|
||||
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serializing payload: %w", err)
|
||||
@@ -351,7 +364,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
|
||||
queryParams.Add("signature", data.Signature)
|
||||
bindPortURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
||||
Host: net.JoinHostPort(apiIPAddress.String(), "19999"),
|
||||
Path: "/bindPort",
|
||||
RawQuery: queryParams.Encode(),
|
||||
}
|
||||
|
||||
@@ -2,13 +2,17 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/log"
|
||||
)
|
||||
|
||||
func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
||||
func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||
authSettings auth.Settings,
|
||||
buildInfo models.BuildInformation,
|
||||
vpnLooper VPNLooper,
|
||||
pfGetter PortForwardedGetter,
|
||||
@@ -17,7 +21,7 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
||||
publicIPLooper PublicIPLoop,
|
||||
storage Storage,
|
||||
ipv6Supported bool,
|
||||
) http.Handler {
|
||||
) (httpHandler http.Handler, err error) {
|
||||
handler := &handler{}
|
||||
|
||||
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
|
||||
@@ -29,16 +33,25 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
||||
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
|
||||
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
|
||||
|
||||
handlerWithLog := withLogMiddleware(handler, logger, logging)
|
||||
handler.setLogEnabled = handlerWithLog.setEnabled
|
||||
authMiddleware, err := auth.New(authSettings, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating auth middleware: %w", err)
|
||||
}
|
||||
|
||||
return handlerWithLog
|
||||
middlewares := []func(http.Handler) http.Handler{
|
||||
authMiddleware,
|
||||
log.New(logger, logging),
|
||||
}
|
||||
httpHandler = handler
|
||||
for _, middleware := range middlewares {
|
||||
httpHandler = middleware(httpHandler)
|
||||
}
|
||||
return httpHandler, nil
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
v0 http.Handler
|
||||
v1 http.Handler
|
||||
setLogEnabled func(enabled bool)
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package server
|
||||
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...any)
|
||||
infoer
|
||||
warner
|
||||
Warnf(format string, args ...any)
|
||||
errorer
|
||||
}
|
||||
|
||||
|
||||
36
internal/server/middlewares/auth/apikey.go
Normal file
36
internal/server/middlewares/auth/apikey.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type apiKeyMethod struct {
|
||||
apiKeyDigest [32]byte
|
||||
}
|
||||
|
||||
func newAPIKeyMethod(apiKey string) *apiKeyMethod {
|
||||
return &apiKeyMethod{
|
||||
apiKeyDigest: sha256.Sum256([]byte(apiKey)),
|
||||
}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (a *apiKeyMethod) equal(other authorizationChecker) bool {
|
||||
otherTokenMethod, ok := other.(*apiKeyMethod)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return a.apiKeyDigest == otherTokenMethod.apiKeyDigest
|
||||
}
|
||||
|
||||
func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool {
|
||||
xAPIKey := request.Header.Get("X-API-Key")
|
||||
if xAPIKey == "" {
|
||||
xAPIKey = request.URL.Query().Get("api_key")
|
||||
}
|
||||
xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey))
|
||||
return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1
|
||||
}
|
||||
37
internal/server/middlewares/auth/basic.go
Normal file
37
internal/server/middlewares/auth/basic.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type basicAuthMethod struct {
|
||||
authDigest [32]byte
|
||||
}
|
||||
|
||||
func newBasicAuthMethod(username, password string) *basicAuthMethod {
|
||||
return &basicAuthMethod{
|
||||
authDigest: sha256.Sum256([]byte(username + password)),
|
||||
}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (a *basicAuthMethod) equal(other authorizationChecker) bool {
|
||||
otherBasicMethod, ok := other.(*basicAuthMethod)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return a.authDigest == otherBasicMethod.authDigest
|
||||
}
|
||||
|
||||
func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool {
|
||||
username, password, ok := request.BasicAuth()
|
||||
if !ok {
|
||||
headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
return false
|
||||
}
|
||||
requestAuthDigest := sha256.Sum256([]byte(username + password))
|
||||
return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1
|
||||
}
|
||||
35
internal/server/middlewares/auth/configfile.go
Normal file
35
internal/server/middlewares/auth/configfile.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
// If the file does not exist, it returns empty settings and no error.
|
||||
func Read(filepath string) (settings Settings, err error) {
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return Settings{}, nil
|
||||
}
|
||||
return settings, fmt.Errorf("opening file: %w", err)
|
||||
}
|
||||
decoder := toml.NewDecoder(file)
|
||||
decoder.DisallowUnknownFields()
|
||||
err = decoder.Decode(&settings)
|
||||
if err == nil {
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
strictErr := new(toml.StrictMissingError)
|
||||
ok := errors.As(err, &strictErr)
|
||||
if !ok {
|
||||
return settings, fmt.Errorf("toml decoding file: %w", err)
|
||||
}
|
||||
return settings, fmt.Errorf("toml decoding file: %w:\n%s",
|
||||
strictErr, strictErr.String())
|
||||
}
|
||||
80
internal/server/middlewares/auth/configfile_test.go
Normal file
80
internal/server/middlewares/auth/configfile_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
func Test_Read(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
fileContent string
|
||||
settings Settings
|
||||
errMessage string
|
||||
}{
|
||||
"empty_file": {},
|
||||
"malformed_toml": {
|
||||
fileContent: "this is not a toml file",
|
||||
errMessage: `toml decoding file: toml: expected character =`,
|
||||
},
|
||||
"unknown_field": {
|
||||
fileContent: `unknown = "what is this"`,
|
||||
errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct:
|
||||
1| unknown = "what is this"
|
||||
| ~~~~~~~ missing field`,
|
||||
},
|
||||
"filled_settings": {
|
||||
fileContent: `[[roles]]
|
||||
name = "public"
|
||||
auth = "none"
|
||||
routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"]
|
||||
|
||||
[[roles]]
|
||||
name = "client"
|
||||
auth = "apikey"
|
||||
apikey = "xyz"
|
||||
routes = ["GET /v1/vpn/status"]
|
||||
`,
|
||||
settings: Settings{
|
||||
Roles: []Role{{
|
||||
Name: "public",
|
||||
Auth: AuthNone,
|
||||
Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"},
|
||||
}, {
|
||||
Name: "client",
|
||||
Auth: AuthAPIKey,
|
||||
APIKey: "xyz",
|
||||
Routes: []string{"GET /v1/vpn/status"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
filepath := tempDir + "/config.toml"
|
||||
const permissions fs.FileMode = 0600
|
||||
err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
settings, err := Read(filepath)
|
||||
|
||||
assert.Equal(t, testCase.settings, settings)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
22
internal/server/middlewares/auth/format.go
Normal file
22
internal/server/middlewares/auth/format.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package auth
|
||||
|
||||
func andStrings(strings []string) (result string) {
|
||||
return joinStrings(strings, "and")
|
||||
}
|
||||
|
||||
func joinStrings(strings []string, lastJoin string) (result string) {
|
||||
if len(strings) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result = strings[0]
|
||||
for i := 1; i < len(strings); i++ {
|
||||
if i < len(strings)-1 {
|
||||
result += ", " + strings[i]
|
||||
} else {
|
||||
result += " " + lastJoin + " " + strings[i]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
6
internal/server/middlewares/auth/interfaces.go
Normal file
6
internal/server/middlewares/auth/interfaces.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package auth
|
||||
|
||||
type DebugLogger interface {
|
||||
Debugf(format string, args ...any)
|
||||
Warnf(format string, args ...any)
|
||||
}
|
||||
8
internal/server/middlewares/auth/interfaces_local.go
Normal file
8
internal/server/middlewares/auth/interfaces_local.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package auth
|
||||
|
||||
import "net/http"
|
||||
|
||||
type authorizationChecker interface {
|
||||
equal(other authorizationChecker) bool
|
||||
isAuthorized(headers http.Header, request *http.Request) bool
|
||||
}
|
||||
47
internal/server/middlewares/auth/lookup.go
Normal file
47
internal/server/middlewares/auth/lookup.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type internalRole struct {
|
||||
name string
|
||||
checker authorizationChecker
|
||||
}
|
||||
|
||||
func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) {
|
||||
routeToRoles = make(map[string][]internalRole)
|
||||
for _, role := range settings.Roles {
|
||||
var checker authorizationChecker
|
||||
switch role.Auth {
|
||||
case AuthNone:
|
||||
checker = newNoneMethod()
|
||||
case AuthAPIKey:
|
||||
checker = newAPIKeyMethod(role.APIKey)
|
||||
case AuthBasic:
|
||||
checker = newBasicAuthMethod(role.Username, role.Password)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth)
|
||||
}
|
||||
|
||||
iRole := internalRole{
|
||||
name: role.Name,
|
||||
checker: checker,
|
||||
}
|
||||
for _, route := range role.Routes {
|
||||
checkerExists := false
|
||||
for _, role := range routeToRoles[route] {
|
||||
if role.checker.equal(iRole.checker) {
|
||||
checkerExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if checkerExists {
|
||||
// even if the role name is different, if the checker is the same, skip it.
|
||||
continue
|
||||
}
|
||||
routeToRoles[route] = append(routeToRoles[route], iRole)
|
||||
}
|
||||
}
|
||||
return routeToRoles, nil
|
||||
}
|
||||
60
internal/server/middlewares/auth/lookup_test.go
Normal file
60
internal/server/middlewares/auth/lookup_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
func Test_settingsToLookupMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
routeToRoles map[string][]internalRole
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty_settings": {
|
||||
routeToRoles: map[string][]internalRole{},
|
||||
},
|
||||
"auth_method_not_supported": {
|
||||
settings: Settings{
|
||||
Roles: []Role{{Name: "a", Auth: "bad"}},
|
||||
},
|
||||
errWrapped: ErrMethodNotSupported,
|
||||
errMessage: "authentication method not supported: bad",
|
||||
},
|
||||
"success": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}},
|
||||
{Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}},
|
||||
},
|
||||
},
|
||||
routeToRoles: map[string][]internalRole{
|
||||
"GET /path": {
|
||||
{name: "a", checker: newNoneMethod()}, // deduplicated method
|
||||
},
|
||||
"PUT /path": {
|
||||
{name: "b", checker: newNoneMethod()},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
routeToRoles, err := settingsToLookupMap(testCase.settings)
|
||||
|
||||
assert.Equal(t, testCase.routeToRoles, routeToRoles)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
111
internal/server/middlewares/auth/middleware.go
Normal file
111
internal/server/middlewares/auth/middleware.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func New(settings Settings, debugLogger DebugLogger) (
|
||||
middleware func(http.Handler) http.Handler,
|
||||
err error) {
|
||||
routeToRoles, err := settingsToLookupMap(settings)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting settings to lookup maps: %w", err)
|
||||
}
|
||||
|
||||
//nolint:goconst
|
||||
return func(handler http.Handler) http.Handler {
|
||||
return &authHandler{
|
||||
childHandler: handler,
|
||||
routeToRoles: routeToRoles,
|
||||
unprotectedRoutes: map[string]struct{}{
|
||||
http.MethodGet + " /openvpn/actions/restart": {},
|
||||
http.MethodGet + " /unbound/actions/restart": {},
|
||||
http.MethodGet + " /updater/restart": {},
|
||||
http.MethodGet + " /v1/version": {},
|
||||
http.MethodGet + " /v1/vpn/status": {},
|
||||
http.MethodPut + " /v1/vpn/status": {},
|
||||
// GET /v1/vpn/settings is protected by default
|
||||
// PUT /v1/vpn/settings is protected by default
|
||||
http.MethodGet + " /v1/openvpn/status": {},
|
||||
http.MethodPut + " /v1/openvpn/status": {},
|
||||
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||
// GET /v1/openvpn/settings is protected by default
|
||||
http.MethodGet + " /v1/dns/status": {},
|
||||
http.MethodPut + " /v1/dns/status": {},
|
||||
http.MethodGet + " /v1/updater/status": {},
|
||||
http.MethodPut + " /v1/updater/status": {},
|
||||
http.MethodGet + " /v1/publicip/ip": {},
|
||||
},
|
||||
logger: debugLogger,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
childHandler http.Handler
|
||||
routeToRoles map[string][]internalRole
|
||||
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove
|
||||
logger DebugLogger
|
||||
}
|
||||
|
||||
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
route := request.Method + " " + request.URL.Path
|
||||
roles := h.routeToRoles[route]
|
||||
if len(roles) == 0 {
|
||||
h.logger.Debugf("no authentication role defined for route %s", route)
|
||||
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
responseHeader := make(http.Header, 0)
|
||||
for _, role := range roles {
|
||||
if !role.checker.isAuthorized(responseHeader, request) {
|
||||
continue
|
||||
}
|
||||
|
||||
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
|
||||
|
||||
h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
|
||||
h.childHandler.ServeHTTP(writer, request)
|
||||
return
|
||||
}
|
||||
|
||||
// Flush out response headers if all roles failed to authenticate
|
||||
for headerKey, headerValues := range responseHeader {
|
||||
for _, headerValue := range headerValues {
|
||||
writer.Header().Add(headerKey, headerValue)
|
||||
}
|
||||
}
|
||||
|
||||
allRoleNames := make([]string, len(roles))
|
||||
for i, role := range roles {
|
||||
allRoleNames[i] = role.name
|
||||
}
|
||||
h.logger.Debugf("access to route %s unauthorized after checking for roles %s",
|
||||
route, andStrings(allRoleNames))
|
||||
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
|
||||
// TODO v3.41.0 remove
|
||||
if role.name != "public" {
|
||||
// custom role name, allow none authentication to be specified
|
||||
return
|
||||
}
|
||||
_, isNoneChecker := role.checker.(*noneMethod)
|
||||
if !isNoneChecker {
|
||||
// not the none authentication method
|
||||
return
|
||||
}
|
||||
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
|
||||
if !isUnprotectedByDefault {
|
||||
// route is not unprotected by default, so this is a user decision
|
||||
return
|
||||
}
|
||||
h.logger.Warnf("route %s is unprotected by default, "+
|
||||
"please set up authentication following the documentation at "+
|
||||
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||
"since this will become no longer publicly accessible after release v3.40.",
|
||||
route)
|
||||
}
|
||||
124
internal/server/middlewares/auth/middleware_test.go
Normal file
124
internal/server/middlewares/auth/middleware_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_authHandler_ServeHTTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
|
||||
requestMethod string
|
||||
requestPath string
|
||||
statusCode int
|
||||
responseBody string
|
||||
}{
|
||||
"route_has_no_role": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/b",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
responseBody: "Unauthorized\n",
|
||||
},
|
||||
"authorized_unprotected_by_default": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Warnf("route %s is unprotected by default, "+
|
||||
"please set up authentication following the documentation at "+
|
||||
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||
"since this will become no longer publicly accessible after release v3.40.",
|
||||
"GET /v1/vpn/status")
|
||||
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||
"GET /v1/vpn/status", "public")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/v1/vpn/status",
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
"authorized_none": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||
"GET /a", "role1")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/a",
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var debugLogger DebugLogger
|
||||
if testCase.makeLogger != nil {
|
||||
debugLogger = testCase.makeLogger(ctrl)
|
||||
}
|
||||
middleware, err := New(testCase.settings, debugLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := middleware(childHandler)
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := server.Client()
|
||||
|
||||
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
|
||||
require.NoError(t, err)
|
||||
request, err := http.NewRequestWithContext(context.Background(),
|
||||
testCase.requestMethod, requestURL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := client.Do(request)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err = response.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
assert.Equal(t, testCase.statusCode, response.StatusCode)
|
||||
body, err := io.ReadAll(response.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testCase.responseBody, string(body))
|
||||
})
|
||||
}
|
||||
}
|
||||
3
internal/server/middlewares/auth/mocks_generate_test.go
Normal file
3
internal/server/middlewares/auth/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package auth
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger
|
||||
68
internal/server/middlewares/auth/mocks_test.go
Normal file
68
internal/server/middlewares/auth/mocks_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger)
|
||||
|
||||
// Package auth is a generated GoMock package.
|
||||
package auth
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDebugLogger is a mock of DebugLogger interface.
|
||||
type MockDebugLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDebugLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger.
|
||||
type MockDebugLoggerMockRecorder struct {
|
||||
mock *MockDebugLogger
|
||||
}
|
||||
|
||||
// NewMockDebugLogger creates a new mock instance.
|
||||
func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
mock := &MockDebugLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockDebugLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debugf mocks base method.
|
||||
func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Debugf", varargs...)
|
||||
}
|
||||
|
||||
// Debugf indicates an expected call of Debugf.
|
||||
func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...)
|
||||
}
|
||||
|
||||
// Warnf mocks base method.
|
||||
func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Warnf", varargs...)
|
||||
}
|
||||
|
||||
// Warnf indicates an expected call of Warnf.
|
||||
func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...)
|
||||
}
|
||||
20
internal/server/middlewares/auth/none.go
Normal file
20
internal/server/middlewares/auth/none.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package auth
|
||||
|
||||
import "net/http"
|
||||
|
||||
type noneMethod struct{}
|
||||
|
||||
func newNoneMethod() *noneMethod {
|
||||
return &noneMethod{}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (n *noneMethod) equal(other authorizationChecker) bool {
|
||||
_, ok := other.(*noneMethod)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
131
internal/server/middlewares/auth/settings.go
Normal file
131
internal/server/middlewares/auth/settings.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/validate"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
// Roles is a list of roles with their associated authentication
|
||||
// and routes.
|
||||
Roles []Role
|
||||
}
|
||||
|
||||
func (s *Settings) SetDefaults() {
|
||||
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
|
||||
Name: "public",
|
||||
Auth: "none",
|
||||
Routes: []string{
|
||||
http.MethodGet + " /openvpn/actions/restart",
|
||||
http.MethodGet + " /unbound/actions/restart",
|
||||
http.MethodGet + " /updater/restart",
|
||||
http.MethodGet + " /v1/version",
|
||||
http.MethodGet + " /v1/vpn/status",
|
||||
http.MethodPut + " /v1/vpn/status",
|
||||
http.MethodGet + " /v1/openvpn/status",
|
||||
http.MethodPut + " /v1/openvpn/status",
|
||||
http.MethodGet + " /v1/openvpn/portforwarded",
|
||||
http.MethodGet + " /v1/dns/status",
|
||||
http.MethodPut + " /v1/dns/status",
|
||||
http.MethodGet + " /v1/updater/status",
|
||||
http.MethodPut + " /v1/updater/status",
|
||||
http.MethodGet + " /v1/publicip/ip",
|
||||
},
|
||||
}})
|
||||
}
|
||||
|
||||
func (s Settings) Validate() (err error) {
|
||||
for i, role := range s.Roles {
|
||||
err = role.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("role %s (%d of %d): %w",
|
||||
role.Name, i+1, len(s.Roles), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
AuthNone = "none"
|
||||
AuthAPIKey = "apikey"
|
||||
AuthBasic = "basic"
|
||||
)
|
||||
|
||||
// Role contains the role name, authentication method name and
|
||||
// routes that the role can access.
|
||||
type Role struct {
|
||||
// Name is the role name and is only used for documentation
|
||||
// and in the authentication middleware debug logs.
|
||||
Name string
|
||||
// Auth is the authentication method to use, which can be 'none' or 'apikey'.
|
||||
Auth string
|
||||
// APIKey is the API key to use when using the 'apikey' authentication.
|
||||
APIKey string
|
||||
// Username for HTTP Basic authentication method.
|
||||
Username string
|
||||
// Password for HTTP Basic authentication method.
|
||||
Password string
|
||||
// Routes is a list of routes that the role can access in the format
|
||||
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
|
||||
Routes []string
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMethodNotSupported = errors.New("authentication method not supported")
|
||||
ErrAPIKeyEmpty = errors.New("api key is empty")
|
||||
ErrBasicUsernameEmpty = errors.New("username is empty")
|
||||
ErrBasicPasswordEmpty = errors.New("password is empty")
|
||||
ErrRouteNotSupported = errors.New("route not supported by the control server")
|
||||
)
|
||||
|
||||
func (r Role) validate() (err error) {
|
||||
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Auth == AuthAPIKey && r.APIKey == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty)
|
||||
case r.Auth == AuthBasic && r.Username == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty)
|
||||
case r.Auth == AuthBasic && r.Password == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty)
|
||||
}
|
||||
|
||||
for i, route := range r.Routes {
|
||||
_, ok := validRoutes[route]
|
||||
if !ok {
|
||||
return fmt.Errorf("route %d of %d: %w: %s",
|
||||
i+1, len(r.Routes), ErrRouteNotSupported, route)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WARNING: do not mutate programmatically.
|
||||
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
|
||||
http.MethodGet + " /openvpn/actions/restart": {},
|
||||
http.MethodGet + " /unbound/actions/restart": {},
|
||||
http.MethodGet + " /updater/restart": {},
|
||||
http.MethodGet + " /v1/version": {},
|
||||
http.MethodGet + " /v1/vpn/status": {},
|
||||
http.MethodPut + " /v1/vpn/status": {},
|
||||
http.MethodGet + " /v1/vpn/settings": {},
|
||||
http.MethodPut + " /v1/vpn/settings": {},
|
||||
http.MethodGet + " /v1/openvpn/status": {},
|
||||
http.MethodPut + " /v1/openvpn/status": {},
|
||||
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||
http.MethodGet + " /v1/openvpn/settings": {},
|
||||
http.MethodGet + " /v1/dns/status": {},
|
||||
http.MethodPut + " /v1/dns/status": {},
|
||||
http.MethodGet + " /v1/updater/status": {},
|
||||
http.MethodPut + " /v1/updater/status": {},
|
||||
http.MethodGet + " /v1/publicip/ip": {},
|
||||
}
|
||||
5
internal/server/middlewares/log/interfaces.go
Normal file
5
internal/server/middlewares/log/interfaces.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package log
|
||||
|
||||
type Logger interface {
|
||||
Info(message string)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package log
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -7,18 +7,21 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func withLogMiddleware(childHandler http.Handler, logger infoer, enabled bool) *logMiddleware {
|
||||
func New(logger Logger, enabled bool) (
|
||||
middleware func(http.Handler) http.Handler) {
|
||||
return func(handler http.Handler) http.Handler {
|
||||
return &logMiddleware{
|
||||
childHandler: childHandler,
|
||||
childHandler: handler,
|
||||
logger: logger,
|
||||
timeNow: time.Now,
|
||||
enabled: enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type logMiddleware struct {
|
||||
childHandler http.Handler
|
||||
logger infoer
|
||||
logger Logger
|
||||
timeNow func() time.Time
|
||||
enabled bool
|
||||
enabledMu sync.RWMutex
|
||||
@@ -39,7 +42,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.RemoteAddr + " in " + duration.String())
|
||||
}
|
||||
|
||||
func (m *logMiddleware) setEnabled(enabled bool) {
|
||||
func (m *logMiddleware) SetEnabled(enabled bool) {
|
||||
m.enabledMu.Lock()
|
||||
defer m.enabledMu.Unlock()
|
||||
m.enabled = enabled
|
||||
@@ -6,17 +6,31 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/httpserver"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||
)
|
||||
|
||||
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
|
||||
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||
authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
|
||||
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
|
||||
ipv6Supported bool) (
|
||||
server *httpserver.Server, err error) {
|
||||
handler := newHandler(ctx, logger, logEnabled, buildInfo,
|
||||
authSettings, err := auth.Read(authConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading auth settings: %w", err)
|
||||
}
|
||||
authSettings.SetDefaults()
|
||||
err = authSettings.Validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating auth settings: %w", err)
|
||||
}
|
||||
|
||||
handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
|
||||
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
|
||||
storage, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating handler: %w", err)
|
||||
}
|
||||
|
||||
httpServerSettings := httpserver.Settings{
|
||||
Address: address,
|
||||
|
||||
@@ -128,6 +128,31 @@ func noServerFoundError(selection settings.ServerSelection) (err error) {
|
||||
messageParts = append(messageParts, "premium tier only")
|
||||
}
|
||||
|
||||
if *selection.StreamOnly {
|
||||
messageParts = append(messageParts, "stream only")
|
||||
}
|
||||
|
||||
if *selection.MultiHopOnly {
|
||||
messageParts = append(messageParts, "multihop only")
|
||||
}
|
||||
|
||||
if *selection.PortForwardOnly {
|
||||
messageParts = append(messageParts, "port forwarding only")
|
||||
}
|
||||
|
||||
if *selection.SecureCoreOnly {
|
||||
messageParts = append(messageParts, "secure core only")
|
||||
}
|
||||
|
||||
if *selection.TorOnly {
|
||||
messageParts = append(messageParts, "tor only")
|
||||
}
|
||||
|
||||
if selection.TargetIP.IsValid() {
|
||||
messageParts = append(messageParts,
|
||||
"target ip address "+selection.TargetIP.String())
|
||||
}
|
||||
|
||||
message := "for " + strings.Join(messageParts, "; ")
|
||||
|
||||
return fmt.Errorf("%w: %s", ErrNoServerFound, message)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user