Compare commits

...

7 Commits

Author SHA1 Message Date
Quentin McGaw
67ae5f5065 feat(server): role based authentication system (#2434)
- Parse toml configuration file, see https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/control-server.md#authentication
- Retro-compatible with existing AND documented routes, until after v3.41 release
- Log a warning if an unprotected-by-default route is accessed unprotected
- Authentication methods: none, apikey, basic
- `genkey` command to generate API keys
- move log middleware to internal/server/middlewares/log

Co-authored-by: Joe Jose <45399349+joejose97@users.noreply.github.com>
2024-09-29 17:53:17 +00:00
Quentin McGaw
cbfdb25190 fix(settings): prevent using FREE_ONLY and PORT_FORWARD_ONLY together with protonvpn (see #2470) 2024-09-29 17:53:17 +00:00
Quentin McGaw
638f233b3c fix(storage): add missing selection fields to build noServerFoundError
- `STREAM_ONLY`, `PORT_FORWARD_ONLY`, `SECURE_CORE_ONLY`, `TOR_ONLY` and target ip options affected
- Refers to issue #2470
2024-09-29 17:53:17 +00:00
Quentin McGaw
c450c54d67 fix(ivpn): split city into city and region
- Fix bad city values containing a comma
- update ivpn servers data
2024-09-29 17:53:17 +00:00
Quentin McGaw
d166314f8b fix(nordvpn): remove commas from region values 2024-09-29 17:53:17 +00:00
Quentin McGaw
7064a44403 fix(pia): support port forwarding using Wireguard (#2420)
- Build API IP address using the first 2 bytes of the gateway IP and adding `128.1` to it
- API IP address is valid for both OpenVPN and Wireguard
- Fix #2320
2024-09-29 17:53:17 +00:00
Quentin McGaw
c33158c13c fix(firewall): delete chain rules by line number (#2411)
- Fix #2334
- Parsing of iptables chains, contributing to progress for #1856
2024-09-29 17:53:04 +00:00
48 changed files with 2946 additions and 725 deletions

View File

@@ -197,6 +197,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
# Control server
HTTP_CONTROL_SERVER_LOG=on \
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \
# Server data updater
UPDATER_PERIOD=0 \
UPDATER_MIN_RATIO=0.8 \

View File

@@ -161,12 +161,14 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return cli.Update(ctx, args[2:], logger)
case "format-servers":
return cli.FormatServers(args[2:])
case "genkey":
return cli.GenKey(args[2:])
default:
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
}
}
announcementExp, err := time.Parse(time.RFC3339, "2023-07-01T00:00:00Z")
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
if err != nil {
return err
}
@@ -177,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
Version: buildInfo.Version,
Commit: buildInfo.Commit,
Created: buildInfo.Created,
Announcement: "Wiki moved to https://github.com/qdm12/gluetun-wiki",
Announcement: "All control server routes will become private by default after the v3.41.0 release",
AnnounceExp: announcementExp,
// Sponsor information
PaypalUser: "qmcgaw",
@@ -474,6 +476,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
logger.New(log.SetComponent("http server")),
allSettings.ControlServer.AuthFilePath,
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported)
if err != nil {
@@ -595,6 +598,7 @@ type clier interface {
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
GenKey(args []string) error
}
type Tun interface {

1
go.mod
View File

@@ -8,6 +8,7 @@ require (
github.com/golang/mock v1.6.0
github.com/klauspost/compress v1.17.8
github.com/klauspost/pgzip v1.2.6
github.com/pelletier/go-toml/v2 v2.2.2
github.com/qdm12/dns v1.11.0
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
github.com/qdm12/gosettings v0.4.2

8
go.sum
View File

@@ -83,6 +83,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -113,10 +115,16 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=

66
internal/cli/genkey.go Normal file
View File

@@ -0,0 +1,66 @@
package cli
import (
"crypto/rand"
"flag"
"fmt"
)
func (c *CLI) GenKey(args []string) (err error) {
flagSet := flag.NewFlagSet("genkey", flag.ExitOnError)
err = flagSet.Parse(args)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
const keyLength = 128 / 8
keyBytes := make([]byte, keyLength)
_, _ = rand.Read(keyBytes)
key := base58Encode(keyBytes)
fmt.Println(key)
return nil
}
func base58Encode(data []byte) string {
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
const radix = 58
zcount := 0
for zcount < len(data) && data[zcount] == 0 {
zcount++
}
// integer simplification of ceil(log(256)/log(58))
ceilLog256Div58 := (len(data)-zcount)*555/406 + 1 //nolint:gomnd
size := zcount + ceilLog256Div58
output := make([]byte, size)
high := size - 1
for _, b := range data {
i := size - 1
for carry := uint32(b); i > high || carry != 0; i-- {
carry += 256 * uint32(output[i]) //nolint:gomnd
output[i] = byte(carry % radix)
carry /= radix
}
high = i
}
// Determine the additional "zero-gap" in the output buffer
additionalZeroGapEnd := zcount
for additionalZeroGapEnd < size && output[additionalZeroGapEnd] == 0 {
additionalZeroGapEnd++
}
val := output[additionalZeroGapEnd-zcount:]
size = len(val)
for i := range val {
output[i] = alphabet[val[i]]
}
return string(output[:size])
}

View File

@@ -19,6 +19,11 @@ type ControlServer struct {
// Log can be true or false to enable logging on requests.
// It cannot be nil in the internal state.
Log *bool
// AuthFilePath is the path to the file containing the authentication
// configuration for the middleware.
// It cannot be empty in the internal state and defaults to
// /gluetun/auth/config.toml.
AuthFilePath string
}
func (c ControlServer) validate() (err error) {
@@ -44,8 +49,9 @@ func (c ControlServer) validate() (err error) {
func (c *ControlServer) copy() (copied ControlServer) {
return ControlServer{
Address: gosettings.CopyPointer(c.Address),
Log: gosettings.CopyPointer(c.Log),
Address: gosettings.CopyPointer(c.Address),
Log: gosettings.CopyPointer(c.Log),
AuthFilePath: c.AuthFilePath,
}
}
@@ -55,11 +61,13 @@ func (c *ControlServer) copy() (copied ControlServer) {
func (c *ControlServer) overrideWith(other ControlServer) {
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath)
}
func (c *ControlServer) setDefaults() {
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
c.Log = gosettings.DefaultPointer(c.Log, true)
c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml")
}
func (c ControlServer) String() string {
@@ -70,6 +78,7 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) {
node = gotree.New("Control server settings:")
node.Appendf("Listening address: %s", *c.Address)
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
node.Appendf("Authentication file path: %s", c.AuthFilePath)
return node
}
@@ -78,6 +87,10 @@ func (c *ControlServer) read(r *reader.Reader) (err error) {
if err != nil {
return err
}
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH")
return nil
}

View File

@@ -191,11 +191,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
}
if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 {
// Allow a single name to be specified for the custom provider in case
// the user wants to use VPN server side port forwarding with PIA
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
if vpnServiceProvider == providers.Custom {
switch len(settings.Names) {
case 0:
case 1:
// Allow a single name to be specified for the custom provider in case
// the user wants to use VPN server side port forwarding with PIA
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
default:
return fmt.Errorf("%w: %d names specified instead of "+
"0 or 1 for the custom provider",
ErrNameNotValid, len(settings.Names))
}
}
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
if err != nil {
@@ -229,6 +237,8 @@ func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string)
switch {
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
case *settings.StreamOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)

View File

@@ -78,7 +78,8 @@ func Test_Settings_String(t *testing.T) {
| └── Enabled: no
├── Control server settings:
| ├── Listening address: :8000
| ── Logging: yes
| ── Logging: yes
| └── Authentication file path: /gluetun/auth/config.toml
├── OS Alpine settings:
| ├── Process UID: 1000
| └── Process GID: 1000

View File

@@ -0,0 +1,98 @@
package firewall
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
)
// isDeleteMatchInstruction returns true if the iptables instruction
// is a delete instruction by rule matching. It returns false if the
// instruction is a delete instruction by line number, or not a delete
// instruction.
func isDeleteMatchInstruction(instruction string) bool {
fields := strings.Fields(instruction)
for i, field := range fields {
switch {
case field != "-D" && field != "--delete": //nolint:goconst
continue
case i == len(fields)-1: // malformed: missing chain name
return false
case i == len(fields)-2: // chain name is last field
return true
default:
// chain name is fields[i+1]
const base, bitLength = 10, 16
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
return err != nil // not a line number
}
}
return false
}
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
runner Runner, logger Logger) (err error) {
targetRule, err := parseIptablesInstruction(instruction)
if err != nil {
return fmt.Errorf("parsing iptables command: %w", err)
}
lineNumber, err := findLineNumber(ctx, iptablesBinary,
targetRule, runner, logger)
if err != nil {
return fmt.Errorf("finding iptables chain rule line number: %w", err)
} else if lineNumber == 0 {
logger.Debug("rule matching \"" + instruction + "\" not found")
return nil
}
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
instruction, lineNumber))
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return err
}
return nil
}
// findLineNumber finds the line number of an iptables rule.
// It returns 0 if the rule is not found.
func findLineNumber(ctx context.Context, iptablesBinary string,
instruction iptablesInstruction, runner Runner, logger Logger) (
lineNumber uint16, err error) {
listFlags := []string{"-t", instruction.table, "-L", instruction.chain,
"--line-numbers", "-n", "-v"}
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return 0, err
}
chain, err := parseChain(output)
if err != nil {
return 0, fmt.Errorf("parsing chain list: %w", err)
}
for _, rule := range chain.rules {
if instruction.equalToRule(instruction.table, chain.name, rule) {
return rule.lineNumber, nil
}
}
return 0, nil
}

View File

@@ -0,0 +1,188 @@
package firewall
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_isDeleteMatchInstruction(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
instruction string
isDeleteMatch bool
}{
"not_delete": {
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
},
"malformed_missing_chain_name": {
instruction: "-t nat -D",
},
"delete_chain_name_last_field": {
instruction: "-t nat --delete PREROUTING",
isDeleteMatch: true,
},
"delete_match": {
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
isDeleteMatch: true,
},
"delete_line_number_last_field": {
instruction: "-t nat -D PREROUTING 2",
},
"delete_line_number": {
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
})
}
}
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
"^--line-numbers$", "^-n$", "^-v$")
}
func Test_deleteIPTablesRule(t *testing.T) {
t.Parallel()
const iptablesBinary = "/sbin/iptables"
errTest := errors.New("test error")
testCases := map[string]struct {
instruction string
makeRunner func(ctrl *gomock.Controller) *MockRunner
makeLogger func(ctrl *gomock.Controller) *MockLogger
errWrapped error
errMessage string
}{
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: iptables command is malformed: " +
"fields count 1 is not even: \"invalid\"",
},
"list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
return logger
},
errWrapped: errTest,
errMessage: `finding iptables chain rule line number: command failed: ` +
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
},
"rule_not_found": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
num pkts bytes target prot opt in out source destination
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
return logger
},
},
"rule_found_delete_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
errWrapped: errTest,
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
},
"rule_found_delete_success": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
ctx := context.Background()
instruction := testCase.instruction
var runner *MockRunner
if testCase.makeRunner != nil {
runner = testCase.makeRunner(ctrl)
}
var logger *MockLogger
if testCase.makeLogger != nil {
logger = testCase.makeLogger(ctrl)
}
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -0,0 +1,13 @@
package firewall
import "github.com/qdm12/golibs/command"
type Runner interface {
Run(cmd command.ExecCmd) (output string, err error)
}
type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}

