Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67ae5f5065 | ||
|
|
cbfdb25190 | ||
|
|
638f233b3c | ||
|
|
c450c54d67 | ||
|
|
d166314f8b | ||
|
|
7064a44403 | ||
|
|
c33158c13c |
@@ -197,6 +197,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
# Control server
|
# Control server
|
||||||
HTTP_CONTROL_SERVER_LOG=on \
|
HTTP_CONTROL_SERVER_LOG=on \
|
||||||
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
||||||
|
HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \
|
||||||
# Server data updater
|
# Server data updater
|
||||||
UPDATER_PERIOD=0 \
|
UPDATER_PERIOD=0 \
|
||||||
UPDATER_MIN_RATIO=0.8 \
|
UPDATER_MIN_RATIO=0.8 \
|
||||||
|
|||||||
@@ -161,12 +161,14 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
return cli.Update(ctx, args[2:], logger)
|
return cli.Update(ctx, args[2:], logger)
|
||||||
case "format-servers":
|
case "format-servers":
|
||||||
return cli.FormatServers(args[2:])
|
return cli.FormatServers(args[2:])
|
||||||
|
case "genkey":
|
||||||
|
return cli.GenKey(args[2:])
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
announcementExp, err := time.Parse(time.RFC3339, "2023-07-01T00:00:00Z")
|
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -177,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
Commit: buildInfo.Commit,
|
Commit: buildInfo.Commit,
|
||||||
Created: buildInfo.Created,
|
Created: buildInfo.Created,
|
||||||
Announcement: "Wiki moved to https://github.com/qdm12/gluetun-wiki",
|
Announcement: "All control server routes will become private by default after the v3.41.0 release",
|
||||||
AnnounceExp: announcementExp,
|
AnnounceExp: announcementExp,
|
||||||
// Sponsor information
|
// Sponsor information
|
||||||
PaypalUser: "qmcgaw",
|
PaypalUser: "qmcgaw",
|
||||||
@@ -474,6 +476,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
||||||
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
|
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
|
||||||
logger.New(log.SetComponent("http server")),
|
logger.New(log.SetComponent("http server")),
|
||||||
|
allSettings.ControlServer.AuthFilePath,
|
||||||
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
|
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
|
||||||
storage, ipv6Supported)
|
storage, ipv6Supported)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -595,6 +598,7 @@ type clier interface {
|
|||||||
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
|
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
|
||||||
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
|
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
|
||||||
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
|
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
|
||||||
|
GenKey(args []string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tun interface {
|
type Tun interface {
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -8,6 +8,7 @@ require (
|
|||||||
github.com/golang/mock v1.6.0
|
github.com/golang/mock v1.6.0
|
||||||
github.com/klauspost/compress v1.17.8
|
github.com/klauspost/compress v1.17.8
|
||||||
github.com/klauspost/pgzip v1.2.6
|
github.com/klauspost/pgzip v1.2.6
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2
|
||||||
github.com/qdm12/dns v1.11.0
|
github.com/qdm12/dns v1.11.0
|
||||||
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
|
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
|
||||||
github.com/qdm12/gosettings v0.4.2
|
github.com/qdm12/gosettings v0.4.2
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -83,6 +83,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
|||||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||||
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
||||||
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
|
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
|
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
@@ -113,10 +115,16 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
|
|||||||
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
|
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=
|
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=
|
||||||
|
|||||||
66
internal/cli/genkey.go
Normal file
66
internal/cli/genkey.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *CLI) GenKey(args []string) (err error) {
|
||||||
|
flagSet := flag.NewFlagSet("genkey", flag.ExitOnError)
|
||||||
|
err = flagSet.Parse(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing flags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const keyLength = 128 / 8
|
||||||
|
keyBytes := make([]byte, keyLength)
|
||||||
|
|
||||||
|
_, _ = rand.Read(keyBytes)
|
||||||
|
|
||||||
|
key := base58Encode(keyBytes)
|
||||||
|
fmt.Println(key)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func base58Encode(data []byte) string {
|
||||||
|
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||||
|
const radix = 58
|
||||||
|
|
||||||
|
zcount := 0
|
||||||
|
for zcount < len(data) && data[zcount] == 0 {
|
||||||
|
zcount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// integer simplification of ceil(log(256)/log(58))
|
||||||
|
ceilLog256Div58 := (len(data)-zcount)*555/406 + 1 //nolint:gomnd
|
||||||
|
size := zcount + ceilLog256Div58
|
||||||
|
|
||||||
|
output := make([]byte, size)
|
||||||
|
|
||||||
|
high := size - 1
|
||||||
|
for _, b := range data {
|
||||||
|
i := size - 1
|
||||||
|
for carry := uint32(b); i > high || carry != 0; i-- {
|
||||||
|
carry += 256 * uint32(output[i]) //nolint:gomnd
|
||||||
|
output[i] = byte(carry % radix)
|
||||||
|
carry /= radix
|
||||||
|
}
|
||||||
|
high = i
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the additional "zero-gap" in the output buffer
|
||||||
|
additionalZeroGapEnd := zcount
|
||||||
|
for additionalZeroGapEnd < size && output[additionalZeroGapEnd] == 0 {
|
||||||
|
additionalZeroGapEnd++
|
||||||
|
}
|
||||||
|
|
||||||
|
val := output[additionalZeroGapEnd-zcount:]
|
||||||
|
size = len(val)
|
||||||
|
for i := range val {
|
||||||
|
output[i] = alphabet[val[i]]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(output[:size])
|
||||||
|
}
|
||||||
@@ -19,6 +19,11 @@ type ControlServer struct {
|
|||||||
// Log can be true or false to enable logging on requests.
|
// Log can be true or false to enable logging on requests.
|
||||||
// It cannot be nil in the internal state.
|
// It cannot be nil in the internal state.
|
||||||
Log *bool
|
Log *bool
|
||||||
|
// AuthFilePath is the path to the file containing the authentication
|
||||||
|
// configuration for the middleware.
|
||||||
|
// It cannot be empty in the internal state and defaults to
|
||||||
|
// /gluetun/auth/config.toml.
|
||||||
|
AuthFilePath string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c ControlServer) validate() (err error) {
|
func (c ControlServer) validate() (err error) {
|
||||||
@@ -44,8 +49,9 @@ func (c ControlServer) validate() (err error) {
|
|||||||
|
|
||||||
func (c *ControlServer) copy() (copied ControlServer) {
|
func (c *ControlServer) copy() (copied ControlServer) {
|
||||||
return ControlServer{
|
return ControlServer{
|
||||||
Address: gosettings.CopyPointer(c.Address),
|
Address: gosettings.CopyPointer(c.Address),
|
||||||
Log: gosettings.CopyPointer(c.Log),
|
Log: gosettings.CopyPointer(c.Log),
|
||||||
|
AuthFilePath: c.AuthFilePath,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,11 +61,13 @@ func (c *ControlServer) copy() (copied ControlServer) {
|
|||||||
func (c *ControlServer) overrideWith(other ControlServer) {
|
func (c *ControlServer) overrideWith(other ControlServer) {
|
||||||
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
|
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
|
||||||
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
|
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
|
||||||
|
c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ControlServer) setDefaults() {
|
func (c *ControlServer) setDefaults() {
|
||||||
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
|
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
|
||||||
c.Log = gosettings.DefaultPointer(c.Log, true)
|
c.Log = gosettings.DefaultPointer(c.Log, true)
|
||||||
|
c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c ControlServer) String() string {
|
func (c ControlServer) String() string {
|
||||||
@@ -70,6 +78,7 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) {
|
|||||||
node = gotree.New("Control server settings:")
|
node = gotree.New("Control server settings:")
|
||||||
node.Appendf("Listening address: %s", *c.Address)
|
node.Appendf("Listening address: %s", *c.Address)
|
||||||
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
|
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
|
||||||
|
node.Appendf("Authentication file path: %s", c.AuthFilePath)
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,6 +87,10 @@ func (c *ControlServer) read(r *reader.Reader) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
|
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
|
||||||
|
|
||||||
|
c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -191,11 +191,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
|||||||
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 {
|
if vpnServiceProvider == providers.Custom {
|
||||||
// Allow a single name to be specified for the custom provider in case
|
switch len(settings.Names) {
|
||||||
// the user wants to use VPN server side port forwarding with PIA
|
case 0:
|
||||||
// which requires a server name for TLS verification.
|
case 1:
|
||||||
filterChoices.Names = settings.Names
|
// Allow a single name to be specified for the custom provider in case
|
||||||
|
// the user wants to use VPN server side port forwarding with PIA
|
||||||
|
// which requires a server name for TLS verification.
|
||||||
|
filterChoices.Names = settings.Names
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: %d names specified instead of "+
|
||||||
|
"0 or 1 for the custom provider",
|
||||||
|
ErrNameNotValid, len(settings.Names))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
|
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -229,6 +237,8 @@ func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string)
|
|||||||
switch {
|
switch {
|
||||||
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
|
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
|
||||||
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
|
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
|
||||||
|
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
|
||||||
|
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
|
||||||
case *settings.StreamOnly &&
|
case *settings.StreamOnly &&
|
||||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||||
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
|
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
|
||||||
|
|||||||
@@ -78,7 +78,8 @@ func Test_Settings_String(t *testing.T) {
|
|||||||
| └── Enabled: no
|
| └── Enabled: no
|
||||||
├── Control server settings:
|
├── Control server settings:
|
||||||
| ├── Listening address: :8000
|
| ├── Listening address: :8000
|
||||||
| └── Logging: yes
|
| ├── Logging: yes
|
||||||
|
| └── Authentication file path: /gluetun/auth/config.toml
|
||||||
├── OS Alpine settings:
|
├── OS Alpine settings:
|
||||||
| ├── Process UID: 1000
|
| ├── Process UID: 1000
|
||||||
| └── Process GID: 1000
|
| └── Process GID: 1000
|
||||||
|
|||||||
98
internal/firewall/delete.go
Normal file
98
internal/firewall/delete.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// isDeleteMatchInstruction returns true if the iptables instruction
|
||||||
|
// is a delete instruction by rule matching. It returns false if the
|
||||||
|
// instruction is a delete instruction by line number, or not a delete
|
||||||
|
// instruction.
|
||||||
|
func isDeleteMatchInstruction(instruction string) bool {
|
||||||
|
fields := strings.Fields(instruction)
|
||||||
|
for i, field := range fields {
|
||||||
|
switch {
|
||||||
|
case field != "-D" && field != "--delete": //nolint:goconst
|
||||||
|
continue
|
||||||
|
case i == len(fields)-1: // malformed: missing chain name
|
||||||
|
return false
|
||||||
|
case i == len(fields)-2: // chain name is last field
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
// chain name is fields[i+1]
|
||||||
|
const base, bitLength = 10, 16
|
||||||
|
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
|
||||||
|
return err != nil // not a line number
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
|
||||||
|
runner Runner, logger Logger) (err error) {
|
||||||
|
targetRule, err := parseIptablesInstruction(instruction)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing iptables command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lineNumber, err := findLineNumber(ctx, iptablesBinary,
|
||||||
|
targetRule, runner, logger)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("finding iptables chain rule line number: %w", err)
|
||||||
|
} else if lineNumber == 0 {
|
||||||
|
logger.Debug("rule matching \"" + instruction + "\" not found")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
|
||||||
|
instruction, lineNumber))
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
|
||||||
|
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
|
||||||
|
logger.Debug(cmd.String())
|
||||||
|
output, err := runner.Run(cmd)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||||
|
if output != "" {
|
||||||
|
err = fmt.Errorf("%w: %s", err, output)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findLineNumber finds the line number of an iptables rule.
|
||||||
|
// It returns 0 if the rule is not found.
|
||||||
|
func findLineNumber(ctx context.Context, iptablesBinary string,
|
||||||
|
instruction iptablesInstruction, runner Runner, logger Logger) (
|
||||||
|
lineNumber uint16, err error) {
|
||||||
|
listFlags := []string{"-t", instruction.table, "-L", instruction.chain,
|
||||||
|
"--line-numbers", "-n", "-v"}
|
||||||
|
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
|
||||||
|
logger.Debug(cmd.String())
|
||||||
|
output, err := runner.Run(cmd)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||||
|
if output != "" {
|
||||||
|
err = fmt.Errorf("%w: %s", err, output)
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
chain, err := parseChain(output)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parsing chain list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range chain.rules {
|
||||||
|
if instruction.equalToRule(instruction.table, chain.name, rule) {
|
||||||
|
return rule.lineNumber, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
188
internal/firewall/delete_test.go
Normal file
188
internal/firewall/delete_test.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_isDeleteMatchInstruction(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
instruction string
|
||||||
|
isDeleteMatch bool
|
||||||
|
}{
|
||||||
|
"not_delete": {
|
||||||
|
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
|
||||||
|
},
|
||||||
|
"malformed_missing_chain_name": {
|
||||||
|
instruction: "-t nat -D",
|
||||||
|
},
|
||||||
|
"delete_chain_name_last_field": {
|
||||||
|
instruction: "-t nat --delete PREROUTING",
|
||||||
|
isDeleteMatch: true,
|
||||||
|
},
|
||||||
|
"delete_match": {
|
||||||
|
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
|
||||||
|
isDeleteMatch: true,
|
||||||
|
},
|
||||||
|
"delete_line_number_last_field": {
|
||||||
|
instruction: "-t nat -D PREROUTING 2",
|
||||||
|
},
|
||||||
|
"delete_line_number": {
|
||||||
|
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
|
||||||
|
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
|
||||||
|
"^--line-numbers$", "^-n$", "^-v$")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_deleteIPTablesRule(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const iptablesBinary = "/sbin/iptables"
|
||||||
|
errTest := errors.New("test error")
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
instruction string
|
||||||
|
makeRunner func(ctrl *gomock.Controller) *MockRunner
|
||||||
|
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||||
|
errWrapped error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"invalid_instruction": {
|
||||||
|
instruction: "invalid",
|
||||||
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
|
errMessage: "parsing iptables command: iptables command is malformed: " +
|
||||||
|
"fields count 1 is not even: \"invalid\"",
|
||||||
|
},
|
||||||
|
"list_error": {
|
||||||
|
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||||
|
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||||
|
runner := NewMockRunner(ctrl)
|
||||||
|
runner.EXPECT().
|
||||||
|
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||||
|
Return("", errTest)
|
||||||
|
return runner
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
errWrapped: errTest,
|
||||||
|
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||||
|
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||||
|
},
|
||||||
|
"rule_not_found": {
|
||||||
|
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||||
|
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||||
|
runner := NewMockRunner(ctrl)
|
||||||
|
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||||
|
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
|
||||||
|
num pkts bytes target prot opt in out source destination
|
||||||
|
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
|
||||||
|
nil)
|
||||||
|
return runner
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||||
|
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
|
||||||
|
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"rule_found_delete_error": {
|
||||||
|
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||||
|
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||||
|
runner := NewMockRunner(ctrl)
|
||||||
|
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||||
|
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||||
|
"num pkts bytes target prot opt in out source destination \n"+
|
||||||
|
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||||
|
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||||
|
nil)
|
||||||
|
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||||
|
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
||||||
|
return runner
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||||
|
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||||
|
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||||
|
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
errWrapped: errTest,
|
||||||
|
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||||
|
},
|
||||||
|
"rule_found_delete_success": {
|
||||||
|
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||||
|
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||||
|
runner := NewMockRunner(ctrl)
|
||||||
|
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||||
|
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||||
|
"num pkts bytes target prot opt in out source destination \n"+
|
||||||
|
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||||
|
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||||
|
nil)
|
||||||
|
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||||
|
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
|
||||||
|
return runner
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
logger := NewMockLogger(ctrl)
|
||||||
|
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||||
|
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||||
|
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||||
|
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
instruction := testCase.instruction
|
||||||
|
var runner *MockRunner
|
||||||
|
if testCase.makeRunner != nil {
|
||||||
|
runner = testCase.makeRunner(ctrl)
|
||||||
|
}
|
||||||
|
var logger *MockLogger
|
||||||
|
if testCase.makeLogger != nil {
|
||||||
|
logger = testCase.makeLogger(ctrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||||
|
if testCase.errWrapped != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
13
internal/firewall/interfaces.go
Normal file
13
internal/firewall/interfaces.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import "github.com/qdm12/golibs/command"
|
||||||
|
|
||||||
|
type Runner interface {
|
||||||
|
Run(cmd command.ExecCmd) (output string, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(s string)
|
||||||
|
Info(s string)
|
||||||
|
Error(s string)
|
||||||
|
}
|
||||||
@@ -40,10 +40,14 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
|
|||||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||||
defer c.ip6tablesMutex.Unlock()
|
defer c.ip6tablesMutex.Unlock()
|
||||||
|
|
||||||
c.logger.Debug(c.ip6Tables + " " + instruction)
|
if isDeleteMatchInstruction(instruction) {
|
||||||
|
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
|
||||||
|
c.runner, c.logger)
|
||||||
|
}
|
||||||
|
|
||||||
flags := strings.Fields(instruction)
|
flags := strings.Fields(instruction)
|
||||||
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
|
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
|
||||||
|
c.logger.Debug(cmd.String())
|
||||||
if output, err := c.runner.Run(cmd); err != nil {
|
if output, err := c.runner.Run(cmd); err != nil {
|
||||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||||
c.ip6Tables, instruction, output, err)
|
c.ip6Tables, instruction, output, err)
|
||||||
@@ -55,7 +59,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid")
|
|||||||
|
|
||||||
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||||
switch policy {
|
switch policy {
|
||||||
case "ACCEPT", "DROP":
|
case "ACCEPT", "DROP": //nolint:goconst
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,10 +70,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
|||||||
c.iptablesMutex.Lock() // only one iptables command at once
|
c.iptablesMutex.Lock() // only one iptables command at once
|
||||||
defer c.iptablesMutex.Unlock()
|
defer c.iptablesMutex.Unlock()
|
||||||
|
|
||||||
c.logger.Debug(c.ipTables + " " + instruction)
|
if isDeleteMatchInstruction(instruction) {
|
||||||
|
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||||
|
c.runner, c.logger)
|
||||||
|
}
|
||||||
|
|
||||||
flags := strings.Fields(instruction)
|
flags := strings.Fields(instruction)
|
||||||
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
|
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
|
||||||
|
c.logger.Debug(cmd.String())
|
||||||
if output, err := c.runner.Run(cmd); err != nil {
|
if output, err := c.runner.Run(cmd); err != nil {
|
||||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||||
c.ipTables, instruction, output, err)
|
c.ipTables, instruction, output, err)
|
||||||
@@ -143,7 +147,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
|
|||||||
defaultInterface string, connection models.Connection, remove bool) error {
|
defaultInterface string, connection models.Connection, remove bool) error {
|
||||||
protocol := connection.Protocol
|
protocol := connection.Protocol
|
||||||
if protocol == "tcp-client" {
|
if protocol == "tcp-client" {
|
||||||
protocol = "tcp"
|
protocol = "tcp" //nolint:goconst
|
||||||
}
|
}
|
||||||
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||||
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
|
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
|
||||||
|
|||||||
381
internal/firewall/list.go
Normal file
381
internal/firewall/list.go
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type chain struct {
|
||||||
|
name string
|
||||||
|
policy string
|
||||||
|
packets uint64
|
||||||
|
bytes uint64
|
||||||
|
rules []chainRule
|
||||||
|
}
|
||||||
|
|
||||||
|
type chainRule struct {
|
||||||
|
lineNumber uint16 // starts from 1 and cannot be zero.
|
||||||
|
packets uint64
|
||||||
|
bytes uint64
|
||||||
|
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
|
||||||
|
protocol string // "tcp", "udp" or "" for all protocols.
|
||||||
|
inputInterface string // input interface, for example "tun0" or "*""
|
||||||
|
outputInterface string // output interface, for example "eth0" or "*""
|
||||||
|
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||||
|
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||||
|
destinationPort uint16 // Not specified if set to zero.
|
||||||
|
redirPorts []uint16 // Not specified if empty.
|
||||||
|
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||||
|
// Text example:
|
||||||
|
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||||
|
// pkts bytes target prot opt in out source destination
|
||||||
|
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||||
|
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||||
|
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
|
||||||
|
iptablesOutput = strings.TrimSpace(iptablesOutput)
|
||||||
|
linesWithComments := strings.Split(iptablesOutput, "\n")
|
||||||
|
|
||||||
|
// Filter out lines starting with a '#' character
|
||||||
|
lines := make([]string, 0, len(linesWithComments))
|
||||||
|
for _, line := range linesWithComments {
|
||||||
|
if strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lines = append(lines, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
const minLines = 2 // chain general information line + legend line
|
||||||
|
if len(lines) < minLines {
|
||||||
|
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||||
|
ErrChainListMalformed, iptablesOutput)
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err = parseChainGeneralDataLine(lines[0])
|
||||||
|
if err != nil {
|
||||||
|
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity check for the legend line
|
||||||
|
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
|
||||||
|
legendLine := strings.TrimSpace(lines[1])
|
||||||
|
legendFields := strings.Fields(legendLine)
|
||||||
|
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||||
|
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||||
|
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
lines = lines[2:] // remove chain general information line and legend line
|
||||||
|
if len(lines) == 0 {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.rules = make([]chainRule, len(lines))
|
||||||
|
for i, line := range lines {
|
||||||
|
c.rules[i], err = parseChainRuleLine(line)
|
||||||
|
if err != nil {
|
||||||
|
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseChainGeneralDataLine parses the first line of iptables chain list output.
|
||||||
|
// For example, it can parse the following line:
|
||||||
|
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||||
|
// It returns a chain struct with the parsed data.
|
||||||
|
func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
runesToRemove := []rune{'(', ')', ','}
|
||||||
|
for _, r := range runesToRemove {
|
||||||
|
line = strings.ReplaceAll(line, string(r), "")
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
const expectedNumberOfFields = 8
|
||||||
|
if len(fields) != expectedNumberOfFields {
|
||||||
|
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||||
|
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity checks
|
||||||
|
indexToExpectedValue := map[int]string{
|
||||||
|
0: "Chain",
|
||||||
|
2: "policy",
|
||||||
|
5: "packets",
|
||||||
|
7: "bytes",
|
||||||
|
}
|
||||||
|
for index, expectedValue := range indexToExpectedValue {
|
||||||
|
if fields[index] == expectedValue {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||||
|
ErrChainListMalformed, expectedValue, index, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
base.name = fields[1] // chain name could be custom
|
||||||
|
base.policy = fields[3]
|
||||||
|
err = checkTarget(base.policy)
|
||||||
|
if err != nil {
|
||||||
|
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
packets, err := parseMetricSize(fields[4])
|
||||||
|
if err != nil {
|
||||||
|
return chain{}, fmt.Errorf("parsing packets: %w", err)
|
||||||
|
}
|
||||||
|
base.packets = packets
|
||||||
|
|
||||||
|
bytes, err := parseMetricSize(fields[6])
|
||||||
|
if err != nil {
|
||||||
|
return chain{}, fmt.Errorf("parsing bytes: %w", err)
|
||||||
|
}
|
||||||
|
base.bytes = bytes
|
||||||
|
|
||||||
|
return base, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
|
||||||
|
const minFields = 10
|
||||||
|
if len(fields) < minFields {
|
||||||
|
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldIndex, field := range fields[:minFields] {
|
||||||
|
err = parseChainRuleField(fieldIndex, field, &rule)
|
||||||
|
if err != nil {
|
||||||
|
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fields) > minFields {
|
||||||
|
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
|
||||||
|
if err != nil {
|
||||||
|
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||||
|
if field == "" {
|
||||||
|
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
numIndex = iota
|
||||||
|
packetsIndex
|
||||||
|
bytesIndex
|
||||||
|
targetIndex
|
||||||
|
protocolIndex
|
||||||
|
optIndex
|
||||||
|
inputInterfaceIndex
|
||||||
|
outputInterfaceIndex
|
||||||
|
sourceIndex
|
||||||
|
destinationIndex
|
||||||
|
)
|
||||||
|
|
||||||
|
switch fieldIndex {
|
||||||
|
case numIndex:
|
||||||
|
rule.lineNumber, err = parseLineNumber(field)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing line number: %w", err)
|
||||||
|
}
|
||||||
|
case packetsIndex:
|
||||||
|
rule.packets, err = parseMetricSize(field)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing packets: %w", err)
|
||||||
|
}
|
||||||
|
case bytesIndex:
|
||||||
|
rule.bytes, err = parseMetricSize(field)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing bytes: %w", err)
|
||||||
|
}
|
||||||
|
case targetIndex:
|
||||||
|
err = checkTarget(field)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("checking target: %w", err)
|
||||||
|
}
|
||||||
|
rule.target = field
|
||||||
|
case protocolIndex:
|
||||||
|
rule.protocol, err = parseProtocol(field)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing protocol: %w", err)
|
||||||
|
}
|
||||||
|
case optIndex: // ignored
|
||||||
|
case inputInterfaceIndex:
|
||||||
|
rule.inputInterface = field
|
||||||
|
case outputInterfaceIndex:
|
||||||
|
rule.outputInterface = field
|
||||||
|
case sourceIndex:
|
||||||
|
rule.source, err = parseIPPrefix(field)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||||
|
}
|
||||||
|
case destinationIndex:
|
||||||
|
rule.destination, err = parseIPPrefix(field)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||||
|
for i := 0; i < len(optionalFields); i++ {
|
||||||
|
key := optionalFields[i]
|
||||||
|
switch key {
|
||||||
|
case "tcp", "udp":
|
||||||
|
i++
|
||||||
|
value := optionalFields[i]
|
||||||
|
value = strings.TrimPrefix(value, "dpt:")
|
||||||
|
const base, bitLength = 10, 16
|
||||||
|
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
||||||
|
}
|
||||||
|
rule.destinationPort = uint16(destinationPort)
|
||||||
|
case "redir":
|
||||||
|
i++
|
||||||
|
switch optionalFields[i] {
|
||||||
|
case "ports":
|
||||||
|
i++
|
||||||
|
ports, err := parsePortsCSV(optionalFields[i])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||||
|
}
|
||||||
|
rule.redirPorts = ports
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||||
|
ErrChainRuleMalformed, optionalFields[i])
|
||||||
|
}
|
||||||
|
case "ctstate":
|
||||||
|
i++
|
||||||
|
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||||
|
if s == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Split(s, ",")
|
||||||
|
ports = make([]uint16, len(fields))
|
||||||
|
for i, field := range fields {
|
||||||
|
const base, bitLength = 10, 16
|
||||||
|
port, err := strconv.ParseUint(field, base, bitLength)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
||||||
|
}
|
||||||
|
ports[i] = uint16(port)
|
||||||
|
}
|
||||||
|
return ports, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrLineNumberIsZero = errors.New("line number is zero")
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseLineNumber(s string) (n uint16, err error) {
|
||||||
|
const base, bitLength = 10, 16
|
||||||
|
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
} else if lineNumber == 0 {
|
||||||
|
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||||
|
}
|
||||||
|
return uint16(lineNumber), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTargetUnknown = errors.New("unknown target")
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkTarget(target string) (err error) {
|
||||||
|
switch target {
|
||||||
|
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrProtocolUnknown = errors.New("unknown protocol")
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseProtocol(s string) (protocol string, err error) {
|
||||||
|
switch s {
|
||||||
|
case "0":
|
||||||
|
case "6":
|
||||||
|
protocol = "tcp"
|
||||||
|
case "17":
|
||||||
|
protocol = "udp"
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||||
|
}
|
||||||
|
return protocol, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||||
|
// returns the raw integer matching it.
|
||||||
|
func parseMetricSize(size string) (n uint64, err error) {
|
||||||
|
if size == "" {
|
||||||
|
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gomnd
|
||||||
|
multiplerLetterToValue := map[byte]uint64{
|
||||||
|
'K': 1000,
|
||||||
|
'M': 1000000,
|
||||||
|
'G': 1000000000,
|
||||||
|
'T': 1000000000000,
|
||||||
|
}
|
||||||
|
|
||||||
|
lastCharacter := size[len(size)-1]
|
||||||
|
multiplier, ok := multiplerLetterToValue[lastCharacter]
|
||||||
|
if ok { // multiplier present
|
||||||
|
size = size[:len(size)-1]
|
||||||
|
} else {
|
||||||
|
multiplier = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
const base, bitLength = 10, 64
|
||||||
|
n, err = strconv.ParseUint(size, base, bitLength)
|
||||||
|
if err != nil {
|
||||||
|
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||||
|
}
|
||||||
|
n *= multiplier
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
121
internal/firewall/list_test.go
Normal file
121
internal/firewall/list_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseChain(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
iptablesOutput string
|
||||||
|
table chain
|
||||||
|
errWrapped error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"no_output": {
|
||||||
|
errWrapped: ErrChainListMalformed,
|
||||||
|
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||||
|
},
|
||||||
|
"single_line_only": {
|
||||||
|
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||||
|
errWrapped: ErrChainListMalformed,
|
||||||
|
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||||
|
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||||
|
},
|
||||||
|
"malformed_general_data_line": {
|
||||||
|
iptablesOutput: `Chain INPUT
|
||||||
|
num pkts bytes target prot opt in out source destination`,
|
||||||
|
errWrapped: ErrChainListMalformed,
|
||||||
|
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||||
|
"expected 8 fields in \"Chain INPUT\"",
|
||||||
|
},
|
||||||
|
"malformed_legend": {
|
||||||
|
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||||
|
num pkts bytes target prot opt in out source`,
|
||||||
|
errWrapped: ErrChainListMalformed,
|
||||||
|
errMessage: "iptables chain list output is malformed: legend " +
|
||||||
|
"\"num pkts bytes target prot opt in out source\" " +
|
||||||
|
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||||
|
},
|
||||||
|
"no_rule": {
|
||||||
|
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||||
|
num pkts bytes target prot opt in out source destination`,
|
||||||
|
table: chain{
|
||||||
|
name: "INPUT",
|
||||||
|
policy: "ACCEPT",
|
||||||
|
packets: 140000,
|
||||||
|
bytes: 226000000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"some_rules": {
|
||||||
|
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||||
|
num pkts bytes target prot opt in out source destination
|
||||||
|
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||||
|
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||||
|
3 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
|
||||||
|
`,
|
||||||
|
table: chain{
|
||||||
|
name: "INPUT",
|
||||||
|
policy: "ACCEPT",
|
||||||
|
packets: 140000,
|
||||||
|
bytes: 226000000,
|
||||||
|
rules: []chainRule{
|
||||||
|
{
|
||||||
|
lineNumber: 1,
|
||||||
|
packets: 0,
|
||||||
|
bytes: 0,
|
||||||
|
target: "ACCEPT",
|
||||||
|
protocol: "udp",
|
||||||
|
inputInterface: "tun0",
|
||||||
|
outputInterface: "*",
|
||||||
|
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
destinationPort: 55405,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
lineNumber: 2,
|
||||||
|
packets: 0,
|
||||||
|
bytes: 0,
|
||||||
|
target: "ACCEPT",
|
||||||
|
protocol: "tcp",
|
||||||
|
inputInterface: "tun0",
|
||||||
|
outputInterface: "*",
|
||||||
|
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
destinationPort: 55405,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
lineNumber: 3,
|
||||||
|
packets: 0,
|
||||||
|
bytes: 0,
|
||||||
|
target: "DROP",
|
||||||
|
protocol: "",
|
||||||
|
inputInterface: "tun0",
|
||||||
|
outputInterface: "*",
|
||||||
|
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||||
|
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
table, err := parseChain(testCase.iptablesOutput)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.table, table)
|
||||||
|
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||||
|
if testCase.errWrapped != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,12 +5,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Logger interface {
|
|
||||||
Debug(s string)
|
|
||||||
Info(s string)
|
|
||||||
Error(s string)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) {
|
func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) {
|
||||||
c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+
|
c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+
|
||||||
"no default route matching its family", subnet))
|
"no default route matching its family", subnet))
|
||||||
|
|||||||
3
internal/firewall/mocks_generate_test.go
Normal file
3
internal/firewall/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger
|
||||||
109
internal/firewall/mocks_test.go
Normal file
109
internal/firewall/mocks_test.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger)
|
||||||
|
|
||||||
|
// Package firewall is a generated GoMock package.
|
||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
command "github.com/qdm12/golibs/command"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockRunner is a mock of Runner interface.
|
||||||
|
type MockRunner struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockRunnerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
||||||
|
type MockRunnerMockRecorder struct {
|
||||||
|
mock *MockRunner
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockRunner creates a new mock instance.
|
||||||
|
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
||||||
|
mock := &MockRunner{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockRunnerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run mocks base method.
|
||||||
|
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Run", arg0)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run indicates an expected call of Run.
|
||||||
|
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLogger is a mock of Logger interface.
|
||||||
|
type MockLogger struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockLoggerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||||
|
type MockLoggerMockRecorder struct {
|
||||||
|
mock *MockLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockLogger creates a new mock instance.
|
||||||
|
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||||
|
mock := &MockLogger{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug mocks base method.
|
||||||
|
func (m *MockLogger) Debug(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Debug", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug indicates an expected call of Debug.
|
||||||
|
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error mocks base method.
|
||||||
|
func (m *MockLogger) Error(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Error", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error indicates an expected call of Error.
|
||||||
|
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info mocks base method.
|
||||||
|
func (m *MockLogger) Info(arg0 string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "Info", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info indicates an expected call of Info.
|
||||||
|
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||||
|
}
|
||||||
166
internal/firewall/parse.go
Normal file
166
internal/firewall/parse.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type iptablesInstruction struct {
|
||||||
|
table string // defaults to "filter", and can be "nat" for example.
|
||||||
|
append bool
|
||||||
|
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||||
|
target string // for example ACCEPT. Can be empty.
|
||||||
|
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||||
|
inputInterface string // for example "tun0" or "" for any interface.
|
||||||
|
outputInterface string // for example "tun0" or "" for any interface.
|
||||||
|
source netip.Prefix // if not valid, then it is unspecified.
|
||||||
|
destination netip.Prefix // if not valid, then it is unspecified.
|
||||||
|
destinationPort uint16 // if zero, there is no destination port
|
||||||
|
toPorts []uint16 // if empty, there is no redirection
|
||||||
|
ctstate []string // if empty, there is no ctstate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iptablesInstruction) setDefaults() {
|
||||||
|
if i.table == "" {
|
||||||
|
i.table = "filter"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
|
||||||
|
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
|
||||||
|
switch {
|
||||||
|
case i.table != table:
|
||||||
|
return false
|
||||||
|
case i.chain != chain:
|
||||||
|
return false
|
||||||
|
case i.target != rule.target:
|
||||||
|
return false
|
||||||
|
case i.protocol != rule.protocol:
|
||||||
|
return false
|
||||||
|
case i.destinationPort != rule.destinationPort:
|
||||||
|
return false
|
||||||
|
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||||
|
return false
|
||||||
|
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||||
|
return false
|
||||||
|
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
|
||||||
|
return false
|
||||||
|
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
|
||||||
|
return false
|
||||||
|
case !ipPrefixesEqual(i.source, rule.source):
|
||||||
|
return false
|
||||||
|
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||||
|
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||||
|
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||||
|
return instruction == chainRule ||
|
||||||
|
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||||
|
if s == "" {
|
||||||
|
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||||
|
}
|
||||||
|
fields := strings.Fields(s)
|
||||||
|
if len(fields)%2 != 0 {
|
||||||
|
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
||||||
|
ErrIptablesCommandMalformed, len(fields), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(fields); i += 2 {
|
||||||
|
key := fields[i]
|
||||||
|
value := fields[i+1]
|
||||||
|
err = parseInstructionFlag(key, value, &instruction)
|
||||||
|
if err != nil {
|
||||||
|
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
instruction.setDefaults()
|
||||||
|
return instruction, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
||||||
|
switch key {
|
||||||
|
case "-t", "--table":
|
||||||
|
instruction.table = value
|
||||||
|
case "-D", "--delete":
|
||||||
|
instruction.append = false
|
||||||
|
instruction.chain = value
|
||||||
|
case "-A", "--append":
|
||||||
|
instruction.append = true
|
||||||
|
instruction.chain = value
|
||||||
|
case "-j", "--jump":
|
||||||
|
instruction.target = value
|
||||||
|
case "-p", "--protocol":
|
||||||
|
instruction.protocol = value
|
||||||
|
case "-m", "--match": // ignore match
|
||||||
|
case "-i", "--in-interface":
|
||||||
|
instruction.inputInterface = value
|
||||||
|
case "-o", "--out-interface":
|
||||||
|
instruction.outputInterface = value
|
||||||
|
case "-s", "--source":
|
||||||
|
instruction.source, err = parseIPPrefix(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||||
|
}
|
||||||
|
case "-d", "--destination":
|
||||||
|
instruction.destination, err = parseIPPrefix(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||||
|
}
|
||||||
|
case "--dport":
|
||||||
|
const base, bitLength = 10, 16
|
||||||
|
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing destination port: %w", err)
|
||||||
|
}
|
||||||
|
instruction.destinationPort = uint16(destinationPort)
|
||||||
|
case "--ctstate":
|
||||||
|
instruction.ctstate = strings.Split(value, ",")
|
||||||
|
case "--to-ports":
|
||||||
|
portStrings := strings.Split(value, ",")
|
||||||
|
instruction.toPorts = make([]uint16, len(portStrings))
|
||||||
|
for i, portString := range portStrings {
|
||||||
|
const base, bitLength = 10, 16
|
||||||
|
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing port redirection: %w", err)
|
||||||
|
}
|
||||||
|
instruction.toPorts[i] = uint16(port)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||||
|
slashIndex := strings.Index(value, "/")
|
||||||
|
if slashIndex >= 0 {
|
||||||
|
return netip.ParsePrefix(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := netip.ParseAddr(value)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
|
||||||
|
}
|
||||||
|
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||||
|
}
|
||||||
138
internal/firewall/parse_test.go
Normal file
138
internal/firewall/parse_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseIptablesInstruction(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
s string
|
||||||
|
instruction iptablesInstruction
|
||||||
|
errWrapped error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"no_instruction": {
|
||||||
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
|
errMessage: "iptables command is malformed: empty instruction",
|
||||||
|
},
|
||||||
|
"uneven_fields": {
|
||||||
|
s: "-A",
|
||||||
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
|
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
||||||
|
},
|
||||||
|
"unknown_key": {
|
||||||
|
s: "-x something",
|
||||||
|
errWrapped: ErrIptablesCommandMalformed,
|
||||||
|
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||||
|
},
|
||||||
|
"one_pair": {
|
||||||
|
s: "-A INPUT",
|
||||||
|
instruction: iptablesInstruction{
|
||||||
|
table: "filter",
|
||||||
|
chain: "INPUT",
|
||||||
|
append: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"instruction_A": {
|
||||||
|
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
|
||||||
|
instruction: iptablesInstruction{
|
||||||
|
table: "filter",
|
||||||
|
chain: "INPUT",
|
||||||
|
append: true,
|
||||||
|
inputInterface: "tun0",
|
||||||
|
protocol: "tcp",
|
||||||
|
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||||
|
destination: netip.MustParsePrefix("5.6.7.8/32"),
|
||||||
|
destinationPort: 10000,
|
||||||
|
target: "ACCEPT",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"nat_redirection": {
|
||||||
|
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||||
|
instruction: iptablesInstruction{
|
||||||
|
table: "nat",
|
||||||
|
chain: "PREROUTING",
|
||||||
|
append: false,
|
||||||
|
inputInterface: "tun0",
|
||||||
|
protocol: "tcp",
|
||||||
|
destinationPort: 43716,
|
||||||
|
target: "REDIRECT",
|
||||||
|
toPorts: []uint16{5678},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rule, err := parseIptablesInstruction(testCase.s)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.instruction, rule)
|
||||||
|
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||||
|
if testCase.errWrapped != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_parseIPPrefix(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
value string
|
||||||
|
prefix netip.Prefix
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"empty": {
|
||||||
|
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
|
||||||
|
},
|
||||||
|
"invalid": {
|
||||||
|
value: "invalid",
|
||||||
|
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
|
||||||
|
},
|
||||||
|
"valid_ipv4_with_bits": {
|
||||||
|
value: "10.0.0.0/16",
|
||||||
|
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
|
||||||
|
},
|
||||||
|
"valid_ipv4_without_bits": {
|
||||||
|
value: "10.0.0.4",
|
||||||
|
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
|
||||||
|
},
|
||||||
|
"valid_ipv6_with_bits": {
|
||||||
|
value: "2001:db8::/32",
|
||||||
|
prefix: netip.PrefixFrom(
|
||||||
|
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||||
|
32),
|
||||||
|
},
|
||||||
|
"valid_ipv6_without_bits": {
|
||||||
|
value: "2001:db8::",
|
||||||
|
prefix: netip.PrefixFrom(
|
||||||
|
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||||
|
128),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
prefix, err := parseIPPrefix(testCase.value)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.prefix, prefix)
|
||||||
|
if testCase.errMessage != "" {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: github.com/qdm12/golibs/command (interfaces: Runner)
|
|
||||||
|
|
||||||
// Package firewall is a generated GoMock package.
|
|
||||||
package firewall
|
|
||||||
|
|
||||||
import (
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
command "github.com/qdm12/golibs/command"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockRunner is a mock of Runner interface.
|
|
||||||
type MockRunner struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockRunnerMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
|
||||||
type MockRunnerMockRecorder struct {
|
|
||||||
mock *MockRunner
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockRunner creates a new mock instance.
|
|
||||||
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
|
||||||
mock := &MockRunner{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockRunnerMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|
||||||
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run mocks base method.
|
|
||||||
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "Run", arg0)
|
|
||||||
ret0, _ := ret[0].(string)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run indicates an expected call of Run.
|
|
||||||
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
|
||||||
}
|
|
||||||
@@ -11,8 +11,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner
|
|
||||||
|
|
||||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||||
return newCmdMatcher(path,
|
return newCmdMatcher(path,
|
||||||
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
@@ -58,9 +59,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
|||||||
|
|
||||||
servers = make([]models.Server, 0, len(hostToIPs))
|
servers = make([]models.Server, 0, len(hostToIPs))
|
||||||
for _, serverData := range data.Servers {
|
for _, serverData := range data.Servers {
|
||||||
|
city, region := parseCity(serverData.City)
|
||||||
server := models.Server{
|
server := models.Server{
|
||||||
Country: serverData.Country,
|
Country: serverData.Country,
|
||||||
City: serverData.City,
|
City: city,
|
||||||
|
Region: region,
|
||||||
ISP: serverData.ISP,
|
ISP: serverData.ISP,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,3 +99,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
|||||||
|
|
||||||
return servers, nil
|
return servers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseCity(city string) (parsedCity, region string) {
|
||||||
|
commaIndex := strings.Index(city, ", ")
|
||||||
|
if commaIndex == -1 {
|
||||||
|
return city, ""
|
||||||
|
}
|
||||||
|
return city[:commaIndex], city[commaIndex+2:]
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check out the JSON data from https://api.nordvpn.com/v2/servers?limit=10
|
// Check out the JSON data from https://api.nordvpn.com/v2/servers?limit=10
|
||||||
@@ -92,6 +93,9 @@ func (s serversData) idToData() (
|
|||||||
) {
|
) {
|
||||||
groups = make(map[uint32]groupData, len(s.Groups))
|
groups = make(map[uint32]groupData, len(s.Groups))
|
||||||
for _, group := range s.Groups {
|
for _, group := range s.Groups {
|
||||||
|
if group.Type.Identifier == "regions" { //nolint:goconst
|
||||||
|
group.Title = strings.ReplaceAll(group.Title, ",", "")
|
||||||
|
}
|
||||||
groups[group.ID] = group
|
groups[group.ID] = group
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
|
|||||||
|
|
||||||
server := models.Server{
|
server := models.Server{
|
||||||
Country: location.Country.Name,
|
Country: location.Country.Name,
|
||||||
Region: jsonServer.region(groups),
|
Region: region,
|
||||||
City: location.Country.City.Name,
|
City: location.Country.City.Name,
|
||||||
Categories: jsonServer.categories(groups),
|
Categories: jsonServer.categories(groups),
|
||||||
Hostname: jsonServer.Hostname,
|
Hostname: jsonServer.Hostname,
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
serverName := objects.ServerName
|
serverName := objects.ServerName
|
||||||
|
apiIP := buildAPIIPAddress(objects.Gateway)
|
||||||
logger := objects.Logger
|
logger := objects.Logger
|
||||||
|
|
||||||
if !objects.CanPortForward {
|
if !objects.CanPortForward {
|
||||||
@@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
|||||||
|
|
||||||
if !dataFound || expired {
|
if !dataFound || expired {
|
||||||
client := objects.Client
|
client := objects.Client
|
||||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
|
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP,
|
||||||
p.portForwardPath, objects.Username, objects.Password)
|
p.portForwardPath, objects.Username, objects.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("refreshing port forward data: %w", err)
|
return nil, fmt.Errorf("refreshing port forward data: %w", err)
|
||||||
@@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
|||||||
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
||||||
|
|
||||||
// First time binding
|
// First time binding
|
||||||
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
|
if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil {
|
||||||
return nil, fmt.Errorf("binding port: %w", err)
|
return nil, fmt.Errorf("binding port: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,6 +100,8 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
|||||||
panic("gateway is not set")
|
panic("gateway is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apiIP := buildAPIIPAddress(objects.Gateway)
|
||||||
|
|
||||||
privateIPClient, err := newHTTPClient(objects.ServerName)
|
privateIPClient, err := newHTTPClient(objects.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating custom HTTP client: %w", err)
|
return fmt.Errorf("creating custom HTTP client: %w", err)
|
||||||
@@ -127,7 +129,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
|||||||
}
|
}
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-keepAliveTimer.C:
|
case <-keepAliveTimer.C:
|
||||||
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
|
err = bindPort(ctx, privateIPClient, apiIP, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("binding port: %w", err)
|
return fmt.Errorf("binding port: %w", err)
|
||||||
}
|
}
|
||||||
@@ -139,14 +141,25 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) {
|
||||||
|
if gateway.Is6() {
|
||||||
|
panic("IPv6 gateway not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayBytes := gateway.As4()
|
||||||
|
gatewayBytes[2] = 128
|
||||||
|
gatewayBytes[3] = 1
|
||||||
|
return netip.AddrFrom4(gatewayBytes)
|
||||||
|
}
|
||||||
|
|
||||||
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||||
gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
|
apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
|
||||||
data.Token, err = fetchToken(ctx, client, username, password)
|
data.Token, err = fetchToken(ctx, client, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("fetching token: %w", err)
|
return data, fmt.Errorf("fetching token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("fetching port forwarding data: %w", err)
|
return data, fmt.Errorf("fetching port forwarding data: %w", err)
|
||||||
}
|
}
|
||||||
@@ -286,7 +299,7 @@ func fetchToken(ctx context.Context, client *http.Client,
|
|||||||
return result.Token, nil
|
return result.Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
|
func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) (
|
||||||
port uint16, signature string, expiration time.Time, err error) {
|
port uint16, signature string, expiration time.Time, err error) {
|
||||||
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
|
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
|
||||||
|
|
||||||
@@ -294,7 +307,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway neti
|
|||||||
queryParams.Add("token", token)
|
queryParams.Add("token", token)
|
||||||
url := url.URL{
|
url := url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
Host: net.JoinHostPort(apiIP.String(), "19999"),
|
||||||
Path: "/getSignature",
|
Path: "/getSignature",
|
||||||
RawQuery: queryParams.Encode(),
|
RawQuery: queryParams.Encode(),
|
||||||
}
|
}
|
||||||
@@ -340,7 +353,7 @@ var (
|
|||||||
ErrBadResponse = errors.New("bad response received")
|
ErrBadResponse = errors.New("bad response received")
|
||||||
)
|
)
|
||||||
|
|
||||||
func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
|
func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
|
||||||
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("serializing payload: %w", err)
|
return fmt.Errorf("serializing payload: %w", err)
|
||||||
@@ -351,7 +364,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
|
|||||||
queryParams.Add("signature", data.Signature)
|
queryParams.Add("signature", data.Signature)
|
||||||
bindPortURL := url.URL{
|
bindPortURL := url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
Host: net.JoinHostPort(apiIPAddress.String(), "19999"),
|
||||||
Path: "/bindPort",
|
Path: "/bindPort",
|
||||||
RawQuery: queryParams.Encode(),
|
RawQuery: queryParams.Encode(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,17 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||||
|
"github.com/qdm12/gluetun/internal/server/middlewares/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||||
|
authSettings auth.Settings,
|
||||||
buildInfo models.BuildInformation,
|
buildInfo models.BuildInformation,
|
||||||
vpnLooper VPNLooper,
|
vpnLooper VPNLooper,
|
||||||
pfGetter PortForwardedGetter,
|
pfGetter PortForwardedGetter,
|
||||||
@@ -17,7 +21,7 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
|||||||
publicIPLooper PublicIPLoop,
|
publicIPLooper PublicIPLoop,
|
||||||
storage Storage,
|
storage Storage,
|
||||||
ipv6Supported bool,
|
ipv6Supported bool,
|
||||||
) http.Handler {
|
) (httpHandler http.Handler, err error) {
|
||||||
handler := &handler{}
|
handler := &handler{}
|
||||||
|
|
||||||
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
|
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
|
||||||
@@ -29,16 +33,25 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
|||||||
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
|
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
|
||||||
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
|
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
|
||||||
|
|
||||||
handlerWithLog := withLogMiddleware(handler, logger, logging)
|
authMiddleware, err := auth.New(authSettings, logger)
|
||||||
handler.setLogEnabled = handlerWithLog.setEnabled
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating auth middleware: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return handlerWithLog
|
middlewares := []func(http.Handler) http.Handler{
|
||||||
|
authMiddleware,
|
||||||
|
log.New(logger, logging),
|
||||||
|
}
|
||||||
|
httpHandler = handler
|
||||||
|
for _, middleware := range middlewares {
|
||||||
|
httpHandler = middleware(httpHandler)
|
||||||
|
}
|
||||||
|
return httpHandler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type handler struct {
|
type handler struct {
|
||||||
v0 http.Handler
|
v0 http.Handler
|
||||||
v1 http.Handler
|
v1 http.Handler
|
||||||
setLogEnabled func(enabled bool)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
type Logger interface {
|
type Logger interface {
|
||||||
|
Debugf(format string, args ...any)
|
||||||
infoer
|
infoer
|
||||||
warner
|
warner
|
||||||
|
Warnf(format string, args ...any)
|
||||||
errorer
|
errorer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
36
internal/server/middlewares/auth/apikey.go
Normal file
36
internal/server/middlewares/auth/apikey.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiKeyMethod struct {
|
||||||
|
apiKeyDigest [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAPIKeyMethod(apiKey string) *apiKeyMethod {
|
||||||
|
return &apiKeyMethod{
|
||||||
|
apiKeyDigest: sha256.Sum256([]byte(apiKey)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal returns true if another auth checker is equal.
|
||||||
|
// This is used to deduplicate checkers for a particular route.
|
||||||
|
func (a *apiKeyMethod) equal(other authorizationChecker) bool {
|
||||||
|
otherTokenMethod, ok := other.(*apiKeyMethod)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.apiKeyDigest == otherTokenMethod.apiKeyDigest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool {
|
||||||
|
xAPIKey := request.Header.Get("X-API-Key")
|
||||||
|
if xAPIKey == "" {
|
||||||
|
xAPIKey = request.URL.Query().Get("api_key")
|
||||||
|
}
|
||||||
|
xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey))
|
||||||
|
return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1
|
||||||
|
}
|
||||||
37
internal/server/middlewares/auth/basic.go
Normal file
37
internal/server/middlewares/auth/basic.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type basicAuthMethod struct {
|
||||||
|
authDigest [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBasicAuthMethod(username, password string) *basicAuthMethod {
|
||||||
|
return &basicAuthMethod{
|
||||||
|
authDigest: sha256.Sum256([]byte(username + password)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal returns true if another auth checker is equal.
|
||||||
|
// This is used to deduplicate checkers for a particular route.
|
||||||
|
func (a *basicAuthMethod) equal(other authorizationChecker) bool {
|
||||||
|
otherBasicMethod, ok := other.(*basicAuthMethod)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.authDigest == otherBasicMethod.authDigest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool {
|
||||||
|
username, password, ok := request.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
requestAuthDigest := sha256.Sum256([]byte(username + password))
|
||||||
|
return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1
|
||||||
|
}
|
||||||
35
internal/server/middlewares/auth/configfile.go
Normal file
35
internal/server/middlewares/auth/configfile.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/pelletier/go-toml/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read reads the toml file specified by the filepath given.
|
||||||
|
// If the file does not exist, it returns empty settings and no error.
|
||||||
|
func Read(filepath string) (settings Settings, err error) {
|
||||||
|
file, err := os.Open(filepath)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return Settings{}, nil
|
||||||
|
}
|
||||||
|
return settings, fmt.Errorf("opening file: %w", err)
|
||||||
|
}
|
||||||
|
decoder := toml.NewDecoder(file)
|
||||||
|
decoder.DisallowUnknownFields()
|
||||||
|
err = decoder.Decode(&settings)
|
||||||
|
if err == nil {
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
strictErr := new(toml.StrictMissingError)
|
||||||
|
ok := errors.As(err, &strictErr)
|
||||||
|
if !ok {
|
||||||
|
return settings, fmt.Errorf("toml decoding file: %w", err)
|
||||||
|
}
|
||||||
|
return settings, fmt.Errorf("toml decoding file: %w:\n%s",
|
||||||
|
strictErr, strictErr.String())
|
||||||
|
}
|
||||||
80
internal/server/middlewares/auth/configfile_test.go
Normal file
80
internal/server/middlewares/auth/configfile_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read reads the toml file specified by the filepath given.
|
||||||
|
func Test_Read(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
fileContent string
|
||||||
|
settings Settings
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"empty_file": {},
|
||||||
|
"malformed_toml": {
|
||||||
|
fileContent: "this is not a toml file",
|
||||||
|
errMessage: `toml decoding file: toml: expected character =`,
|
||||||
|
},
|
||||||
|
"unknown_field": {
|
||||||
|
fileContent: `unknown = "what is this"`,
|
||||||
|
errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct:
|
||||||
|
1| unknown = "what is this"
|
||||||
|
| ~~~~~~~ missing field`,
|
||||||
|
},
|
||||||
|
"filled_settings": {
|
||||||
|
fileContent: `[[roles]]
|
||||||
|
name = "public"
|
||||||
|
auth = "none"
|
||||||
|
routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"]
|
||||||
|
|
||||||
|
[[roles]]
|
||||||
|
name = "client"
|
||||||
|
auth = "apikey"
|
||||||
|
apikey = "xyz"
|
||||||
|
routes = ["GET /v1/vpn/status"]
|
||||||
|
`,
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{{
|
||||||
|
Name: "public",
|
||||||
|
Auth: AuthNone,
|
||||||
|
Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"},
|
||||||
|
}, {
|
||||||
|
Name: "client",
|
||||||
|
Auth: AuthAPIKey,
|
||||||
|
APIKey: "xyz",
|
||||||
|
Routes: []string{"GET /v1/vpn/status"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
filepath := tempDir + "/config.toml"
|
||||||
|
const permissions fs.FileMode = 0600
|
||||||
|
err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
settings, err := Read(filepath)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.settings, settings)
|
||||||
|
if testCase.errMessage != "" {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
22
internal/server/middlewares/auth/format.go
Normal file
22
internal/server/middlewares/auth/format.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
func andStrings(strings []string) (result string) {
|
||||||
|
return joinStrings(strings, "and")
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinStrings(strings []string, lastJoin string) (result string) {
|
||||||
|
if len(strings) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
result = strings[0]
|
||||||
|
for i := 1; i < len(strings); i++ {
|
||||||
|
if i < len(strings)-1 {
|
||||||
|
result += ", " + strings[i]
|
||||||
|
} else {
|
||||||
|
result += " " + lastJoin + " " + strings[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
6
internal/server/middlewares/auth/interfaces.go
Normal file
6
internal/server/middlewares/auth/interfaces.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
type DebugLogger interface {
|
||||||
|
Debugf(format string, args ...any)
|
||||||
|
Warnf(format string, args ...any)
|
||||||
|
}
|
||||||
8
internal/server/middlewares/auth/interfaces_local.go
Normal file
8
internal/server/middlewares/auth/interfaces_local.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type authorizationChecker interface {
|
||||||
|
equal(other authorizationChecker) bool
|
||||||
|
isAuthorized(headers http.Header, request *http.Request) bool
|
||||||
|
}
|
||||||
47
internal/server/middlewares/auth/lookup.go
Normal file
47
internal/server/middlewares/auth/lookup.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type internalRole struct {
|
||||||
|
name string
|
||||||
|
checker authorizationChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) {
|
||||||
|
routeToRoles = make(map[string][]internalRole)
|
||||||
|
for _, role := range settings.Roles {
|
||||||
|
var checker authorizationChecker
|
||||||
|
switch role.Auth {
|
||||||
|
case AuthNone:
|
||||||
|
checker = newNoneMethod()
|
||||||
|
case AuthAPIKey:
|
||||||
|
checker = newAPIKeyMethod(role.APIKey)
|
||||||
|
case AuthBasic:
|
||||||
|
checker = newBasicAuthMethod(role.Username, role.Password)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
iRole := internalRole{
|
||||||
|
name: role.Name,
|
||||||
|
checker: checker,
|
||||||
|
}
|
||||||
|
for _, route := range role.Routes {
|
||||||
|
checkerExists := false
|
||||||
|
for _, role := range routeToRoles[route] {
|
||||||
|
if role.checker.equal(iRole.checker) {
|
||||||
|
checkerExists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if checkerExists {
|
||||||
|
// even if the role name is different, if the checker is the same, skip it.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
routeToRoles[route] = append(routeToRoles[route], iRole)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return routeToRoles, nil
|
||||||
|
}
|
||||||
60
internal/server/middlewares/auth/lookup_test.go
Normal file
60
internal/server/middlewares/auth/lookup_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read reads the toml file specified by the filepath given.
|
||||||
|
func Test_settingsToLookupMap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
settings Settings
|
||||||
|
routeToRoles map[string][]internalRole
|
||||||
|
errWrapped error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"empty_settings": {
|
||||||
|
routeToRoles: map[string][]internalRole{},
|
||||||
|
},
|
||||||
|
"auth_method_not_supported": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{{Name: "a", Auth: "bad"}},
|
||||||
|
},
|
||||||
|
errWrapped: ErrMethodNotSupported,
|
||||||
|
errMessage: "authentication method not supported: bad",
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{
|
||||||
|
{Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}},
|
||||||
|
{Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
routeToRoles: map[string][]internalRole{
|
||||||
|
"GET /path": {
|
||||||
|
{name: "a", checker: newNoneMethod()}, // deduplicated method
|
||||||
|
},
|
||||||
|
"PUT /path": {
|
||||||
|
{name: "b", checker: newNoneMethod()},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
routeToRoles, err := settingsToLookupMap(testCase.settings)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.routeToRoles, routeToRoles)
|
||||||
|
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||||
|
if testCase.errWrapped != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
111
internal/server/middlewares/auth/middleware.go
Normal file
111
internal/server/middlewares/auth/middleware.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(settings Settings, debugLogger DebugLogger) (
|
||||||
|
middleware func(http.Handler) http.Handler,
|
||||||
|
err error) {
|
||||||
|
routeToRoles, err := settingsToLookupMap(settings)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("converting settings to lookup maps: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:goconst
|
||||||
|
return func(handler http.Handler) http.Handler {
|
||||||
|
return &authHandler{
|
||||||
|
childHandler: handler,
|
||||||
|
routeToRoles: routeToRoles,
|
||||||
|
unprotectedRoutes: map[string]struct{}{
|
||||||
|
http.MethodGet + " /openvpn/actions/restart": {},
|
||||||
|
http.MethodGet + " /unbound/actions/restart": {},
|
||||||
|
http.MethodGet + " /updater/restart": {},
|
||||||
|
http.MethodGet + " /v1/version": {},
|
||||||
|
http.MethodGet + " /v1/vpn/status": {},
|
||||||
|
http.MethodPut + " /v1/vpn/status": {},
|
||||||
|
// GET /v1/vpn/settings is protected by default
|
||||||
|
// PUT /v1/vpn/settings is protected by default
|
||||||
|
http.MethodGet + " /v1/openvpn/status": {},
|
||||||
|
http.MethodPut + " /v1/openvpn/status": {},
|
||||||
|
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||||
|
// GET /v1/openvpn/settings is protected by default
|
||||||
|
http.MethodGet + " /v1/dns/status": {},
|
||||||
|
http.MethodPut + " /v1/dns/status": {},
|
||||||
|
http.MethodGet + " /v1/updater/status": {},
|
||||||
|
http.MethodPut + " /v1/updater/status": {},
|
||||||
|
http.MethodGet + " /v1/publicip/ip": {},
|
||||||
|
},
|
||||||
|
logger: debugLogger,
|
||||||
|
}
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type authHandler struct {
|
||||||
|
childHandler http.Handler
|
||||||
|
routeToRoles map[string][]internalRole
|
||||||
|
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove
|
||||||
|
logger DebugLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
route := request.Method + " " + request.URL.Path
|
||||||
|
roles := h.routeToRoles[route]
|
||||||
|
if len(roles) == 0 {
|
||||||
|
h.logger.Debugf("no authentication role defined for route %s", route)
|
||||||
|
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responseHeader := make(http.Header, 0)
|
||||||
|
for _, role := range roles {
|
||||||
|
if !role.checker.isAuthorized(responseHeader, request) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
|
||||||
|
|
||||||
|
h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
|
||||||
|
h.childHandler.ServeHTTP(writer, request)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush out response headers if all roles failed to authenticate
|
||||||
|
for headerKey, headerValues := range responseHeader {
|
||||||
|
for _, headerValue := range headerValues {
|
||||||
|
writer.Header().Add(headerKey, headerValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allRoleNames := make([]string, len(roles))
|
||||||
|
for i, role := range roles {
|
||||||
|
allRoleNames[i] = role.name
|
||||||
|
}
|
||||||
|
h.logger.Debugf("access to route %s unauthorized after checking for roles %s",
|
||||||
|
route, andStrings(allRoleNames))
|
||||||
|
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
|
||||||
|
// TODO v3.41.0 remove
|
||||||
|
if role.name != "public" {
|
||||||
|
// custom role name, allow none authentication to be specified
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, isNoneChecker := role.checker.(*noneMethod)
|
||||||
|
if !isNoneChecker {
|
||||||
|
// not the none authentication method
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
|
||||||
|
if !isUnprotectedByDefault {
|
||||||
|
// route is not unprotected by default, so this is a user decision
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Warnf("route %s is unprotected by default, "+
|
||||||
|
"please set up authentication following the documentation at "+
|
||||||
|
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||||
|
"since this will become no longer publicly accessible after release v3.40.",
|
||||||
|
route)
|
||||||
|
}
|
||||||
124
internal/server/middlewares/auth/middleware_test.go
Normal file
124
internal/server/middlewares/auth/middleware_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_authHandler_ServeHTTP(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
settings Settings
|
||||||
|
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
|
||||||
|
requestMethod string
|
||||||
|
requestPath string
|
||||||
|
statusCode int
|
||||||
|
responseBody string
|
||||||
|
}{
|
||||||
|
"route_has_no_role": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{
|
||||||
|
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
|
logger := NewMockDebugLogger(ctrl)
|
||||||
|
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
requestMethod: http.MethodGet,
|
||||||
|
requestPath: "/b",
|
||||||
|
statusCode: http.StatusUnauthorized,
|
||||||
|
responseBody: "Unauthorized\n",
|
||||||
|
},
|
||||||
|
"authorized_unprotected_by_default": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{
|
||||||
|
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
|
logger := NewMockDebugLogger(ctrl)
|
||||||
|
logger.EXPECT().Warnf("route %s is unprotected by default, "+
|
||||||
|
"please set up authentication following the documentation at "+
|
||||||
|
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||||
|
"since this will become no longer publicly accessible after release v3.40.",
|
||||||
|
"GET /v1/vpn/status")
|
||||||
|
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||||
|
"GET /v1/vpn/status", "public")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
requestMethod: http.MethodGet,
|
||||||
|
requestPath: "/v1/vpn/status",
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
"authorized_none": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{
|
||||||
|
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
|
logger := NewMockDebugLogger(ctrl)
|
||||||
|
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||||
|
"GET /a", "role1")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
requestMethod: http.MethodGet,
|
||||||
|
requestPath: "/a",
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
|
var debugLogger DebugLogger
|
||||||
|
if testCase.makeLogger != nil {
|
||||||
|
debugLogger = testCase.makeLogger(ctrl)
|
||||||
|
}
|
||||||
|
middleware, err := New(testCase.settings, debugLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
handler := middleware(childHandler)
|
||||||
|
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
t.Cleanup(server.Close)
|
||||||
|
|
||||||
|
client := server.Client()
|
||||||
|
|
||||||
|
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
request, err := http.NewRequestWithContext(context.Background(),
|
||||||
|
testCase.requestMethod, requestURL, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
response, err := client.Do(request)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = response.Body.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.statusCode, response.StatusCode)
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testCase.responseBody, string(body))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
3
internal/server/middlewares/auth/mocks_generate_test.go
Normal file
3
internal/server/middlewares/auth/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger
|
||||||
68
internal/server/middlewares/auth/mocks_test.go
Normal file
68
internal/server/middlewares/auth/mocks_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger)
|
||||||
|
|
||||||
|
// Package auth is a generated GoMock package.
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockDebugLogger is a mock of DebugLogger interface.
|
||||||
|
type MockDebugLogger struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockDebugLoggerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger.
|
||||||
|
type MockDebugLoggerMockRecorder struct {
|
||||||
|
mock *MockDebugLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockDebugLogger creates a new mock instance.
|
||||||
|
func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
|
mock := &MockDebugLogger{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockDebugLoggerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf mocks base method.
|
||||||
|
func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Debugf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf indicates an expected call of Debugf.
|
||||||
|
func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf mocks base method.
|
||||||
|
func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Warnf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf indicates an expected call of Warnf.
|
||||||
|
func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...)
|
||||||
|
}
|
||||||
20
internal/server/middlewares/auth/none.go
Normal file
20
internal/server/middlewares/auth/none.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type noneMethod struct{}
|
||||||
|
|
||||||
|
func newNoneMethod() *noneMethod {
|
||||||
|
return &noneMethod{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal returns true if another auth checker is equal.
|
||||||
|
// This is used to deduplicate checkers for a particular route.
|
||||||
|
func (n *noneMethod) equal(other authorizationChecker) bool {
|
||||||
|
_, ok := other.(*noneMethod)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
131
internal/server/middlewares/auth/settings.go
Normal file
131
internal/server/middlewares/auth/settings.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/qdm12/gosettings"
|
||||||
|
"github.com/qdm12/gosettings/validate"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Settings struct {
|
||||||
|
// Roles is a list of roles with their associated authentication
|
||||||
|
// and routes.
|
||||||
|
Roles []Role
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Settings) SetDefaults() {
|
||||||
|
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
|
||||||
|
Name: "public",
|
||||||
|
Auth: "none",
|
||||||
|
Routes: []string{
|
||||||
|
http.MethodGet + " /openvpn/actions/restart",
|
||||||
|
http.MethodGet + " /unbound/actions/restart",
|
||||||
|
http.MethodGet + " /updater/restart",
|
||||||
|
http.MethodGet + " /v1/version",
|
||||||
|
http.MethodGet + " /v1/vpn/status",
|
||||||
|
http.MethodPut + " /v1/vpn/status",
|
||||||
|
http.MethodGet + " /v1/openvpn/status",
|
||||||
|
http.MethodPut + " /v1/openvpn/status",
|
||||||
|
http.MethodGet + " /v1/openvpn/portforwarded",
|
||||||
|
http.MethodGet + " /v1/dns/status",
|
||||||
|
http.MethodPut + " /v1/dns/status",
|
||||||
|
http.MethodGet + " /v1/updater/status",
|
||||||
|
http.MethodPut + " /v1/updater/status",
|
||||||
|
http.MethodGet + " /v1/publicip/ip",
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Settings) Validate() (err error) {
|
||||||
|
for i, role := range s.Roles {
|
||||||
|
err = role.validate()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("role %s (%d of %d): %w",
|
||||||
|
role.Name, i+1, len(s.Roles), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthNone = "none"
|
||||||
|
AuthAPIKey = "apikey"
|
||||||
|
AuthBasic = "basic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Role contains the role name, authentication method name and
|
||||||
|
// routes that the role can access.
|
||||||
|
type Role struct {
|
||||||
|
// Name is the role name and is only used for documentation
|
||||||
|
// and in the authentication middleware debug logs.
|
||||||
|
Name string
|
||||||
|
// Auth is the authentication method to use, which can be 'none' or 'apikey'.
|
||||||
|
Auth string
|
||||||
|
// APIKey is the API key to use when using the 'apikey' authentication.
|
||||||
|
APIKey string
|
||||||
|
// Username for HTTP Basic authentication method.
|
||||||
|
Username string
|
||||||
|
// Password for HTTP Basic authentication method.
|
||||||
|
Password string
|
||||||
|
// Routes is a list of routes that the role can access in the format
|
||||||
|
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
|
||||||
|
Routes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMethodNotSupported = errors.New("authentication method not supported")
|
||||||
|
ErrAPIKeyEmpty = errors.New("api key is empty")
|
||||||
|
ErrBasicUsernameEmpty = errors.New("username is empty")
|
||||||
|
ErrBasicPasswordEmpty = errors.New("password is empty")
|
||||||
|
ErrRouteNotSupported = errors.New("route not supported by the control server")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r Role) validate() (err error) {
|
||||||
|
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case r.Auth == AuthAPIKey && r.APIKey == "":
|
||||||
|
return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty)
|
||||||
|
case r.Auth == AuthBasic && r.Username == "":
|
||||||
|
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty)
|
||||||
|
case r.Auth == AuthBasic && r.Password == "":
|
||||||
|
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, route := range r.Routes {
|
||||||
|
_, ok := validRoutes[route]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("route %d of %d: %w: %s",
|
||||||
|
i+1, len(r.Routes), ErrRouteNotSupported, route)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WARNING: do not mutate programmatically.
|
||||||
|
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
|
||||||
|
http.MethodGet + " /openvpn/actions/restart": {},
|
||||||
|
http.MethodGet + " /unbound/actions/restart": {},
|
||||||
|
http.MethodGet + " /updater/restart": {},
|
||||||
|
http.MethodGet + " /v1/version": {},
|
||||||
|
http.MethodGet + " /v1/vpn/status": {},
|
||||||
|
http.MethodPut + " /v1/vpn/status": {},
|
||||||
|
http.MethodGet + " /v1/vpn/settings": {},
|
||||||
|
http.MethodPut + " /v1/vpn/settings": {},
|
||||||
|
http.MethodGet + " /v1/openvpn/status": {},
|
||||||
|
http.MethodPut + " /v1/openvpn/status": {},
|
||||||
|
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||||
|
http.MethodGet + " /v1/openvpn/settings": {},
|
||||||
|
http.MethodGet + " /v1/dns/status": {},
|
||||||
|
http.MethodPut + " /v1/dns/status": {},
|
||||||
|
http.MethodGet + " /v1/updater/status": {},
|
||||||
|
http.MethodPut + " /v1/updater/status": {},
|
||||||
|
http.MethodGet + " /v1/publicip/ip": {},
|
||||||
|
}
|
||||||
5
internal/server/middlewares/log/interfaces.go
Normal file
5
internal/server/middlewares/log/interfaces.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package log
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Info(message string)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -7,18 +7,21 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func withLogMiddleware(childHandler http.Handler, logger infoer, enabled bool) *logMiddleware {
|
func New(logger Logger, enabled bool) (
|
||||||
return &logMiddleware{
|
middleware func(http.Handler) http.Handler) {
|
||||||
childHandler: childHandler,
|
return func(handler http.Handler) http.Handler {
|
||||||
logger: logger,
|
return &logMiddleware{
|
||||||
timeNow: time.Now,
|
childHandler: handler,
|
||||||
enabled: enabled,
|
logger: logger,
|
||||||
|
timeNow: time.Now,
|
||||||
|
enabled: enabled,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type logMiddleware struct {
|
type logMiddleware struct {
|
||||||
childHandler http.Handler
|
childHandler http.Handler
|
||||||
logger infoer
|
logger Logger
|
||||||
timeNow func() time.Time
|
timeNow func() time.Time
|
||||||
enabled bool
|
enabled bool
|
||||||
enabledMu sync.RWMutex
|
enabledMu sync.RWMutex
|
||||||
@@ -39,7 +42,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
r.RemoteAddr + " in " + duration.String())
|
r.RemoteAddr + " in " + duration.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *logMiddleware) setEnabled(enabled bool) {
|
func (m *logMiddleware) SetEnabled(enabled bool) {
|
||||||
m.enabledMu.Lock()
|
m.enabledMu.Lock()
|
||||||
defer m.enabledMu.Unlock()
|
defer m.enabledMu.Unlock()
|
||||||
m.enabled = enabled
|
m.enabled = enabled
|
||||||
@@ -6,17 +6,31 @@ import (
|
|||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/httpserver"
|
"github.com/qdm12/gluetun/internal/httpserver"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
|
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
|
||||||
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||||
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
|
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
|
||||||
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
|
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
|
||||||
ipv6Supported bool) (
|
ipv6Supported bool) (
|
||||||
server *httpserver.Server, err error) {
|
server *httpserver.Server, err error) {
|
||||||
handler := newHandler(ctx, logger, logEnabled, buildInfo,
|
authSettings, err := auth.Read(authConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading auth settings: %w", err)
|
||||||
|
}
|
||||||
|
authSettings.SetDefaults()
|
||||||
|
err = authSettings.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validating auth settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
|
||||||
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
|
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
|
||||||
storage, ipv6Supported)
|
storage, ipv6Supported)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating handler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
httpServerSettings := httpserver.Settings{
|
httpServerSettings := httpserver.Settings{
|
||||||
Address: address,
|
Address: address,
|
||||||
|
|||||||
@@ -128,6 +128,31 @@ func noServerFoundError(selection settings.ServerSelection) (err error) {
|
|||||||
messageParts = append(messageParts, "premium tier only")
|
messageParts = append(messageParts, "premium tier only")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *selection.StreamOnly {
|
||||||
|
messageParts = append(messageParts, "stream only")
|
||||||
|
}
|
||||||
|
|
||||||
|
if *selection.MultiHopOnly {
|
||||||
|
messageParts = append(messageParts, "multihop only")
|
||||||
|
}
|
||||||
|
|
||||||
|
if *selection.PortForwardOnly {
|
||||||
|
messageParts = append(messageParts, "port forwarding only")
|
||||||
|
}
|
||||||
|
|
||||||
|
if *selection.SecureCoreOnly {
|
||||||
|
messageParts = append(messageParts, "secure core only")
|
||||||
|
}
|
||||||
|
|
||||||
|
if *selection.TorOnly {
|
||||||
|
messageParts = append(messageParts, "tor only")
|
||||||
|
}
|
||||||
|
|
||||||
|
if selection.TargetIP.IsValid() {
|
||||||
|
messageParts = append(messageParts,
|
||||||
|
"target ip address "+selection.TargetIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
message := "for " + strings.Join(messageParts, "; ")
|
message := "for " + strings.Join(messageParts, "; ")
|
||||||
|
|
||||||
return fmt.Errorf("%w: %s", ErrNoServerFound, message)
|
return fmt.Errorf("%w: %s", ErrNoServerFound, message)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user