View File

@@ -40,10 +40,14 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
c.logger.Debug(c.ip6Tables + " " + instruction)
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
c.runner, c.logger)
}
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ip6Tables, instruction, output, err)
@@ -55,7 +59,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
case "ACCEPT", "DROP": //nolint:goconst
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
}

View File

@@ -70,10 +70,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
c.logger.Debug(c.ipTables + " " + instruction)
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
}
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ipTables, instruction, output, err)
@@ -143,7 +147,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool) error {
protocol := connection.Protocol
if protocol == "tcp-client" {
protocol = "tcp"
protocol = "tcp" //nolint:goconst
}
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, protocol,

381
internal/firewall/list.go Normal file
View File

@@ -0,0 +1,381 @@
package firewall
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type chain struct {
name string
policy string
packets uint64
bytes uint64
rules []chainRule
}
type chainRule struct {
lineNumber uint16 // starts from 1 and cannot be zero.
packets uint64
bytes uint64
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
protocol string // "tcp", "udp" or "" for all protocols.
inputInterface string // input interface, for example "tun0" or "*""
outputInterface string // output interface, for example "eth0" or "*""
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
}
var (
ErrChainListMalformed = errors.New("iptables chain list output is malformed")
)
func parseChain(iptablesOutput string) (c chain, err error) {
// Text example:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
// pkts bytes target prot opt in out source destination
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
iptablesOutput = strings.TrimSpace(iptablesOutput)
linesWithComments := strings.Split(iptablesOutput, "\n")
// Filter out lines starting with a '#' character
lines := make([]string, 0, len(linesWithComments))
for _, line := range linesWithComments {
if strings.HasPrefix(line, "#") {
continue
}
lines = append(lines, line)
}
const minLines = 2 // chain general information line + legend line
if len(lines) < minLines {
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
ErrChainListMalformed, iptablesOutput)
}
c, err = parseChainGeneralDataLine(lines[0])
if err != nil {
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
}
// Sanity check for the legend line
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
legendLine := strings.TrimSpace(lines[1])
legendFields := strings.Fields(legendLine)
if !slices.Equal(expectedLegendFields, legendFields) {
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
}
lines = lines[2:] // remove chain general information line and legend line
if len(lines) == 0 {
return c, nil
}
c.rules = make([]chainRule, len(lines))
for i, line := range lines {
c.rules[i], err = parseChainRuleLine(line)
if err != nil {
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
}
}
return c, nil
}
// parseChainGeneralDataLine parses the first line of iptables chain list output.
// For example, it can parse the following line:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
// It returns a chain struct with the parsed data.
func parseChainGeneralDataLine(line string) (base chain, err error) {
line = strings.TrimSpace(line)
runesToRemove := []rune{'(', ')', ','}
for _, r := range runesToRemove {
line = strings.ReplaceAll(line, string(r), "")
}
fields := strings.Fields(line)
const expectedNumberOfFields = 8
if len(fields) != expectedNumberOfFields {
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
ErrChainListMalformed, expectedNumberOfFields, line)
}
// Sanity checks
indexToExpectedValue := map[int]string{
0: "Chain",
2: "policy",
5: "packets",
7: "bytes",
}
for index, expectedValue := range indexToExpectedValue {
if fields[index] == expectedValue {
continue
}
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
ErrChainListMalformed, expectedValue, index, line)
}
base.name = fields[1] // chain name could be custom
base.policy = fields[3]
err = checkTarget(base.policy)
if err != nil {
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
}
packets, err := parseMetricSize(fields[4])
if err != nil {
return chain{}, fmt.Errorf("parsing packets: %w", err)
}
base.packets = packets
bytes, err := parseMetricSize(fields[6])
if err != nil {
return chain{}, fmt.Errorf("parsing bytes: %w", err)
}
base.bytes = bytes
return base, nil
}
var (
ErrChainRuleMalformed = errors.New("chain rule is malformed")
)
func parseChainRuleLine(line string) (rule chainRule, err error) {
line = strings.TrimSpace(line)
if line == "" {
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
}
fields := strings.Fields(line)
const minFields = 10
if len(fields) < minFields {
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
}
for fieldIndex, field := range fields[:minFields] {
err = parseChainRuleField(fieldIndex, field, &rule)
if err != nil {
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
}
}
if len(fields) > minFields {
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
if err != nil {
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
}
}
return rule, nil
}
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
if field == "" {
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
}
const (
numIndex = iota
packetsIndex
bytesIndex
targetIndex
protocolIndex
optIndex
inputInterfaceIndex
outputInterfaceIndex
sourceIndex
destinationIndex
)
switch fieldIndex {
case numIndex:
rule.lineNumber, err = parseLineNumber(field)
if err != nil {
return fmt.Errorf("parsing line number: %w", err)
}
case packetsIndex:
rule.packets, err = parseMetricSize(field)
if err != nil {
return fmt.Errorf("parsing packets: %w", err)
}
case bytesIndex:
rule.bytes, err = parseMetricSize(field)
if err != nil {
return fmt.Errorf("parsing bytes: %w", err)
}
case targetIndex:
err = checkTarget(field)
if err != nil {
return fmt.Errorf("checking target: %w", err)
}
rule.target = field
case protocolIndex:
rule.protocol, err = parseProtocol(field)
if err != nil {
return fmt.Errorf("parsing protocol: %w", err)
}
case optIndex: // ignored
case inputInterfaceIndex:
rule.inputInterface = field
case outputInterfaceIndex:
rule.outputInterface = field
case sourceIndex:
rule.source, err = parseIPPrefix(field)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case destinationIndex:
rule.destination, err = parseIPPrefix(field)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
}
return nil
}
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
for i := 0; i < len(optionalFields); i++ {
key := optionalFields[i]
switch key {
case "tcp", "udp":
i++
value := optionalFields[i]
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing destination port %q: %w", value, err)
}
rule.destinationPort = uint16(destinationPort)
case "redir":
i++
switch optionalFields[i] {
case "ports":
i++
ports, err := parsePortsCSV(optionalFields[i])
if err != nil {
return fmt.Errorf("parsing redirection ports: %w", err)
}
rule.redirPorts = ports
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
}
case "ctstate":
i++
rule.ctstate = strings.Split(optionalFields[i], ",")
default:
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
}
}
return nil
}
func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" {
return nil, nil
}
fields := strings.Split(s, ",")
ports = make([]uint16, len(fields))
for i, field := range fields {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(field, base, bitLength)
if err != nil {
return nil, fmt.Errorf("parsing port %q: %w", field, err)
}
ports[i] = uint16(port)
}
return ports, nil
}
var (
ErrLineNumberIsZero = errors.New("line number is zero")
)
func parseLineNumber(s string) (n uint16, err error) {
const base, bitLength = 10, 16
lineNumber, err := strconv.ParseUint(s, base, bitLength)
if err != nil {
return 0, err
} else if lineNumber == 0 {
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
}
return uint16(lineNumber), nil
}
var (
ErrTargetUnknown = errors.New("unknown target")
)
func checkTarget(target string) (err error) {
switch target {
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
return nil
}
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
}
var (
ErrProtocolUnknown = errors.New("unknown protocol")
)
func parseProtocol(s string) (protocol string, err error) {
switch s {
case "0":
case "6":
protocol = "tcp"
case "17":
protocol = "udp"
default:
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
}
return protocol, nil
}
var (
ErrMetricSizeMalformed = errors.New("metric size is malformed")
)
// parseMetricSize parses a metric size string like 140K or 226M and
// returns the raw integer matching it.
func parseMetricSize(size string) (n uint64, err error) {
if size == "" {
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
}
//nolint:gomnd
multiplerLetterToValue := map[byte]uint64{
'K': 1000,
'M': 1000000,
'G': 1000000000,
'T': 1000000000000,
}
lastCharacter := size[len(size)-1]
multiplier, ok := multiplerLetterToValue[lastCharacter]
if ok { // multiplier present
size = size[:len(size)-1]
} else {
multiplier = 1
}
const base, bitLength = 10, 64
n, err = strconv.ParseUint(size, base, bitLength)
if err != nil {
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
}
n *= multiplier
return n, nil
}

View File

@@ -0,0 +1,121 @@
package firewall
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseChain(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
iptablesOutput string
table chain
errWrapped error
errMessage string
}{
"no_output": {
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
},
"single_line_only": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
},
"malformed_general_data_line": {
iptablesOutput: `Chain INPUT
num pkts bytes target prot opt in out source destination`,
errWrapped: ErrChainListMalformed,
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
"expected 8 fields in \"Chain INPUT\"",
},
"malformed_legend": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: legend " +
"\"num pkts bytes target prot opt in out source\" " +
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
},
"no_rule": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source destination`,
table: chain{
name: "INPUT",
policy: "ACCEPT",
packets: 140000,
bytes: 226000000,
},
},
"some_rules": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source destination
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
3 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
`,
table: chain{
name: "INPUT",
policy: "ACCEPT",
packets: 140000,
bytes: 226000000,
rules: []chainRule{
{
lineNumber: 1,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "udp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
destinationPort: 55405,
},
{
lineNumber: 2,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "tcp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
destinationPort: 55405,
},
{
lineNumber: 3,
packets: 0,
bytes: 0,
target: "DROP",
protocol: "",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
},
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
table, err := parseChain(testCase.iptablesOutput)
assert.Equal(t, testCase.table, table)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -5,12 +5,6 @@ import (
"net/netip"
)
type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}
func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) {
c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+
"no default route matching its family", subnet))

View File

@@ -0,0 +1,3 @@
package firewall
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger

View File

@@ -0,0 +1,109 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger)
// Package firewall is a generated GoMock package.
package firewall
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
command "github.com/qdm12/golibs/command"
)
// MockRunner is a mock of Runner interface.
type MockRunner struct {
ctrl *gomock.Controller
recorder *MockRunnerMockRecorder
}
// MockRunnerMockRecorder is the mock recorder for MockRunner.
type MockRunnerMockRecorder struct {
mock *MockRunner
}
// NewMockRunner creates a new mock instance.
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
mock := &MockRunner{ctrl: ctrl}
mock.recorder = &MockRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
return m.recorder
}
// Run mocks base method.
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Run indicates an expected call of Run.
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
}
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Error mocks base method.
func (m *MockLogger) Error(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Error", arg0)
}
// Error indicates an expected call of Error.
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
}
// Info mocks base method.
func (m *MockLogger) Info(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0)
}
// Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
}

166
internal/firewall/parse.go Normal file
View File

@@ -0,0 +1,166 @@
package firewall
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
append bool
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
}
func (i *iptablesInstruction) setDefaults() {
if i.table == "" {
i.table = "filter"
}
}
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
switch {
case i.table != table:
return false
case i.chain != chain:
return false
case i.target != rule.target:
return false
case i.protocol != rule.protocol:
return false
case i.destinationPort != rule.destinationPort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
return false
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
return false
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
return false
case !ipPrefixesEqual(i.source, rule.source):
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
default:
return true
}
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
}
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
return instruction == chainRule ||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var (
ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
)
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
if len(fields)%2 != 0 {
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
ErrIptablesCommandMalformed, len(fields), s)
}
for i := 0; i < len(fields); i += 2 {
key := fields[i]
value := fields[i+1]
err = parseInstructionFlag(key, value, &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
switch key {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.append = false
instruction.chain = value
case "-A", "--append":
instruction.append = true
instruction.chain = value
case "-j", "--jump":
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match": // ignore match
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
instruction.outputInterface = value
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
portStrings := strings.Split(value, ",")
instruction.toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
}
default:
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
}
return nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
return netip.ParsePrefix(value)
}
ip, err := netip.ParseAddr(value)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}

View File

@@ -0,0 +1,138 @@
package firewall
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseIptablesInstruction(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
s string
instruction iptablesInstruction
errWrapped error
errMessage string
}{
"no_instruction": {
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: empty instruction",
},
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
},
"unknown_key": {
s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
},
"one_pair": {
s: "-A INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
},
},
"instruction_A": {
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("5.6.7.8/32"),
destinationPort: 10000,
target: "ACCEPT",
},
},
"nat_redirection": {
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
append: false,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
target: "REDIRECT",
toPorts: []uint16{5678},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
rule, err := parseIptablesInstruction(testCase.s)
assert.Equal(t, testCase.instruction, rule)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}
func Test_parseIPPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
value string
prefix netip.Prefix
errMessage string
}{
"empty": {
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
},
"invalid": {
value: "invalid",
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
},
"valid_ipv4_with_bits": {
value: "10.0.0.0/16",
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
},
"valid_ipv4_without_bits": {
value: "10.0.0.4",
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
},
"valid_ipv6_with_bits": {
value: "2001:db8::/32",
prefix: netip.PrefixFrom(
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
32),
},
"valid_ipv6_without_bits": {
value: "2001:db8::",
prefix: netip.PrefixFrom(
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
128),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix, err := parseIPPrefix(testCase.value)
assert.Equal(t, testCase.prefix, prefix)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,50 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/golibs/command (interfaces: Runner)
// Package firewall is a generated GoMock package.
package firewall
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
command "github.com/qdm12/golibs/command"
)
// MockRunner is a mock of Runner interface.
type MockRunner struct {
ctrl *gomock.Controller
recorder *MockRunnerMockRecorder
}
// MockRunnerMockRecorder is the mock recorder for MockRunner.
type MockRunnerMockRecorder struct {
mock *MockRunner
}
// NewMockRunner creates a new mock instance.
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
mock := &MockRunner{ctrl: ctrl}
mock.recorder = &MockRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
return m.recorder
}
// Run mocks base method.
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Run indicates an expected call of Run.
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
}

View File

@@ -11,8 +11,6 @@ import (
"github.com/stretchr/testify/require"
)
//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner
func newAppendTestRuleMatcher(path string) *cmdMatcher {
return newCmdMatcher(path,
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"sort"
"strings"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
@@ -58,9 +59,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
servers = make([]models.Server, 0, len(hostToIPs))
for _, serverData := range data.Servers {
city, region := parseCity(serverData.City)
server := models.Server{
Country: serverData.Country,
City: serverData.City,
City: city,
Region: region,
ISP: serverData.ISP,
}
@@ -96,3 +99,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return servers, nil
}
func parseCity(city string) (parsedCity, region string) {
commaIndex := strings.Index(city, ", ")
if commaIndex == -1 {
return city, ""
}
return city[:commaIndex], city[commaIndex+2:]
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/netip"
"strings"
)
// Check out the JSON data from https://api.nordvpn.com/v2/servers?limit=10
@@ -92,6 +93,9 @@ func (s serversData) idToData() (
) {
groups = make(map[uint32]groupData, len(s.Groups))
for _, group := range s.Groups {
if group.Type.Identifier == "regions" { //nolint:goconst
group.Title = strings.ReplaceAll(group.Title, ",", "")
}
groups[group.ID] = group
}

View File

@@ -79,7 +79,7 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
server := models.Server{
Country: location.Country.Name,
Region: jsonServer.region(groups),
Region: region,
City: location.Country.City.Name,
Categories: jsonServer.categories(groups),
Hostname: jsonServer.Hostname,

View File

@@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context,
}
serverName := objects.ServerName
apiIP := buildAPIIPAddress(objects.Gateway)
logger := objects.Logger
if !objects.CanPortForward {
@@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context,
if !dataFound || expired {
client := objects.Client
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP,
p.portForwardPath, objects.Username, objects.Password)
if err != nil {
return nil, fmt.Errorf("refreshing port forward data: %w", err)
@@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context,
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
// First time binding
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil {
return nil, fmt.Errorf("binding port: %w", err)
}
@@ -100,6 +100,8 @@ func (p *Provider) KeepPortForward(ctx context.Context,
panic("gateway is not set")
}
apiIP := buildAPIIPAddress(objects.Gateway)
privateIPClient, err := newHTTPClient(objects.ServerName)
if err != nil {
return fmt.Errorf("creating custom HTTP client: %w", err)
@@ -127,7 +129,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
return ctx.Err()
case <-keepAliveTimer.C:
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
err = bindPort(ctx, privateIPClient, apiIP, data)
if err != nil {
return fmt.Errorf("binding port: %w", err)
}
@@ -139,14 +141,25 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
}
func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) {
if gateway.Is6() {
panic("IPv6 gateway not supported")
}
gatewayBytes := gateway.As4()
gatewayBytes[2] = 128
gatewayBytes[3] = 1
return netip.AddrFrom4(gatewayBytes)
}
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, username, password)
if err != nil {
return data, fmt.Errorf("fetching token: %w", err)
}
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token)
if err != nil {
return data, fmt.Errorf("fetching port forwarding data: %w", err)
}
@@ -286,7 +299,7 @@ func fetchToken(ctx context.Context, client *http.Client,
return result.Token, nil
}
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) (
port uint16, signature string, expiration time.Time, err error) {
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
@@ -294,7 +307,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway neti
queryParams.Add("token", token)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Host: net.JoinHostPort(apiIP.String(), "19999"),
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
@@ -340,7 +353,7 @@ var (
ErrBadResponse = errors.New("bad response received")
)
func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
payload, err := packPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return fmt.Errorf("serializing payload: %w", err)
@@ -351,7 +364,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
queryParams.Add("signature", data.Signature)
bindPortURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Host: net.JoinHostPort(apiIPAddress.String(), "19999"),
Path: "/bindPort",
RawQuery: queryParams.Encode(),
}

View File

@@ -2,13 +2,17 @@ package server
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
"github.com/qdm12/gluetun/internal/server/middlewares/log"
)
func newHandler(ctx context.Context, logger infoWarner, logging bool,
func newHandler(ctx context.Context, logger Logger, logging bool,
authSettings auth.Settings,
buildInfo models.BuildInformation,
vpnLooper VPNLooper,
pfGetter PortForwardedGetter,
@@ -17,7 +21,7 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
publicIPLooper PublicIPLoop,
storage Storage,
ipv6Supported bool,
) http.Handler {
) (httpHandler http.Handler, err error) {
handler := &handler{}
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
@@ -29,16 +33,25 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
handlerWithLog := withLogMiddleware(handler, logger, logging)
handler.setLogEnabled = handlerWithLog.setEnabled
authMiddleware, err := auth.New(authSettings, logger)
if err != nil {
return nil, fmt.Errorf("creating auth middleware: %w", err)
}
return handlerWithLog
middlewares := []func(http.Handler) http.Handler{
authMiddleware,
log.New(logger, logging),
}
httpHandler = handler
for _, middleware := range middlewares {
httpHandler = middleware(httpHandler)
}
return httpHandler, nil
}
type handler struct {
v0 http.Handler
v1 http.Handler
setLogEnabled func(enabled bool)
v0 http.Handler
v1 http.Handler
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

View File

@@ -1,8 +1,10 @@
package server
type Logger interface {
Debugf(format string, args ...any)
infoer
warner
Warnf(format string, args ...any)
errorer
}

View File

@@ -0,0 +1,36 @@
package auth
import (
"crypto/sha256"
"crypto/subtle"
"net/http"
)
type apiKeyMethod struct {
apiKeyDigest [32]byte
}
func newAPIKeyMethod(apiKey string) *apiKeyMethod {
return &apiKeyMethod{
apiKeyDigest: sha256.Sum256([]byte(apiKey)),
}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (a *apiKeyMethod) equal(other authorizationChecker) bool {
otherTokenMethod, ok := other.(*apiKeyMethod)
if !ok {
return false
}
return a.apiKeyDigest == otherTokenMethod.apiKeyDigest
}
func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool {
xAPIKey := request.Header.Get("X-API-Key")
if xAPIKey == "" {
xAPIKey = request.URL.Query().Get("api_key")
}
xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey))
return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1
}

View File

@@ -0,0 +1,37 @@
package auth
import (
"crypto/sha256"
"crypto/subtle"
"net/http"
)
type basicAuthMethod struct {
authDigest [32]byte
}
func newBasicAuthMethod(username, password string) *basicAuthMethod {
return &basicAuthMethod{
authDigest: sha256.Sum256([]byte(username + password)),
}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (a *basicAuthMethod) equal(other authorizationChecker) bool {
otherBasicMethod, ok := other.(*basicAuthMethod)
if !ok {
return false
}
return a.authDigest == otherBasicMethod.authDigest
}
func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool {
username, password, ok := request.BasicAuth()
if !ok {
headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
return false
}
requestAuthDigest := sha256.Sum256([]byte(username + password))
return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1
}

View File

@@ -0,0 +1,35 @@
package auth
import (
"errors"
"fmt"
"os"
"github.com/pelletier/go-toml/v2"
)
// Read reads the toml file specified by the filepath given.
// If the file does not exist, it returns empty settings and no error.
func Read(filepath string) (settings Settings, err error) {
file, err := os.Open(filepath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return Settings{}, nil
}
return settings, fmt.Errorf("opening file: %w", err)
}
decoder := toml.NewDecoder(file)
decoder.DisallowUnknownFields()
err = decoder.Decode(&settings)
if err == nil {
return settings, nil
}
strictErr := new(toml.StrictMissingError)
ok := errors.As(err, &strictErr)
if !ok {
return settings, fmt.Errorf("toml decoding file: %w", err)
}
return settings, fmt.Errorf("toml decoding file: %w:\n%s",
strictErr, strictErr.String())
}

View File

@@ -0,0 +1,80 @@
package auth
import (
"io/fs"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Read reads the toml file specified by the filepath given.
func Test_Read(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
fileContent string
settings Settings
errMessage string
}{
"empty_file": {},
"malformed_toml": {
fileContent: "this is not a toml file",
errMessage: `toml decoding file: toml: expected character =`,
},
"unknown_field": {
fileContent: `unknown = "what is this"`,
errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct:
1| unknown = "what is this"
| ~~~~~~~ missing field`,
},
"filled_settings": {
fileContent: `[[roles]]
name = "public"
auth = "none"
routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"]
[[roles]]
name = "client"
auth = "apikey"
apikey = "xyz"
routes = ["GET /v1/vpn/status"]
`,
settings: Settings{
Roles: []Role{{
Name: "public",
Auth: AuthNone,
Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"},
}, {
Name: "client",
Auth: AuthAPIKey,
APIKey: "xyz",
Routes: []string{"GET /v1/vpn/status"},
}},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
filepath := tempDir + "/config.toml"
const permissions fs.FileMode = 0600
err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions)
require.NoError(t, err)
settings, err := Read(filepath)
assert.Equal(t, testCase.settings, settings)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,22 @@
package auth
func andStrings(strings []string) (result string) {
return joinStrings(strings, "and")
}
func joinStrings(strings []string, lastJoin string) (result string) {
if len(strings) == 0 {
return ""
}
result = strings[0]
for i := 1; i < len(strings); i++ {
if i < len(strings)-1 {
result += ", " + strings[i]
} else {
result += " " + lastJoin + " " + strings[i]
}
}
return result
}

View File

@@ -0,0 +1,6 @@
package auth
type DebugLogger interface {
Debugf(format string, args ...any)
Warnf(format string, args ...any)
}

View File

@@ -0,0 +1,8 @@
package auth
import "net/http"
type authorizationChecker interface {
equal(other authorizationChecker) bool
isAuthorized(headers http.Header, request *http.Request) bool
}

View File

@@ -0,0 +1,47 @@
package auth
import (
"fmt"
)
type internalRole struct {
name string
checker authorizationChecker
}
func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) {
routeToRoles = make(map[string][]internalRole)
for _, role := range settings.Roles {
var checker authorizationChecker
switch role.Auth {
case AuthNone:
checker = newNoneMethod()
case AuthAPIKey:
checker = newAPIKeyMethod(role.APIKey)
case AuthBasic:
checker = newBasicAuthMethod(role.Username, role.Password)
default:
return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth)
}
iRole := internalRole{
name: role.Name,
checker: checker,
}
for _, route := range role.Routes {
checkerExists := false
for _, role := range routeToRoles[route] {
if role.checker.equal(iRole.checker) {
checkerExists = true
break
}
}
if checkerExists {
// even if the role name is different, if the checker is the same, skip it.
continue
}
routeToRoles[route] = append(routeToRoles[route], iRole)
}
}
return routeToRoles, nil
}

View File

@@ -0,0 +1,60 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
)
// Read reads the toml file specified by the filepath given.
func Test_settingsToLookupMap(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
settings Settings
routeToRoles map[string][]internalRole
errWrapped error
errMessage string
}{
"empty_settings": {
routeToRoles: map[string][]internalRole{},
},
"auth_method_not_supported": {
settings: Settings{
Roles: []Role{{Name: "a", Auth: "bad"}},
},
errWrapped: ErrMethodNotSupported,
errMessage: "authentication method not supported: bad",
},
"success": {
settings: Settings{
Roles: []Role{
{Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}},
{Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}},
},
},
routeToRoles: map[string][]internalRole{
"GET /path": {
{name: "a", checker: newNoneMethod()}, // deduplicated method
},
"PUT /path": {
{name: "b", checker: newNoneMethod()},
}},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
routeToRoles, err := settingsToLookupMap(testCase.settings)
assert.Equal(t, testCase.routeToRoles, routeToRoles)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -0,0 +1,111 @@
package auth
import (
"fmt"
"net/http"
)
func New(settings Settings, debugLogger DebugLogger) (
middleware func(http.Handler) http.Handler,
err error) {
routeToRoles, err := settingsToLookupMap(settings)
if err != nil {
return nil, fmt.Errorf("converting settings to lookup maps: %w", err)
}
//nolint:goconst
return func(handler http.Handler) http.Handler {
return &authHandler{
childHandler: handler,
routeToRoles: routeToRoles,
unprotectedRoutes: map[string]struct{}{
http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {},
http.MethodGet + " /v1/vpn/status": {},
http.MethodPut + " /v1/vpn/status": {},
// GET /v1/vpn/settings is protected by default
// PUT /v1/vpn/settings is protected by default
http.MethodGet + " /v1/openvpn/status": {},
http.MethodPut + " /v1/openvpn/status": {},
http.MethodGet + " /v1/openvpn/portforwarded": {},
// GET /v1/openvpn/settings is protected by default
http.MethodGet + " /v1/dns/status": {},
http.MethodPut + " /v1/dns/status": {},
http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {},
},
logger: debugLogger,
}
}, nil
}
type authHandler struct {
childHandler http.Handler
routeToRoles map[string][]internalRole
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove
logger DebugLogger
}
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
route := request.Method + " " + request.URL.Path
roles := h.routeToRoles[route]
if len(roles) == 0 {
h.logger.Debugf("no authentication role defined for route %s", route)
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
responseHeader := make(http.Header, 0)
for _, role := range roles {
if !role.checker.isAuthorized(responseHeader, request) {
continue
}
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
h.childHandler.ServeHTTP(writer, request)
return
}
// Flush out response headers if all roles failed to authenticate
for headerKey, headerValues := range responseHeader {
for _, headerValue := range headerValues {
writer.Header().Add(headerKey, headerValue)
}
}
allRoleNames := make([]string, len(roles))
for i, role := range roles {
allRoleNames[i] = role.name
}
h.logger.Debugf("access to route %s unauthorized after checking for roles %s",
route, andStrings(allRoleNames))
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
// TODO v3.41.0 remove
if role.name != "public" {
// custom role name, allow none authentication to be specified
return
}
_, isNoneChecker := role.checker.(*noneMethod)
if !isNoneChecker {
// not the none authentication method
return
}
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
if !isUnprotectedByDefault {
// route is not unprotected by default, so this is a user decision
return
}
h.logger.Warnf("route %s is unprotected by default, "+
"please set up authentication following the documentation at "+
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
"since this will become no longer publicly accessible after release v3.40.",
route)
}

View File

@@ -0,0 +1,124 @@
package auth
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_authHandler_ServeHTTP(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
settings Settings
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
requestMethod string
requestPath string
statusCode int
responseBody string
}{
"route_has_no_role": {
settings: Settings{
Roles: []Role{
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/b",
statusCode: http.StatusUnauthorized,
responseBody: "Unauthorized\n",
},
"authorized_unprotected_by_default": {
settings: Settings{
Roles: []Role{
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Warnf("route %s is unprotected by default, "+
"please set up authentication following the documentation at "+
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
"since this will become no longer publicly accessible after release v3.40.",
"GET /v1/vpn/status")
logger.EXPECT().Debugf("access to route %s authorized for role %s",
"GET /v1/vpn/status", "public")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/v1/vpn/status",
statusCode: http.StatusOK,
},
"authorized_none": {
settings: Settings{
Roles: []Role{
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Debugf("access to route %s authorized for role %s",
"GET /a", "role1")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/a",
statusCode: http.StatusOK,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var debugLogger DebugLogger
if testCase.makeLogger != nil {
debugLogger = testCase.makeLogger(ctrl)
}
middleware, err := New(testCase.settings, debugLogger)
require.NoError(t, err)
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := middleware(childHandler)
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
client := server.Client()
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
require.NoError(t, err)
request, err := http.NewRequestWithContext(context.Background(),
testCase.requestMethod, requestURL, nil)
require.NoError(t, err)
response, err := client.Do(request)
require.NoError(t, err)
t.Cleanup(func() {
err = response.Body.Close()
assert.NoError(t, err)
})
assert.Equal(t, testCase.statusCode, response.StatusCode)
body, err := io.ReadAll(response.Body)
require.NoError(t, err)
assert.Equal(t, testCase.responseBody, string(body))
})
}
}

View File

@@ -0,0 +1,3 @@
package auth
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger

View File

@@ -0,0 +1,68 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger)
// Package auth is a generated GoMock package.
package auth
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockDebugLogger is a mock of DebugLogger interface.
type MockDebugLogger struct {
ctrl *gomock.Controller
recorder *MockDebugLoggerMockRecorder
}
// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger.
type MockDebugLoggerMockRecorder struct {
mock *MockDebugLogger
}
// NewMockDebugLogger creates a new mock instance.
func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger {
mock := &MockDebugLogger{ctrl: ctrl}
mock.recorder = &MockDebugLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder {
return m.recorder
}
// Debugf mocks base method.
func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Debugf", varargs...)
}
// Debugf indicates an expected call of Debugf.
func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...)
}
// Warnf mocks base method.
func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Warnf", varargs...)
}
// Warnf indicates an expected call of Warnf.
func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...)
}

View File

@@ -0,0 +1,20 @@
package auth
import "net/http"
type noneMethod struct{}
func newNoneMethod() *noneMethod {
return &noneMethod{}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (n *noneMethod) equal(other authorizationChecker) bool {
_, ok := other.(*noneMethod)
return ok
}
func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool {
return true
}

View File

@@ -0,0 +1,131 @@
package auth
import (
"errors"
"fmt"
"net/http"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/validate"
)
type Settings struct {
// Roles is a list of roles with their associated authentication
// and routes.
Roles []Role
}
func (s *Settings) SetDefaults() {
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
Name: "public",
Auth: "none",
Routes: []string{
http.MethodGet + " /openvpn/actions/restart",
http.MethodGet + " /unbound/actions/restart",
http.MethodGet + " /updater/restart",
http.MethodGet + " /v1/version",
http.MethodGet + " /v1/vpn/status",
http.MethodPut + " /v1/vpn/status",
http.MethodGet + " /v1/openvpn/status",
http.MethodPut + " /v1/openvpn/status",
http.MethodGet + " /v1/openvpn/portforwarded",
http.MethodGet + " /v1/dns/status",
http.MethodPut + " /v1/dns/status",
http.MethodGet + " /v1/updater/status",
http.MethodPut + " /v1/updater/status",
http.MethodGet + " /v1/publicip/ip",
},
}})
}
func (s Settings) Validate() (err error) {
for i, role := range s.Roles {
err = role.validate()
if err != nil {
return fmt.Errorf("role %s (%d of %d): %w",
role.Name, i+1, len(s.Roles), err)
}
}
return nil
}
const (
AuthNone = "none"
AuthAPIKey = "apikey"
AuthBasic = "basic"
)
// Role contains the role name, authentication method name and
// routes that the role can access.
type Role struct {
// Name is the role name and is only used for documentation
// and in the authentication middleware debug logs.
Name string
// Auth is the authentication method to use, which can be 'none' or 'apikey'.
Auth string
// APIKey is the API key to use when using the 'apikey' authentication.
APIKey string
// Username for HTTP Basic authentication method.
Username string
// Password for HTTP Basic authentication method.
Password string
// Routes is a list of routes that the role can access in the format
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
Routes []string
}
var (
ErrMethodNotSupported = errors.New("authentication method not supported")
ErrAPIKeyEmpty = errors.New("api key is empty")
ErrBasicUsernameEmpty = errors.New("username is empty")
ErrBasicPasswordEmpty = errors.New("password is empty")
ErrRouteNotSupported = errors.New("route not supported by the control server")
)
func (r Role) validate() (err error) {
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
if err != nil {
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
}
switch {
case r.Auth == AuthAPIKey && r.APIKey == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty)
case r.Auth == AuthBasic && r.Username == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty)
case r.Auth == AuthBasic && r.Password == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty)
}
for i, route := range r.Routes {
_, ok := validRoutes[route]
if !ok {
return fmt.Errorf("route %d of %d: %w: %s",
i+1, len(r.Routes), ErrRouteNotSupported, route)
}
}
return nil
}
// WARNING: do not mutate programmatically.
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {},
http.MethodGet + " /v1/vpn/status": {},
http.MethodPut + " /v1/vpn/status": {},
http.MethodGet + " /v1/vpn/settings": {},
http.MethodPut + " /v1/vpn/settings": {},
http.MethodGet + " /v1/openvpn/status": {},
http.MethodPut + " /v1/openvpn/status": {},
http.MethodGet + " /v1/openvpn/portforwarded": {},
http.MethodGet + " /v1/openvpn/settings": {},
http.MethodGet + " /v1/dns/status": {},
http.MethodPut + " /v1/dns/status": {},
http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {},
}

View File

@@ -0,0 +1,5 @@
package log
type Logger interface {
Info(message string)
}

View File

@@ -1,4 +1,4 @@
package server
package log
import (
"net/http"
@@ -7,18 +7,21 @@ import (
"time"
)
func withLogMiddleware(childHandler http.Handler, logger infoer, enabled bool) *logMiddleware {
return &logMiddleware{
childHandler: childHandler,
logger: logger,
timeNow: time.Now,
enabled: enabled,
func New(logger Logger, enabled bool) (
middleware func(http.Handler) http.Handler) {
return func(handler http.Handler) http.Handler {
return &logMiddleware{
childHandler: handler,
logger: logger,
timeNow: time.Now,
enabled: enabled,
}
}
}
type logMiddleware struct {
childHandler http.Handler
logger infoer
logger Logger
timeNow func() time.Time
enabled bool
enabledMu sync.RWMutex
@@ -39,7 +42,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.RemoteAddr + " in " + duration.String())
}
func (m *logMiddleware) setEnabled(enabled bool) {
func (m *logMiddleware) SetEnabled(enabled bool) {
m.enabledMu.Lock()
defer m.enabledMu.Unlock()
m.enabled = enabled

View File

@@ -6,17 +6,31 @@ import (
"github.com/qdm12/gluetun/internal/httpserver"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
)
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
ipv6Supported bool) (
server *httpserver.Server, err error) {
handler := newHandler(ctx, logger, logEnabled, buildInfo,
authSettings, err := auth.Read(authConfigPath)
if err != nil {
return nil, fmt.Errorf("reading auth settings: %w", err)
}
authSettings.SetDefaults()
err = authSettings.Validate()
if err != nil {
return nil, fmt.Errorf("validating auth settings: %w", err)
}
handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported)
if err != nil {
return nil, fmt.Errorf("creating handler: %w", err)
}
httpServerSettings := httpserver.Settings{
Address: address,

View File

@@ -128,6 +128,31 @@ func noServerFoundError(selection settings.ServerSelection) (err error) {
messageParts = append(messageParts, "premium tier only")
}
if *selection.StreamOnly {
messageParts = append(messageParts, "stream only")
}
if *selection.MultiHopOnly {
messageParts = append(messageParts, "multihop only")
}
if *selection.PortForwardOnly {
messageParts = append(messageParts, "port forwarding only")
}
if *selection.SecureCoreOnly {
messageParts = append(messageParts, "secure core only")
}
if *selection.TorOnly {
messageParts = append(messageParts, "tor only")
}
if selection.TargetIP.IsValid() {
messageParts = append(messageParts,
"target ip address "+selection.TargetIP.String())
}
message := "for " + strings.Join(messageParts, "; ")
return fmt.Errorf("%w: %s", ErrNoServerFound, message)

File diff suppressed because it is too large Load Diff