Compare commits

..

1 Commits

Author SHA1 Message Date
Quentin McGaw
1a6e8d74d6 wip 2024-08-01 07:51:35 +00:00
77 changed files with 1041 additions and 12783 deletions

View File

@@ -37,12 +37,12 @@
"go.useLanguageServer": true,
"[go]": {
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
"source.organizeImports": true
}
},
"[go.mod]": {
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
"source.organizeImports": true
}
},
"gopls": {

View File

@@ -1,10 +1,10 @@
blank_issues_enabled: false
contact_links:
- name: Report a Wiki issue
url: https://github.com/qdm12/gluetun-wiki/issues/new/choose
url: https://github.com/qdm12/gluetun-wiki/issues/new
about: Please create an issue on the gluetun-wiki repository.
- name: Configuration help?
url: https://github.com/qdm12/gluetun/discussions/new/choose
url: https://github.com/qdm12/gluetun/discussions/new
about: Please create a Github discussion.
- name: Unraid template issue
url: https://github.com/qdm12/gluetun/discussions/550

8
.github/labels.yml vendored
View File

@@ -3,9 +3,6 @@
- name: "Status: 🔴 Blocked"
color: "f7d692"
description: "Blocked by another issue or pull request"
- name: "Status: 📌 Before next release"
color: "f7d692"
description: "Has to be done before the next release"
- name: "Status: 🔒 After next release"
color: "f7d692"
description: "Will be done after the next release"
@@ -39,8 +36,6 @@
# VPN providers
- name: "☁️ AirVPN"
color: "cfe8d4"
- name: "☁️ Custom"
color: "cfe8d4"
- name: "☁️ Cyberghost"
color: "cfe8d4"
- name: "☁️ HideMyAss"
@@ -96,9 +91,6 @@
- name: "Category: Maintenance ⛓️"
description: "Anything related to code or other maintenance"
color: "ffc7ea"
- name: "Category: Logs 📚"
description: "Something to change in logs"
color: "ffc7ea"
- name: "Category: Good idea 🎯"
description: "This is a good idea, judged by the maintainers"
color: "ffc7ea"

View File

@@ -120,3 +120,9 @@ linters:
- wastedassign
- whitespace
- zerologlint
run:
skip-dirs:
- .devcontainer
- .github
- doc

View File

@@ -197,7 +197,6 @@ ENV VPN_SERVICE_PROVIDER=pia \
# Control server
HTTP_CONTROL_SERVER_LOG=on \
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \
# Server data updater
UPDATER_PERIOD=0 \
UPDATER_MIN_RATIO=0.8 \

View File

@@ -60,8 +60,8 @@ Lightweight swiss-knife-like VPN client to multiple VPN service providers
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **WeVPN**, **Windscribe** servers
- Supports OpenVPN for all providers listed
- Supports Wireguard both kernelspace and userspace
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
- For **Cyberghost**, **Private Internet Access**, **PrivateVPN**, **PureVPN**, **Torguard**, **VPN Unlimited**, **VyprVPN** and **WeVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **Surfshark** and **Windscribe**
- For **ProtonVPN**, **PureVPN**, **Torguard**, **VPN Unlimited** and **WeVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
- For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
- More in progress, see [#134](https://github.com/qdm12/gluetun/issues/134)
- DNS over TLS baked in with service provider(s) of your choice
@@ -73,7 +73,7 @@ Lightweight swiss-knife-like VPN client to multiple VPN service providers
- [Connect other containers to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-container-to-gluetun.md)
- [Connect LAN devices to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-lan-device-to-gluetun.md)
- Compatible with amd64, i686 (32 bit), **ARM** 64 bit, ARM 32 bit v6 and v7, and even ppc64le 🎆
- Custom VPN server side port forwarding for [Perfect Privacy](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/perfect-privacy.md#vpn-server-port-forwarding), [Private Internet Access](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/private-internet-access.md#vpn-server-port-forwarding) and [ProtonVPN](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/protonvpn.md#vpn-server-port-forwarding)
- [Custom VPN server side port forwarding for Private Internet Access](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/private-internet-access.md#vpn-server-port-forwarding)
- Possibility of split horizon DNS by selecting multiple DNS over TLS providers
- Unbound subprogram drops root privileges once launched
- Can work as a Kubernetes sidecar container, thanks @rorph
@@ -84,7 +84,7 @@ Lightweight swiss-knife-like VPN client to multiple VPN service providers
Go to the [Wiki](https://github.com/qdm12/gluetun-wiki)!
[🐛 Found a bug in the Wiki?!](https://github.com/qdm12/gluetun-wiki/issues/new/choose)
[🐛 Found a bug in the Wiki?!](https://github.com/qdm12/gluetun-wiki/issues/new)
Here's a docker-compose.yml for the laziest:

View File

@@ -161,14 +161,12 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return cli.Update(ctx, args[2:], logger)
case "format-servers":
return cli.FormatServers(args[2:])
case "genkey":
return cli.GenKey(args[2:])
default:
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
}
}
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
announcementExp, err := time.Parse(time.RFC3339, "2023-07-01T00:00:00Z")
if err != nil {
return err
}
@@ -178,8 +176,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
Emails: []string{"quentin.mcgaw@gmail.com"},
Version: buildInfo.Version,
Commit: buildInfo.Commit,
Created: buildInfo.Created,
Announcement: "All control server routes will become private by default after the v3.41.0 release",
BuildDate: buildInfo.Created,
Announcement: "Wiki moved to https://github.com/qdm12/gluetun-wiki",
AnnounceExp: announcementExp,
// Sponsor information
PaypalUser: "qmcgaw",
@@ -476,7 +474,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
logger.New(log.SetComponent("http server")),
allSettings.ControlServer.AuthFilePath,
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported)
if err != nil {
@@ -598,7 +595,6 @@ type clier interface {
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
GenKey(args []string) error
}
type Tun interface {

5
go.mod
View File

@@ -3,17 +3,16 @@ module github.com/qdm12/gluetun
go 1.22
require (
github.com/breml/rootcerts v0.2.17
github.com/breml/rootcerts v0.2.16
github.com/fatih/color v1.17.0
github.com/golang/mock v1.6.0
github.com/klauspost/compress v1.17.8
github.com/klauspost/pgzip v1.2.6
github.com/pelletier/go-toml/v2 v2.2.2
github.com/qdm12/dns v1.11.0
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
github.com/qdm12/gosettings v0.4.2
github.com/qdm12/goshutdown v0.3.0
github.com/qdm12/gosplash v0.2.0
github.com/qdm12/gosplash v0.1.0
github.com/qdm12/gotree v0.2.0
github.com/qdm12/log v0.1.0
github.com/qdm12/ss-server v0.6.0

18
go.sum
View File

@@ -4,8 +4,8 @@ github.com/alcortesm/tgz v0.0.0-20161220082320-9c5fe88206d7/go.mod h1:6zEj6s6u/g
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/breml/rootcerts v0.2.17 h1:0/M2BE2Apw0qEJCXDOkaiu7d5Sx5ObNfe1BkImJ4u1I=
github.com/breml/rootcerts v0.2.17/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/breml/rootcerts v0.2.16 h1:yN1TGvicfHx8dKz3OQRIrx/5nE/iN3XT1ibqGbd6urc=
github.com/breml/rootcerts v0.2.16/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -83,8 +83,6 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -95,12 +93,14 @@ github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8
github.com/qdm12/golibs v0.0.0-20210723175634-a75ca7fd74c2/go.mod h1:6aRbg4Z/bTbm9JfxsGXfWKHi7zsOvPfUTK1S5HuAFKg=
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6 h1:bge5AL7cjHJMPz+5IOz5yF01q/l8No6+lIEBieA8gMg=
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6/go.mod h1:6aRbg4Z/bTbm9JfxsGXfWKHi7zsOvPfUTK1S5HuAFKg=
github.com/qdm12/gosettings v0.4.1 h1:c7+14jO1Y2kFXBCUfS2+QE2NgwTKfzcdJzGEFRItCI8=
github.com/qdm12/gosettings v0.4.1/go.mod h1:uItKwGXibJp2pQ0am6MBKilpjfvYTGiH+zXHd10jFj8=
github.com/qdm12/gosettings v0.4.2 h1:Gb39NScPr7OQV+oy0o1OD7A121udITDJuUGa7ljDF58=
github.com/qdm12/gosettings v0.4.2/go.mod h1:CPrt2YC4UsURTrslmhxocVhMCW03lIrqdH2hzIf5prg=
github.com/qdm12/goshutdown v0.3.0 h1:pqBpJkdwlZlfTEx4QHtS8u8CXx6pG0fVo6S1N0MpSEM=
github.com/qdm12/goshutdown v0.3.0/go.mod h1:EqZ46No00kCTZ5qzdd3qIzY6ayhMt24QI8Mh8LVQYmM=
github.com/qdm12/gosplash v0.2.0 h1:DOxCEizbW6ZG+FgpH2oK1atT6bM8MHL9GZ2ywSS4zZY=
github.com/qdm12/gosplash v0.2.0/go.mod h1:k+1PzhO0th9cpX4q2Nneu4xTsndXqrM/x7NTIYmJ4jo=
github.com/qdm12/gosplash v0.1.0 h1:Sfl+zIjFZFP7b0iqf2l5UkmEY97XBnaKkH3FNY6Gf7g=
github.com/qdm12/gosplash v0.1.0/go.mod h1:+A3fWW4/rUeDXhY3ieBzwghKdnIPFJgD8K3qQkenJlw=
github.com/qdm12/gotree v0.2.0 h1:+58ltxkNLUyHtATFereAcOjBVfY6ETqRex8XK90Fb/c=
github.com/qdm12/gotree v0.2.0/go.mod h1:1SdFaqKZuI46U1apbXIf25pDMNnrPuYLEqMF/qL4lY4=
github.com/qdm12/log v0.1.0 h1:jYBd/xscHYpblzZAd2kjZp2YmuYHjAAfbTViJWxoPTw=
@@ -115,16 +115,10 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=

View File

@@ -1,66 +0,0 @@
package cli
import (
"crypto/rand"
"flag"
"fmt"
)
func (c *CLI) GenKey(args []string) (err error) {
flagSet := flag.NewFlagSet("genkey", flag.ExitOnError)
err = flagSet.Parse(args)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
const keyLength = 128 / 8
keyBytes := make([]byte, keyLength)
_, _ = rand.Read(keyBytes)
key := base58Encode(keyBytes)
fmt.Println(key)
return nil
}
func base58Encode(data []byte) string {
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
const radix = 58
zcount := 0
for zcount < len(data) && data[zcount] == 0 {
zcount++
}
// integer simplification of ceil(log(256)/log(58))
ceilLog256Div58 := (len(data)-zcount)*555/406 + 1 //nolint:gomnd
size := zcount + ceilLog256Div58
output := make([]byte, size)
high := size - 1
for _, b := range data {
i := size - 1
for carry := uint32(b); i > high || carry != 0; i-- {
carry += 256 * uint32(output[i]) //nolint:gomnd
output[i] = byte(carry % radix)
carry /= radix
}
high = i
}
// Determine the additional "zero-gap" in the output buffer
additionalZeroGapEnd := zcount
for additionalZeroGapEnd < size && output[additionalZeroGapEnd] == 0 {
additionalZeroGapEnd++
}
val := output[additionalZeroGapEnd-zcount:]
size = len(val)
for i := range val {
output[i] = alphabet[val[i]]
}
return string(output[:size])
}

View File

@@ -39,7 +39,6 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) {
providers.Ivpn,
providers.Mullvad,
providers.Nordvpn,
providers.Protonvpn,
providers.Surfshark,
providers.Windscribe,
}

View File

@@ -19,11 +19,6 @@ type ControlServer struct {
// Log can be true or false to enable logging on requests.
// It cannot be nil in the internal state.
Log *bool
// AuthFilePath is the path to the file containing the authentication
// configuration for the middleware.
// It cannot be empty in the internal state and defaults to
// /gluetun/auth/config.toml.
AuthFilePath string
}
func (c ControlServer) validate() (err error) {
@@ -49,9 +44,8 @@ func (c ControlServer) validate() (err error) {
func (c *ControlServer) copy() (copied ControlServer) {
return ControlServer{
Address: gosettings.CopyPointer(c.Address),
Log: gosettings.CopyPointer(c.Log),
AuthFilePath: c.AuthFilePath,
Address: gosettings.CopyPointer(c.Address),
Log: gosettings.CopyPointer(c.Log),
}
}
@@ -61,13 +55,11 @@ func (c *ControlServer) copy() (copied ControlServer) {
func (c *ControlServer) overrideWith(other ControlServer) {
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath)
}
func (c *ControlServer) setDefaults() {
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
c.Log = gosettings.DefaultPointer(c.Log, true)
c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml")
}
func (c ControlServer) String() string {
@@ -78,7 +70,6 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) {
node = gotree.New("Control server settings:")
node.Appendf("Listening address: %s", *c.Address)
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
node.Appendf("Authentication file path: %s", c.AuthFilePath)
return node
}
@@ -87,10 +78,6 @@ func (c *ControlServer) read(r *reader.Reader) (err error) {
if err != nil {
return err
}
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH")
return nil
}

View File

@@ -191,19 +191,11 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
}
if vpnServiceProvider == providers.Custom {
switch len(settings.Names) {
case 0:
case 1:
// Allow a single name to be specified for the custom provider in case
// the user wants to use VPN server side port forwarding with PIA
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
default:
return fmt.Errorf("%w: %d names specified instead of "+
"0 or 1 for the custom provider",
ErrNameNotValid, len(settings.Names))
}
if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 {
// Allow a single name to be specified for the custom provider in case
// the user wants to use VPN server side port forwarding with PIA
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
}
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
if err != nil {
@@ -237,8 +229,6 @@ func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string)
switch {
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
case *settings.StreamOnly &&
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)

View File

@@ -78,8 +78,7 @@ func Test_Settings_String(t *testing.T) {
| └── Enabled: no
├── Control server settings:
| ├── Listening address: :8000
| ── Logging: yes
| └── Authentication file path: /gluetun/auth/config.toml
| ── Logging: yes
├── OS Alpine settings:
| ├── Process UID: 1000
| └── Process GID: 1000

View File

@@ -62,7 +62,6 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
providers.Ivpn,
providers.Mullvad,
providers.Nordvpn,
providers.Protonvpn,
providers.Surfshark,
providers.Windscribe,
) {
@@ -174,15 +173,10 @@ func (w *Wireguard) overrideWith(other Wireguard) {
func (w *Wireguard) setDefaults(vpnProvider string) {
w.PrivateKey = gosettings.DefaultPointer(w.PrivateKey, "")
w.PreSharedKey = gosettings.DefaultPointer(w.PreSharedKey, "")
switch vpnProvider {
case providers.Nordvpn:
if vpnProvider == providers.Nordvpn {
defaultNordVPNAddress := netip.AddrFrom4([4]byte{10, 5, 0, 2})
defaultNordVPNPrefix := netip.PrefixFrom(defaultNordVPNAddress, defaultNordVPNAddress.BitLen())
w.Addresses = gosettings.DefaultSlice(w.Addresses, []netip.Prefix{defaultNordVPNPrefix})
case providers.Protonvpn:
defaultAddress := netip.AddrFrom4([4]byte{10, 2, 0, 2})
defaultPrefix := netip.PrefixFrom(defaultAddress, defaultAddress.BitLen())
w.Addresses = gosettings.DefaultSlice(w.Addresses, []netip.Prefix{defaultPrefix})
}
defaultAllowedIPs := []netip.Prefix{
netip.PrefixFrom(netip.IPv4Unspecified(), 0),

View File

@@ -39,8 +39,8 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
// Validate EndpointIP
switch vpnProvider {
case providers.Airvpn, providers.Fastestvpn, providers.Ivpn,
providers.Mullvad, providers.Nordvpn, providers.Protonvpn,
providers.Surfshark, providers.Windscribe:
providers.Mullvad, providers.Nordvpn, providers.Surfshark,
providers.Windscribe:
// endpoint IP addresses are baked in
case providers.Custom:
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
@@ -57,8 +57,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
return fmt.Errorf("%w", ErrWireguardEndpointPortNotSet)
}
// EndpointPort cannot be set
case providers.Fastestvpn, providers.Nordvpn,
providers.Protonvpn, providers.Surfshark:
case providers.Fastestvpn, providers.Surfshark, providers.Nordvpn:
if *w.EndpointPort != 0 {
return fmt.Errorf("%w", ErrWireguardEndpointPortSet)
}

View File

@@ -3,4 +3,13 @@ package vpn
const (
OpenVPN = "openvpn"
Wireguard = "wireguard"
Both = "openvpn+wireguard"
)
func IsWireguard(s string) bool {
return s == Wireguard || s == Both
}
func IsOpenVPN(s string) bool {
return s == OpenVPN || s == Both
}

View File

@@ -1,98 +0,0 @@
package firewall
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
)
// isDeleteMatchInstruction returns true if the iptables instruction
// is a delete instruction by rule matching. It returns false if the
// instruction is a delete instruction by line number, or not a delete
// instruction.
func isDeleteMatchInstruction(instruction string) bool {
fields := strings.Fields(instruction)
for i, field := range fields {
switch {
case field != "-D" && field != "--delete": //nolint:goconst
continue
case i == len(fields)-1: // malformed: missing chain name
return false
case i == len(fields)-2: // chain name is last field
return true
default:
// chain name is fields[i+1]
const base, bitLength = 10, 16
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
return err != nil // not a line number
}
}
return false
}
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
runner Runner, logger Logger) (err error) {
targetRule, err := parseIptablesInstruction(instruction)
if err != nil {
return fmt.Errorf("parsing iptables command: %w", err)
}
lineNumber, err := findLineNumber(ctx, iptablesBinary,
targetRule, runner, logger)
if err != nil {
return fmt.Errorf("finding iptables chain rule line number: %w", err)
} else if lineNumber == 0 {
logger.Debug("rule matching \"" + instruction + "\" not found")
return nil
}
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
instruction, lineNumber))
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return err
}
return nil
}
// findLineNumber finds the line number of an iptables rule.
// It returns 0 if the rule is not found.
func findLineNumber(ctx context.Context, iptablesBinary string,
instruction iptablesInstruction, runner Runner, logger Logger) (
lineNumber uint16, err error) {
listFlags := []string{"-t", instruction.table, "-L", instruction.chain,
"--line-numbers", "-n", "-v"}
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return 0, err
}
chain, err := parseChain(output)
if err != nil {
return 0, fmt.Errorf("parsing chain list: %w", err)
}
for _, rule := range chain.rules {
if instruction.equalToRule(instruction.table, chain.name, rule) {
return rule.lineNumber, nil
}
}
return 0, nil
}

View File

@@ -1,188 +0,0 @@
package firewall
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_isDeleteMatchInstruction(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
instruction string
isDeleteMatch bool
}{
"not_delete": {
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
},
"malformed_missing_chain_name": {
instruction: "-t nat -D",
},
"delete_chain_name_last_field": {
instruction: "-t nat --delete PREROUTING",
isDeleteMatch: true,
},
"delete_match": {
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
isDeleteMatch: true,
},
"delete_line_number_last_field": {
instruction: "-t nat -D PREROUTING 2",
},
"delete_line_number": {
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
})
}
}
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
"^--line-numbers$", "^-n$", "^-v$")
}
func Test_deleteIPTablesRule(t *testing.T) {
t.Parallel()
const iptablesBinary = "/sbin/iptables"
errTest := errors.New("test error")
testCases := map[string]struct {
instruction string
makeRunner func(ctrl *gomock.Controller) *MockRunner
makeLogger func(ctrl *gomock.Controller) *MockLogger
errWrapped error
errMessage string
}{
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: iptables command is malformed: " +
"fields count 1 is not even: \"invalid\"",
},
"list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
return logger
},
errWrapped: errTest,
errMessage: `finding iptables chain rule line number: command failed: ` +
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
},
"rule_not_found": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
num pkts bytes target prot opt in out source destination
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
return logger
},
},
"rule_found_delete_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
errWrapped: errTest,
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
},
"rule_found_delete_success": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
ctx := context.Background()
instruction := testCase.instruction
var runner *MockRunner
if testCase.makeRunner != nil {
runner = testCase.makeRunner(ctrl)
}
var logger *MockLogger
if testCase.makeLogger != nil {
logger = testCase.makeLogger(ctrl)
}
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -1,13 +0,0 @@
package firewall
import "github.com/qdm12/golibs/command"
type Runner interface {
Run(cmd command.ExecCmd) (output string, err error)
}
type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}

View File

@@ -15,7 +15,7 @@ import (
// empty string path is returned.
func findIP6tablesSupported(ctx context.Context, runner command.Runner) (
ip6tablesPath string, err error) {
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy")
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables-legacy", "ip6tables", "ip6tables-nft")
if errors.Is(err, ErrIPTablesNotSupported) {
return "", nil
} else if err != nil {
@@ -40,14 +40,10 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
c.runner, c.logger)
}
c.logger.Debug(c.ip6Tables + " " + instruction)
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ip6Tables, instruction, output, err)
@@ -59,7 +55,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP": //nolint:goconst
case "ACCEPT", "DROP":
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
}

View File

@@ -70,14 +70,10 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
}
c.logger.Debug(c.ipTables + " " + instruction)
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ipTables, instruction, output, err)
@@ -147,7 +143,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool) error {
protocol := connection.Protocol
if protocol == "tcp-client" {
protocol = "tcp" //nolint:goconst
protocol = "tcp"
}
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
@@ -214,14 +210,10 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
}
err = c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
fmt.Sprintf("-t nat %s PREROUTING %s -d 127.0.0.1 -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
fmt.Sprintf("-t nat %s PREROUTING %s -d 127.0.0.1 -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
@@ -229,14 +221,10 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
}
err = c.runIP6tablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
fmt.Sprintf("-t nat %s PREROUTING %s -d ::1 -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
fmt.Sprintf("-t nat %s PREROUTING %s -d ::1 -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
appendOrDelete(remove), interfaceFlag, destinationPort),
})
if err != nil {
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",

View File

@@ -1,381 +0,0 @@
package firewall
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type chain struct {
name string
policy string
packets uint64
bytes uint64
rules []chainRule
}
type chainRule struct {
lineNumber uint16 // starts from 1 and cannot be zero.
packets uint64
bytes uint64
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
protocol string // "tcp", "udp" or "" for all protocols.
inputInterface string // input interface, for example "tun0" or "*""
outputInterface string // output interface, for example "eth0" or "*""
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
}
var (
ErrChainListMalformed = errors.New("iptables chain list output is malformed")
)
func parseChain(iptablesOutput string) (c chain, err error) {
// Text example:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
// pkts bytes target prot opt in out source destination
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
iptablesOutput = strings.TrimSpace(iptablesOutput)
linesWithComments := strings.Split(iptablesOutput, "\n")
// Filter out lines starting with a '#' character
lines := make([]string, 0, len(linesWithComments))
for _, line := range linesWithComments {
if strings.HasPrefix(line, "#") {
continue
}
lines = append(lines, line)
}
const minLines = 2 // chain general information line + legend line
if len(lines) < minLines {
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
ErrChainListMalformed, iptablesOutput)
}
c, err = parseChainGeneralDataLine(lines[0])
if err != nil {
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
}
// Sanity check for the legend line
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
legendLine := strings.TrimSpace(lines[1])
legendFields := strings.Fields(legendLine)
if !slices.Equal(expectedLegendFields, legendFields) {
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
}
lines = lines[2:] // remove chain general information line and legend line
if len(lines) == 0 {
return c, nil
}
c.rules = make([]chainRule, len(lines))
for i, line := range lines {
c.rules[i], err = parseChainRuleLine(line)
if err != nil {
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
}
}
return c, nil
}
// parseChainGeneralDataLine parses the first line of iptables chain list output.
// For example, it can parse the following line:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
// It returns a chain struct with the parsed data.
func parseChainGeneralDataLine(line string) (base chain, err error) {
line = strings.TrimSpace(line)
runesToRemove := []rune{'(', ')', ','}
for _, r := range runesToRemove {
line = strings.ReplaceAll(line, string(r), "")
}
fields := strings.Fields(line)
const expectedNumberOfFields = 8
if len(fields) != expectedNumberOfFields {
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
ErrChainListMalformed, expectedNumberOfFields, line)
}
// Sanity checks
indexToExpectedValue := map[int]string{
0: "Chain",
2: "policy",
5: "packets",
7: "bytes",
}
for index, expectedValue := range indexToExpectedValue {
if fields[index] == expectedValue {
continue
}
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
ErrChainListMalformed, expectedValue, index, line)
}
base.name = fields[1] // chain name could be custom
base.policy = fields[3]
err = checkTarget(base.policy)
if err != nil {
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
}
packets, err := parseMetricSize(fields[4])
if err != nil {
return chain{}, fmt.Errorf("parsing packets: %w", err)
}
base.packets = packets
bytes, err := parseMetricSize(fields[6])
if err != nil {
return chain{}, fmt.Errorf("parsing bytes: %w", err)
}
base.bytes = bytes
return base, nil
}
var (
ErrChainRuleMalformed = errors.New("chain rule is malformed")
)
func parseChainRuleLine(line string) (rule chainRule, err error) {
line = strings.TrimSpace(line)
if line == "" {
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
}
fields := strings.Fields(line)
const minFields = 10
if len(fields) < minFields {
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
}
for fieldIndex, field := range fields[:minFields] {
err = parseChainRuleField(fieldIndex, field, &rule)
if err != nil {
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
}
}
if len(fields) > minFields {
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
if err != nil {
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
}
}
return rule, nil
}
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
if field == "" {
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
}
const (
numIndex = iota
packetsIndex
bytesIndex
targetIndex
protocolIndex
optIndex
inputInterfaceIndex
outputInterfaceIndex
sourceIndex
destinationIndex
)
switch fieldIndex {
case numIndex:
rule.lineNumber, err = parseLineNumber(field)
if err != nil {
return fmt.Errorf("parsing line number: %w", err)
}
case packetsIndex:
rule.packets, err = parseMetricSize(field)
if err != nil {
return fmt.Errorf("parsing packets: %w", err)
}
case bytesIndex:
rule.bytes, err = parseMetricSize(field)
if err != nil {
return fmt.Errorf("parsing bytes: %w", err)
}
case targetIndex:
err = checkTarget(field)
if err != nil {
return fmt.Errorf("checking target: %w", err)
}
rule.target = field
case protocolIndex:
rule.protocol, err = parseProtocol(field)
if err != nil {
return fmt.Errorf("parsing protocol: %w", err)
}
case optIndex: // ignored
case inputInterfaceIndex:
rule.inputInterface = field
case outputInterfaceIndex:
rule.outputInterface = field
case sourceIndex:
rule.source, err = parseIPPrefix(field)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case destinationIndex:
rule.destination, err = parseIPPrefix(field)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
}
return nil
}
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
for i := 0; i < len(optionalFields); i++ {
key := optionalFields[i]
switch key {
case "tcp", "udp":
i++
value := optionalFields[i]
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing destination port %q: %w", value, err)
}
rule.destinationPort = uint16(destinationPort)
case "redir":
i++
switch optionalFields[i] {
case "ports":
i++
ports, err := parsePortsCSV(optionalFields[i])
if err != nil {
return fmt.Errorf("parsing redirection ports: %w", err)
}
rule.redirPorts = ports
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
}
case "ctstate":
i++
rule.ctstate = strings.Split(optionalFields[i], ",")
default:
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
}
}
return nil
}
func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" {
return nil, nil
}
fields := strings.Split(s, ",")
ports = make([]uint16, len(fields))
for i, field := range fields {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(field, base, bitLength)
if err != nil {
return nil, fmt.Errorf("parsing port %q: %w", field, err)
}
ports[i] = uint16(port)
}
return ports, nil
}
var (
ErrLineNumberIsZero = errors.New("line number is zero")
)
func parseLineNumber(s string) (n uint16, err error) {
const base, bitLength = 10, 16
lineNumber, err := strconv.ParseUint(s, base, bitLength)
if err != nil {
return 0, err
} else if lineNumber == 0 {
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
}
return uint16(lineNumber), nil
}
var (
ErrTargetUnknown = errors.New("unknown target")
)
func checkTarget(target string) (err error) {
switch target {
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
return nil
}
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
}
var (
ErrProtocolUnknown = errors.New("unknown protocol")
)
func parseProtocol(s string) (protocol string, err error) {
switch s {
case "0":
case "6":
protocol = "tcp"
case "17":
protocol = "udp"
default:
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
}
return protocol, nil
}
var (
ErrMetricSizeMalformed = errors.New("metric size is malformed")
)
// parseMetricSize parses a metric size string like 140K or 226M and
// returns the raw integer matching it.
func parseMetricSize(size string) (n uint64, err error) {
if size == "" {
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
}
//nolint:gomnd
multiplerLetterToValue := map[byte]uint64{
'K': 1000,
'M': 1000000,
'G': 1000000000,
'T': 1000000000000,
}
lastCharacter := size[len(size)-1]
multiplier, ok := multiplerLetterToValue[lastCharacter]
if ok { // multiplier present
size = size[:len(size)-1]
} else {
multiplier = 1
}
const base, bitLength = 10, 64
n, err = strconv.ParseUint(size, base, bitLength)
if err != nil {
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
}
n *= multiplier
return n, nil
}

View File

@@ -1,121 +0,0 @@
package firewall
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseChain(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
iptablesOutput string
table chain
errWrapped error
errMessage string
}{
"no_output": {
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
},
"single_line_only": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
},
"malformed_general_data_line": {
iptablesOutput: `Chain INPUT
num pkts bytes target prot opt in out source destination`,
errWrapped: ErrChainListMalformed,
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
"expected 8 fields in \"Chain INPUT\"",
},
"malformed_legend": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: legend " +
"\"num pkts bytes target prot opt in out source\" " +
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
},
"no_rule": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source destination`,
table: chain{
name: "INPUT",
policy: "ACCEPT",
packets: 140000,
bytes: 226000000,
},
},
"some_rules": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source destination
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
3 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
`,
table: chain{
name: "INPUT",
policy: "ACCEPT",
packets: 140000,
bytes: 226000000,
rules: []chainRule{
{
lineNumber: 1,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "udp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
destinationPort: 55405,
},
{
lineNumber: 2,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "tcp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
destinationPort: 55405,
},
{
lineNumber: 3,
packets: 0,
bytes: 0,
target: "DROP",
protocol: "",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
},
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
table, err := parseChain(testCase.iptablesOutput)
assert.Equal(t, testCase.table, table)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -5,6 +5,12 @@ import (
"net/netip"
)
type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}
func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) {
c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+
"no default route matching its family", subnet))

View File

@@ -1,3 +0,0 @@
package firewall
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger

View File

@@ -1,109 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger)
// Package firewall is a generated GoMock package.
package firewall
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
command "github.com/qdm12/golibs/command"
)
// MockRunner is a mock of Runner interface.
type MockRunner struct {
ctrl *gomock.Controller
recorder *MockRunnerMockRecorder
}
// MockRunnerMockRecorder is the mock recorder for MockRunner.
type MockRunnerMockRecorder struct {
mock *MockRunner
}
// NewMockRunner creates a new mock instance.
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
mock := &MockRunner{ctrl: ctrl}
mock.recorder = &MockRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
return m.recorder
}
// Run mocks base method.
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Run indicates an expected call of Run.
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
}
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Error mocks base method.
func (m *MockLogger) Error(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Error", arg0)
}
// Error indicates an expected call of Error.
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
}
// Info mocks base method.
func (m *MockLogger) Info(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0)
}
// Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
}

View File

@@ -1,166 +0,0 @@
package firewall
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
append bool
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
}
func (i *iptablesInstruction) setDefaults() {
if i.table == "" {
i.table = "filter"
}
}
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
switch {
case i.table != table:
return false
case i.chain != chain:
return false
case i.target != rule.target:
return false
case i.protocol != rule.protocol:
return false
case i.destinationPort != rule.destinationPort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
return false
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
return false
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
return false
case !ipPrefixesEqual(i.source, rule.source):
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
default:
return true
}
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
}
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
return instruction == chainRule ||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var (
ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
)
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
if len(fields)%2 != 0 {
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
ErrIptablesCommandMalformed, len(fields), s)
}
for i := 0; i < len(fields); i += 2 {
key := fields[i]
value := fields[i+1]
err = parseInstructionFlag(key, value, &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
switch key {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.append = false
instruction.chain = value
case "-A", "--append":
instruction.append = true
instruction.chain = value
case "-j", "--jump":
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match": // ignore match
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
instruction.outputInterface = value
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
portStrings := strings.Split(value, ",")
instruction.toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
}
default:
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
}
return nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
return netip.ParsePrefix(value)
}
ip, err := netip.ParseAddr(value)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}

View File

@@ -1,138 +0,0 @@
package firewall
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseIptablesInstruction(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
s string
instruction iptablesInstruction
errWrapped error
errMessage string
}{
"no_instruction": {
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: empty instruction",
},
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
},
"unknown_key": {
s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
},
"one_pair": {
s: "-A INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
},
},
"instruction_A": {
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("5.6.7.8/32"),
destinationPort: 10000,
target: "ACCEPT",
},
},
"nat_redirection": {
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
append: false,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
target: "REDIRECT",
toPorts: []uint16{5678},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
rule, err := parseIptablesInstruction(testCase.s)
assert.Equal(t, testCase.instruction, rule)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}
func Test_parseIPPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
value string
prefix netip.Prefix
errMessage string
}{
"empty": {
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
},
"invalid": {
value: "invalid",
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
},
"valid_ipv4_with_bits": {
value: "10.0.0.0/16",
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
},
"valid_ipv4_without_bits": {
value: "10.0.0.4",
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
},
"valid_ipv6_with_bits": {
value: "2001:db8::/32",
prefix: netip.PrefixFrom(
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
32),
},
"valid_ipv6_without_bits": {
value: "2001:db8::",
prefix: netip.PrefixFrom(
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
128),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix, err := parseIPPrefix(testCase.value)
assert.Equal(t, testCase.prefix, prefix)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,50 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/golibs/command (interfaces: Runner)
// Package firewall is a generated GoMock package.
package firewall
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
command "github.com/qdm12/golibs/command"
)
// MockRunner is a mock of Runner interface.
type MockRunner struct {
ctrl *gomock.Controller
recorder *MockRunnerMockRecorder
}
// MockRunnerMockRecorder is the mock recorder for MockRunner.
type MockRunnerMockRecorder struct {
mock *MockRunner
}
// NewMockRunner creates a new mock instance.
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
mock := &MockRunner{ctrl: ctrl}
mock.recorder = &MockRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
return m.recorder
}
// Run mocks base method.
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Run indicates an expected call of Run.
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
}

View File

@@ -11,6 +11,8 @@ import (
"github.com/stretchr/testify/require"
)
//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner
func newAppendTestRuleMatcher(path string) *cmdMatcher {
return newCmdMatcher(path,
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",

View File

@@ -21,7 +21,7 @@ type Connection struct {
PubKey string `json:"pubkey"`
// ServerName is used for PIA for port forwarding
ServerName string `json:"server_name,omitempty"`
// PortForward is used for PIA and ProtonVPN for port forwarding
// PortForward is used for PIA for port forwarding
PortForward bool `json:"port_forward"`
}

View File

@@ -125,7 +125,7 @@ func getMarkdownHeaders(vpnProvider string) (headers []string) {
case providers.Mullvad:
return []string{countryHeader, cityHeader, ispHeader, ownedHeader, hostnameHeader, vpnHeader}
case providers.Nordvpn:
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader, vpnHeader, categoriesHeader}
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader, categoriesHeader}
case providers.Perfectprivacy:
return []string{cityHeader, tcpHeader, udpHeader}
case providers.Privado:
@@ -135,15 +135,14 @@ func getMarkdownHeaders(vpnProvider string) (headers []string) {
case providers.Privatevpn:
return []string{countryHeader, cityHeader, hostnameHeader}
case providers.Protonvpn:
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader, vpnHeader,
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader,
freeHeader, portForwardHeader, secureHeader, torHeader}
case providers.Purevpn:
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader, tcpHeader, udpHeader}
case providers.SlickVPN:
return []string{regionHeader, countryHeader, cityHeader, hostnameHeader}
case providers.Surfshark:
return []string{regionHeader, countryHeader, cityHeader, hostnameHeader,
vpnHeader, multiHopHeader, tcpHeader, udpHeader}
return []string{regionHeader, countryHeader, cityHeader, hostnameHeader, multiHopHeader, tcpHeader, udpHeader}
case providers.Torguard:
return []string{countryHeader, cityHeader, hostnameHeader, tcpHeader, udpHeader}
case providers.VPNSecure:

View File

@@ -22,8 +22,8 @@ type Server struct {
Number uint16 `json:"number,omitempty"`
ServerName string `json:"server_name,omitempty"`
Hostname string `json:"hostname,omitempty"`
TCP bool `json:"tcp,omitempty"`
UDP bool `json:"udp,omitempty"`
TCP bool `json:"tcp,omitempty"` // TODO v4 rename to openvpn_tcp
UDP bool `json:"udp,omitempty"` // TODO v4 rename to openvpn_udp
OvpnX509 string `json:"x509,omitempty"`
RetroLoc string `json:"retroloc,omitempty"` // TODO remove in v4
MultiHop bool `json:"multihop,omitempty"`
@@ -38,6 +38,25 @@ type Server struct {
IPs []netip.Addr `json:"ips,omitempty"`
}
func (s *Server) SetVPN(vpnType string) {
switch s.VPN {
case "":
s.VPN = vpnType
case vpn.Both:
return
case vpn.OpenVPN:
if vpnType == vpn.Wireguard {
s.VPN = vpn.Both
}
case vpn.Wireguard:
if vpnType == vpn.OpenVPN {
s.VPN = vpn.Both
}
default:
panic(fmt.Sprintf("VPN type %q not supported", s.VPN))
}
}
var (
ErrVPNFieldEmpty = errors.New("vpn field is empty")
ErrHostnameFieldEmpty = errors.New("hostname field is empty")
@@ -48,16 +67,18 @@ var (
)
func (s *Server) HasMinimumInformation() (err error) {
isOpenVPN := s.VPN == vpn.OpenVPN || s.VPN == vpn.Both
isWireguard := s.VPN == vpn.Wireguard || s.VPN == vpn.Both
switch {
case s.VPN == "":
return fmt.Errorf("%w", ErrVPNFieldEmpty)
case len(s.IPs) == 0:
return fmt.Errorf("%w", ErrIPsFieldEmpty)
case s.VPN == vpn.Wireguard && (s.TCP || s.UDP):
case isWireguard && !isOpenVPN && (s.TCP || s.UDP):
return fmt.Errorf("%w", ErrNetworkProtocolSet)
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
case isOpenVPN && !s.TCP && !s.UDP:
return fmt.Errorf("%w", ErrNoNetworkProtocol)
case s.VPN == vpn.Wireguard && s.WgPubKey == "":
case isWireguard && s.WgPubKey == "":
return fmt.Errorf("%w", ErrWireguardPublicKeyEmpty)
default:
return nil

View File

@@ -18,11 +18,6 @@ var (
func extractDataFromLines(lines []string) (
connection models.Connection, err error) {
for i, line := range lines {
hashSymbolIndex := strings.Index(line, "#")
if hashSymbolIndex >= 0 {
line = line[:hashSymbolIndex]
}
ip, port, protocol, err := extractDataFromLine(line)
if err != nil {
return connection, fmt.Errorf("on line %d: %w", i+1, err)

View File

@@ -40,12 +40,11 @@ func getOpenVPNConnection(extractor Extractor,
connection.Port = customPort
}
// assume all custom provider servers support port forwarding
connection.PortForward = true
if len(selection.Names) > 0 {
// Set the server name for PIA port forwarding code used
// together with the custom provider.
connection.ServerName = selection.Names[0]
connection.PortForward = true
}
return connection, nil
@@ -54,17 +53,17 @@ func getOpenVPNConnection(extractor Extractor,
func getWireguardConnection(selection settings.ServerSelection) (
connection models.Connection) {
connection = models.Connection{
Type: vpn.Wireguard,
IP: selection.Wireguard.EndpointIP,
Port: *selection.Wireguard.EndpointPort,
Protocol: constants.UDP,
PubKey: selection.Wireguard.PublicKey,
PortForward: true, // assume all custom provider servers support port forwarding
Type: vpn.Wireguard,
IP: selection.Wireguard.EndpointIP,
Port: *selection.Wireguard.EndpointPort,
Protocol: constants.UDP,
PubKey: selection.Wireguard.PublicKey,
}
if len(selection.Names) > 0 {
// Set the server name for PIA port forwarding code used
// together with the custom provider.
connection.ServerName = selection.Names[0]
connection.PortForward = true
}
return connection
}

View File

@@ -68,26 +68,17 @@ func (hts hostToServerData) adaptWithIPs(hostToIPs map[string][]netip.Addr) {
func (hts hostToServerData) toServersSlice() (servers []models.Server) {
servers = make([]models.Server, 0, 2*len(hts)) //nolint:gomnd
for hostname, serverData := range hts {
baseServer := models.Server{
server := models.Server{
VPN: vpn.Both,
Hostname: hostname,
Country: serverData.country,
City: serverData.city,
IPs: serverData.ips,
TCP: serverData.openvpnTCP,
UDP: serverData.openvpnUDP,
WgPubKey: "658QxufMbjOTmB61Z7f+c7Rjg7oqWLnepTalqBERjF0=",
}
if serverData.openvpn {
openvpnServer := baseServer
openvpnServer.VPN = vpn.OpenVPN
openvpnServer.TCP = serverData.openvpnTCP
openvpnServer.UDP = serverData.openvpnUDP
servers = append(servers, openvpnServer)
}
if serverData.wireguard {
wireguardServer := baseServer
wireguardServer.VPN = vpn.Wireguard
const wireguardPublicKey = "658QxufMbjOTmB61Z7f+c7Rjg7oqWLnepTalqBERjF0="
wireguardServer.WgPubKey = wireguardPublicKey
servers = append(servers, wireguardServer)
}
servers = append(servers, server)
}
return servers
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"sort"
"strings"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
@@ -59,11 +58,9 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
servers = make([]models.Server, 0, len(hostToIPs))
for _, serverData := range data.Servers {
city, region := parseCity(serverData.City)
server := models.Server{
Country: serverData.Country,
City: city,
Region: region,
City: serverData.City,
ISP: serverData.ISP,
}
@@ -99,11 +96,3 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
return servers, nil
}
func parseCity(city string) (parsedCity, region string) {
commaIndex := strings.Index(city, ", ")
if commaIndex == -1 {
return city, ""
}
return city[:commaIndex], city[commaIndex+2:]
}

View File

@@ -35,11 +35,11 @@ func (hts hostToServer) add(data serverData) (err error) {
switch data.Type {
case "openvpn":
server.VPN = vpn.OpenVPN
server.SetVPN(vpn.OpenVPN)
server.UDP = true
server.TCP = true
case "wireguard":
server.VPN = vpn.Wireguard
server.SetVPN(vpn.Wireguard)
case "bridge":
// ignore bridge servers
return nil

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/netip"
"strings"
)
// Check out the JSON data from https://api.nordvpn.com/v2/servers?limit=10
@@ -93,9 +92,6 @@ func (s serversData) idToData() (
) {
groups = make(map[uint32]groupData, len(s.Groups))
for _, group := range s.Groups {
if group.Type.Identifier == "regions" { //nolint:goconst
group.Title = strings.ReplaceAll(group.Title, ",", "")
}
groups[group.ID] = group
}

View File

@@ -79,7 +79,7 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
server := models.Server{
Country: location.Country.Name,
Region: region,
Region: jsonServer.region(groups),
City: location.Country.City.Name,
Categories: jsonServer.categories(groups),
Hostname: jsonServer.Hostname,
@@ -98,12 +98,6 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
server.Number = number
}
var wireguardFound, openvpnFound bool
wireguardServer := server
wireguardServer.VPN = vpn.Wireguard
openVPNServer := server // accumulate UDP+TCP technologies
openVPNServer.VPN = vpn.OpenVPN
for _, technology := range jsonServer.Technologies {
if technology.Status != "online" {
continue
@@ -118,33 +112,25 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
switch technologyData.Identifier {
case "openvpn_udp", "openvpn_dedicated_udp":
openvpnFound = true
openVPNServer.UDP = true
server.SetVPN(vpn.OpenVPN)
server.UDP = true
case "openvpn_tcp", "openvpn_dedicated_tcp":
openvpnFound = true
openVPNServer.TCP = true
server.SetVPN(vpn.OpenVPN)
server.TCP = true
case "wireguard_udp":
wireguardFound = true
wireguardServer.WgPubKey, err = jsonServer.wireguardPublicKey(technologies)
server.WgPubKey, err = jsonServer.wireguardPublicKey(technologies)
if err != nil {
warning := fmt.Sprintf("ignoring Wireguard server %s: %s", jsonServer.Name, err)
warnings = append(warnings, warning)
wireguardFound = false
continue
}
server.SetVPN(vpn.Wireguard)
default: // Ignore other technologies
continue
}
}
const maxServers = 2
servers = make([]models.Server, 0, maxServers)
if openvpnFound {
servers = append(servers, openVPNServer)
}
if wireguardFound {
servers = append(servers, wireguardServer)
}
servers = append(servers, server)
return servers, warnings
}

View File

@@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context,
}
serverName := objects.ServerName
apiIP := buildAPIIPAddress(objects.Gateway)
logger := objects.Logger
if !objects.CanPortForward {
@@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context,
if !dataFound || expired {
client := objects.Client
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP,
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
p.portForwardPath, objects.Username, objects.Password)
if err != nil {
return nil, fmt.Errorf("refreshing port forward data: %w", err)
@@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context,
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
// First time binding
if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil {
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
return nil, fmt.Errorf("binding port: %w", err)
}
@@ -100,8 +100,6 @@ func (p *Provider) KeepPortForward(ctx context.Context,
panic("gateway is not set")
}
apiIP := buildAPIIPAddress(objects.Gateway)
privateIPClient, err := newHTTPClient(objects.ServerName)
if err != nil {
return fmt.Errorf("creating custom HTTP client: %w", err)
@@ -129,7 +127,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
return ctx.Err()
case <-keepAliveTimer.C:
err = bindPort(ctx, privateIPClient, apiIP, data)
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
if err != nil {
return fmt.Errorf("binding port: %w", err)
}
@@ -141,25 +139,14 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
}
func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) {
if gateway.Is6() {
panic("IPv6 gateway not supported")
}
gatewayBytes := gateway.As4()
gatewayBytes[2] = 128
gatewayBytes[3] = 1
return netip.AddrFrom4(gatewayBytes)
}
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, username, password)
if err != nil {
return data, fmt.Errorf("fetching token: %w", err)
}
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token)
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
if err != nil {
return data, fmt.Errorf("fetching port forwarding data: %w", err)
}
@@ -299,7 +286,7 @@ func fetchToken(ctx context.Context, client *http.Client,
return result.Token, nil
}
func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) (
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
port uint16, signature string, expiration time.Time, err error) {
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
@@ -307,7 +294,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.
queryParams.Add("token", token)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(apiIP.String(), "19999"),
Host: net.JoinHostPort(gateway.String(), "19999"),
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
@@ -353,7 +340,7 @@ var (
ErrBadResponse = errors.New("bad response received")
)
func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
payload, err := packPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return fmt.Errorf("serializing payload: %w", err)
@@ -364,7 +351,7 @@ func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr,
queryParams.Add("signature", data.Signature)
bindPortURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(apiIPAddress.String(), "19999"),
Host: net.JoinHostPort(gateway.String(), "19999"),
Path: "/bindPort",
RawQuery: queryParams.Encode(),
}

View File

@@ -7,7 +7,6 @@ import (
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/updater/openvpn"
@@ -65,7 +64,6 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
continue
}
server := models.Server{
VPN: vpn.OpenVPN,
Country: country,
City: city,
IPs: ips,

View File

@@ -8,7 +8,7 @@ import (
func (p *Provider) GetConnection(selection settings.ServerSelection, ipv6Supported bool) (
connection models.Connection, err error) {
defaults := utils.NewConnectionDefaults(443, 1194, 51820) //nolint:gomnd
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
return utils.GetConnection(p.Name(),
p.storage, selection, defaults, ipv6Supported, p.randSource)
}

View File

@@ -28,11 +28,10 @@ type logicalServer struct {
}
type physicalServer struct {
EntryIP netip.Addr `json:"EntryIP"`
ExitIP netip.Addr `json:"ExitIP"`
Domain string `json:"Domain"`
Status uint8 `json:"Status"`
X25519PublicKey string `json:"X25519PublicKey"`
EntryIP netip.Addr `json:"EntryIP"`
ExitIP netip.Addr `json:"ExitIP"`
Domain string `json:"Domain"`
Status uint8 `json:"Status"`
}
func fetchAPI(ctx context.Context, client *http.Client) (

View File

@@ -7,7 +7,7 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
type ipToServers map[string][2]models.Server // first server is OpenVPN, second is Wireguard.
type ipToServer map[string]models.Server
type features struct {
secureCore bool
@@ -16,50 +16,36 @@ type features struct {
stream bool
}
func (its ipToServers) add(country, region, city, name, hostname, wgPubKey string,
func (its ipToServer) add(country, region, city, name, hostname string,
free bool, entryIP netip.Addr, features features) {
key := entryIP.String()
servers, ok := its[key]
server, ok := its[key]
if ok {
return
}
baseServer := models.Server{
Country: country,
Region: region,
City: city,
ServerName: name,
Hostname: hostname,
Free: free,
SecureCore: features.secureCore,
Tor: features.tor,
PortForward: features.p2p,
Stream: features.stream,
IPs: []netip.Addr{entryIP},
}
openvpnServer := baseServer
openvpnServer.VPN = vpn.OpenVPN
openvpnServer.UDP = true
openvpnServer.TCP = true
servers[0] = openvpnServer
wireguardServer := baseServer
wireguardServer.VPN = vpn.Wireguard
wireguardServer.WgPubKey = wgPubKey
servers[1] = wireguardServer
its[key] = servers
server.VPN = vpn.OpenVPN
server.Country = country
server.Region = region
server.City = city
server.ServerName = name
server.Hostname = hostname
server.Free = free
server.SecureCore = features.secureCore
server.Tor = features.tor
server.PortForward = features.p2p
server.Stream = features.stream
server.UDP = true
server.TCP = true
server.IPs = []netip.Addr{entryIP}
its[key] = server
}
func (its ipToServers) toServersSlice() (serversSlice []models.Server) {
const vpnProtocols = 2
serversSlice = make([]models.Server, 0, vpnProtocols*len(its))
for _, servers := range its {
serversSlice = append(serversSlice, servers[0], servers[1])
func (its ipToServer) toServersSlice() (servers []models.Server) {
servers = make([]models.Server, 0, len(its))
for _, server := range its {
servers = append(servers, server)
}
return serversSlice
}
func (its ipToServers) numberOfServers() int {
const serversPerIP = 2
return len(its) * serversPerIP
return servers
}

View File

@@ -29,7 +29,7 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
common.ErrNotEnoughServers, count, minServers)
}
ipToServer := make(ipToServers, count)
ipToServer := make(ipToServer, count)
for _, logicalServer := range data.LogicalServers {
region := getStringValue(logicalServer.Region)
city := getStringValue(logicalServer.City)
@@ -65,7 +65,6 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
hostname := physicalServer.Domain
entryIP := physicalServer.EntryIP
wgPubKey := physicalServer.X25519PublicKey
// Note: for multi-hop use the server name or hostname
// instead of the country
@@ -75,11 +74,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
u.warner.Warn(warning)
}
ipToServer.add(country, region, city, name, hostname, wgPubKey, free, entryIP, features)
ipToServer.add(country, region, city, name, hostname, free, entryIP, features)
}
}
if ipToServer.numberOfServers() < minServers {
if len(ipToServer) < minServers {
return nil, fmt.Errorf("%w: %d and expected at least %d",
common.ErrNotEnoughServers, len(ipToServer), minServers)
}

View File

@@ -7,11 +7,12 @@ import (
"fmt"
"net/http"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/provider/surfshark/servers"
)
func addServersFromAPI(ctx context.Context, client *http.Client,
hts hostToServers) (err error) {
hts hostToServer) (err error) {
data, err := fetchAPI(ctx, client)
if err != nil {
return err
@@ -25,12 +26,14 @@ func addServersFromAPI(ctx context.Context, client *http.Client,
retroLoc := locationData.RetroLoc // empty string if the host has no retro-compatible region
tcp, udp := true, true // OpenVPN servers from API supports both TCP and UDP
hts.addOpenVPN(serverData.Host, serverData.Region, serverData.Country,
serverData.Location, retroLoc, tcp, udp)
const wgPubKey = ""
hts.add(serverData.Host, vpn.OpenVPN, serverData.Region, serverData.Country,
serverData.Location, retroLoc, wgPubKey, tcp, udp)
if serverData.PubKey != "" {
hts.addWireguard(serverData.Host, serverData.Region, serverData.Country,
serverData.Location, retroLoc, serverData.PubKey)
const wgTCP, wgUDP = false, false // unused
hts.add(serverData.Host, vpn.Wireguard, serverData.Region, serverData.Country,
serverData.Location, retroLoc, serverData.PubKey, wgTCP, wgUDP)
}
}

View File

@@ -24,9 +24,9 @@ func Test_addServersFromAPI(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
hts hostToServers
hts hostToServer
exchanges []httpExchange
expected hostToServers
expected hostToServer
err error
}{
"fetch API error": {
@@ -37,8 +37,8 @@ func Test_addServersFromAPI(t *testing.T) {
err: errors.New("HTTP status code not OK: 204 No Content"),
},
"success": {
hts: hostToServers{
"existinghost": []models.Server{{Hostname: "existinghost"}},
hts: hostToServer{
"existinghost": models.Server{Hostname: "existinghost"},
},
exchanges: []httpExchange{{
requestURL: "https://api.surfshark.com/v4/server/clusters/generic",
@@ -61,25 +61,19 @@ func Test_addServersFromAPI(t *testing.T) {
responseStatus: http.StatusOK,
responseBody: io.NopCloser(strings.NewReader(`[]`)),
}},
expected: map[string][]models.Server{
"existinghost": {{Hostname: "existinghost"}},
"host1": {{
VPN: vpn.OpenVPN,
expected: map[string]models.Server{
"existinghost": {Hostname: "existinghost"},
"host1": {
VPN: vpn.Both,
Region: "region1",
Country: "country1",
City: "location1",
Hostname: "host1",
TCP: true,
UDP: true,
}, {
VPN: vpn.Wireguard,
Region: "region1",
Country: "country1",
City: "location1",
Hostname: "host1",
WgPubKey: "pubKeyValue",
}},
"host2": {{
},
"host2": {
VPN: vpn.OpenVPN,
Region: "region2",
Country: "country1",
@@ -87,7 +81,7 @@ func Test_addServersFromAPI(t *testing.T) {
Hostname: "host2",
TCP: true,
UDP: true,
}},
},
},
},
}

View File

@@ -7,94 +7,62 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
type hostToServers map[string][]models.Server
type hostToServer map[string]models.Server
func (hts hostToServers) addOpenVPN(host, region, country, city,
retroLoc string, tcp, udp bool) {
// Check for existing server for this host and OpenVPN.
servers := hts[host]
for i, existingServer := range servers {
if existingServer.Hostname != host ||
existingServer.VPN != vpn.OpenVPN {
continue
}
// Update OpenVPN supported protocols and return
if !existingServer.TCP {
servers[i].TCP = tcp
}
if !existingServer.UDP {
servers[i].UDP = udp
func (hts hostToServer) add(host, vpnType, region, country, city,
retroLoc, wgPubKey string, openvpnTCP, openvpnUDP bool) {
server, ok := hts[host]
if !ok {
server := models.Server{
VPN: vpnType,
Region: region,
Country: country,
City: city,
RetroLoc: retroLoc,
Hostname: host,
WgPubKey: wgPubKey,
TCP: openvpnTCP,
UDP: openvpnUDP,
}
hts[host] = server
return
}
server := models.Server{
VPN: vpn.OpenVPN,
Region: region,
Country: country,
City: city,
RetroLoc: retroLoc,
Hostname: host,
TCP: tcp,
UDP: udp,
server.SetVPN(vpnType)
if vpnType == vpn.OpenVPN {
server.TCP = server.TCP || openvpnTCP
server.UDP = server.UDP || openvpnUDP
} else if wgPubKey != "" {
server.WgPubKey = wgPubKey
}
hts[host] = append(servers, server)
hts[host] = server
}
func (hts hostToServers) addWireguard(host, region, country, city, retroLoc,
wgPubKey string) {
// Check for existing server for this host and Wireguard.
servers := hts[host]
for _, existingServer := range servers {
if existingServer.Hostname == host &&
existingServer.VPN == vpn.Wireguard {
// No update necessary for Wireguard
return
}
}
server := models.Server{
VPN: vpn.Wireguard,
Region: region,
Country: country,
City: city,
RetroLoc: retroLoc,
Hostname: host,
WgPubKey: wgPubKey,
}
hts[host] = append(servers, server)
}
func (hts hostToServers) toHostsSlice() (hosts []string) {
const vpnServerTypes = 2 // OpenVPN + Wireguard
hosts = make([]string, 0, vpnServerTypes*len(hts))
func (hts hostToServer) toHostsSlice() (hosts []string) {
hosts = make([]string, 0, len(hts))
for host := range hts {
hosts = append(hosts, host)
}
return hosts
}
func (hts hostToServers) adaptWithIPs(hostToIPs map[string][]netip.Addr) {
for host, IPs := range hostToIPs {
servers := hts[host]
for i := range servers {
servers[i].IPs = IPs
}
hts[host] = servers
}
for host, servers := range hts {
if len(servers[0].IPs) == 0 {
func (hts hostToServer) adaptWithIPs(hostToIPs map[string][]netip.Addr) {
for host, server := range hts {
ips := hostToIPs[host]
if len(ips) == 0 {
delete(hts, host)
continue
}
server.IPs = ips
hts[host] = server
}
}
func (hts hostToServers) toServersSlice() (servers []models.Server) {
const vpnServerTypes = 2 // OpenVPN + Wireguard
servers = make([]models.Server, 0, vpnServerTypes*len(hts))
for _, serversForHost := range hts {
servers = append(servers, serversForHost...)
func (hts hostToServer) toServersSlice() (servers []models.Server) {
servers = make([]models.Server, 0, len(hts))
for _, server := range hts {
servers = append(servers, server)
}
return servers
}

View File

@@ -1,11 +1,12 @@
package updater
import (
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/provider/surfshark/servers"
)
// getRemainingServers finds extra servers not found in the API or in the ZIP file.
func getRemainingServers(hts hostToServers) {
func getRemainingServers(hts hostToServer) {
locationData := servers.LocationData()
hostnameToLocationLeft := hostToLocation(locationData)
for _, hostnameDone := range hts.toHostsSlice() {
@@ -15,7 +16,8 @@ func getRemainingServers(hts hostToServers) {
for hostname, locationData := range hostnameToLocationLeft {
// we assume the OpenVPN server supports both TCP and UDP
const tcp, udp = true, true
hts.addOpenVPN(hostname, locationData.Region, locationData.Country,
locationData.City, locationData.RetroLoc, tcp, udp)
const wgPubKey = ""
hts.add(hostname, vpn.OpenVPN, locationData.Region, locationData.Country,
locationData.City, locationData.RetroLoc, wgPubKey, tcp, udp)
}
}

View File

@@ -11,7 +11,7 @@ import (
func (u *Updater) FetchServers(ctx context.Context, minServers int) (
servers []models.Server, err error) {
hts := make(hostToServers)
hts := make(hostToServer)
err = addServersFromAPI(ctx, u.client, hts)
if err != nil {

View File

@@ -4,13 +4,14 @@ import (
"context"
"strings"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/provider/surfshark/servers"
"github.com/qdm12/gluetun/internal/updater/openvpn"
)
func addOpenVPNServersFromZip(ctx context.Context,
unzipper common.Unzipper, hts hostToServers) (
unzipper common.Unzipper, hts hostToServer) (
warnings []string, err error) {
const url = "https://my.surfshark.com/vpn/api/v1/server/configurations"
contents, err := unzipper.FetchAndExtract(ctx, url)
@@ -66,8 +67,9 @@ func addOpenVPNServersFromZip(ctx context.Context,
continue
}
hts.addOpenVPN(host, data.Region, data.Country, data.City,
data.RetroLoc, tcp, udp)
const wgPubKey = ""
hts.add(host, vpn.OpenVPN, data.Region, data.Country, data.City,
data.RetroLoc, wgPubKey, tcp, udp)
}
return warnings, nil

View File

@@ -2,17 +2,13 @@ package server
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
"github.com/qdm12/gluetun/internal/server/middlewares/log"
)
func newHandler(ctx context.Context, logger Logger, logging bool,
authSettings auth.Settings,
func newHandler(ctx context.Context, logger infoWarner, logging bool,
buildInfo models.BuildInformation,
vpnLooper VPNLooper,
pfGetter PortForwardedGetter,
@@ -21,7 +17,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
publicIPLooper PublicIPLoop,
storage Storage,
ipv6Supported bool,
) (httpHandler http.Handler, err error) {
) http.Handler {
handler := &handler{}
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
@@ -33,25 +29,16 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
authMiddleware, err := auth.New(authSettings, logger)
if err != nil {
return nil, fmt.Errorf("creating auth middleware: %w", err)
}
handlerWithLog := withLogMiddleware(handler, logger, logging)
handler.setLogEnabled = handlerWithLog.setEnabled
middlewares := []func(http.Handler) http.Handler{
authMiddleware,
log.New(logger, logging),
}
httpHandler = handler
for _, middleware := range middlewares {
httpHandler = middleware(httpHandler)
}
return httpHandler, nil
return handlerWithLog
}
type handler struct {
v0 http.Handler
v1 http.Handler
v0 http.Handler
v1 http.Handler
setLogEnabled func(enabled bool)
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

View File

@@ -1,4 +1,4 @@
package log
package server
import (
"net/http"
@@ -7,21 +7,18 @@ import (
"time"
)
func New(logger Logger, enabled bool) (
middleware func(http.Handler) http.Handler) {
return func(handler http.Handler) http.Handler {
return &logMiddleware{
childHandler: handler,
logger: logger,
timeNow: time.Now,
enabled: enabled,
}
func withLogMiddleware(childHandler http.Handler, logger infoer, enabled bool) *logMiddleware {
return &logMiddleware{
childHandler: childHandler,
logger: logger,
timeNow: time.Now,
enabled: enabled,
}
}
type logMiddleware struct {
childHandler http.Handler
logger Logger
logger infoer
timeNow func() time.Time
enabled bool
enabledMu sync.RWMutex
@@ -42,7 +39,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.RemoteAddr + " in " + duration.String())
}
func (m *logMiddleware) SetEnabled(enabled bool) {
func (m *logMiddleware) setEnabled(enabled bool) {
m.enabledMu.Lock()
defer m.enabledMu.Unlock()
m.enabled = enabled

View File

@@ -1,10 +1,8 @@
package server
type Logger interface {
Debugf(format string, args ...any)
infoer
warner
Warnf(format string, args ...any)
errorer
}

View File

@@ -1,36 +0,0 @@
package auth
import (
"crypto/sha256"
"crypto/subtle"
"net/http"
)
type apiKeyMethod struct {
apiKeyDigest [32]byte
}
func newAPIKeyMethod(apiKey string) *apiKeyMethod {
return &apiKeyMethod{
apiKeyDigest: sha256.Sum256([]byte(apiKey)),
}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (a *apiKeyMethod) equal(other authorizationChecker) bool {
otherTokenMethod, ok := other.(*apiKeyMethod)
if !ok {
return false
}
return a.apiKeyDigest == otherTokenMethod.apiKeyDigest
}
func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool {
xAPIKey := request.Header.Get("X-API-Key")
if xAPIKey == "" {
xAPIKey = request.URL.Query().Get("api_key")
}
xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey))
return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1
}

View File

@@ -1,37 +0,0 @@
package auth
import (
"crypto/sha256"
"crypto/subtle"
"net/http"
)
type basicAuthMethod struct {
authDigest [32]byte
}
func newBasicAuthMethod(username, password string) *basicAuthMethod {
return &basicAuthMethod{
authDigest: sha256.Sum256([]byte(username + password)),
}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (a *basicAuthMethod) equal(other authorizationChecker) bool {
otherBasicMethod, ok := other.(*basicAuthMethod)
if !ok {
return false
}
return a.authDigest == otherBasicMethod.authDigest
}
func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool {
username, password, ok := request.BasicAuth()
if !ok {
headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
return false
}
requestAuthDigest := sha256.Sum256([]byte(username + password))
return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1
}

View File

@@ -1,35 +0,0 @@
package auth
import (
"errors"
"fmt"
"os"
"github.com/pelletier/go-toml/v2"
)
// Read reads the toml file specified by the filepath given.
// If the file does not exist, it returns empty settings and no error.
func Read(filepath string) (settings Settings, err error) {
file, err := os.Open(filepath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return Settings{}, nil
}
return settings, fmt.Errorf("opening file: %w", err)
}
decoder := toml.NewDecoder(file)
decoder.DisallowUnknownFields()
err = decoder.Decode(&settings)
if err == nil {
return settings, nil
}
strictErr := new(toml.StrictMissingError)
ok := errors.As(err, &strictErr)
if !ok {
return settings, fmt.Errorf("toml decoding file: %w", err)
}
return settings, fmt.Errorf("toml decoding file: %w:\n%s",
strictErr, strictErr.String())
}

View File

@@ -1,80 +0,0 @@
package auth
import (
"io/fs"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Read reads the toml file specified by the filepath given.
func Test_Read(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
fileContent string
settings Settings
errMessage string
}{
"empty_file": {},
"malformed_toml": {
fileContent: "this is not a toml file",
errMessage: `toml decoding file: toml: expected character =`,
},
"unknown_field": {
fileContent: `unknown = "what is this"`,
errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct:
1| unknown = "what is this"
| ~~~~~~~ missing field`,
},
"filled_settings": {
fileContent: `[[roles]]
name = "public"
auth = "none"
routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"]
[[roles]]
name = "client"
auth = "apikey"
apikey = "xyz"
routes = ["GET /v1/vpn/status"]
`,
settings: Settings{
Roles: []Role{{
Name: "public",
Auth: AuthNone,
Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"},
}, {
Name: "client",
Auth: AuthAPIKey,
APIKey: "xyz",
Routes: []string{"GET /v1/vpn/status"},
}},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
filepath := tempDir + "/config.toml"
const permissions fs.FileMode = 0600
err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions)
require.NoError(t, err)
settings, err := Read(filepath)
assert.Equal(t, testCase.settings, settings)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,22 +0,0 @@
package auth
func andStrings(strings []string) (result string) {
return joinStrings(strings, "and")
}
func joinStrings(strings []string, lastJoin string) (result string) {
if len(strings) == 0 {
return ""
}
result = strings[0]
for i := 1; i < len(strings); i++ {
if i < len(strings)-1 {
result += ", " + strings[i]
} else {
result += " " + lastJoin + " " + strings[i]
}
}
return result
}

View File

@@ -1,6 +0,0 @@
package auth
type DebugLogger interface {
Debugf(format string, args ...any)
Warnf(format string, args ...any)
}

View File

@@ -1,8 +0,0 @@
package auth
import "net/http"
type authorizationChecker interface {
equal(other authorizationChecker) bool
isAuthorized(headers http.Header, request *http.Request) bool
}

View File

@@ -1,47 +0,0 @@
package auth
import (
"fmt"
)
type internalRole struct {
name string
checker authorizationChecker
}
func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) {
routeToRoles = make(map[string][]internalRole)
for _, role := range settings.Roles {
var checker authorizationChecker
switch role.Auth {
case AuthNone:
checker = newNoneMethod()
case AuthAPIKey:
checker = newAPIKeyMethod(role.APIKey)
case AuthBasic:
checker = newBasicAuthMethod(role.Username, role.Password)
default:
return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth)
}
iRole := internalRole{
name: role.Name,
checker: checker,
}
for _, route := range role.Routes {
checkerExists := false
for _, role := range routeToRoles[route] {
if role.checker.equal(iRole.checker) {
checkerExists = true
break
}
}
if checkerExists {
// even if the role name is different, if the checker is the same, skip it.
continue
}
routeToRoles[route] = append(routeToRoles[route], iRole)
}
}
return routeToRoles, nil
}

View File

@@ -1,60 +0,0 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
)
// Read reads the toml file specified by the filepath given.
func Test_settingsToLookupMap(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
settings Settings
routeToRoles map[string][]internalRole
errWrapped error
errMessage string
}{
"empty_settings": {
routeToRoles: map[string][]internalRole{},
},
"auth_method_not_supported": {
settings: Settings{
Roles: []Role{{Name: "a", Auth: "bad"}},
},
errWrapped: ErrMethodNotSupported,
errMessage: "authentication method not supported: bad",
},
"success": {
settings: Settings{
Roles: []Role{
{Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}},
{Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}},
},
},
routeToRoles: map[string][]internalRole{
"GET /path": {
{name: "a", checker: newNoneMethod()}, // deduplicated method
},
"PUT /path": {
{name: "b", checker: newNoneMethod()},
}},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
routeToRoles, err := settingsToLookupMap(testCase.settings)
assert.Equal(t, testCase.routeToRoles, routeToRoles)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -1,111 +0,0 @@
package auth
import (
"fmt"
"net/http"
)
func New(settings Settings, debugLogger DebugLogger) (
middleware func(http.Handler) http.Handler,
err error) {
routeToRoles, err := settingsToLookupMap(settings)
if err != nil {
return nil, fmt.Errorf("converting settings to lookup maps: %w", err)
}
//nolint:goconst
return func(handler http.Handler) http.Handler {
return &authHandler{
childHandler: handler,
routeToRoles: routeToRoles,
unprotectedRoutes: map[string]struct{}{
http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {},
http.MethodGet + " /v1/vpn/status": {},
http.MethodPut + " /v1/vpn/status": {},
// GET /v1/vpn/settings is protected by default
// PUT /v1/vpn/settings is protected by default
http.MethodGet + " /v1/openvpn/status": {},
http.MethodPut + " /v1/openvpn/status": {},
http.MethodGet + " /v1/openvpn/portforwarded": {},
// GET /v1/openvpn/settings is protected by default
http.MethodGet + " /v1/dns/status": {},
http.MethodPut + " /v1/dns/status": {},
http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {},
},
logger: debugLogger,
}
}, nil
}
type authHandler struct {
childHandler http.Handler
routeToRoles map[string][]internalRole
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove
logger DebugLogger
}
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
route := request.Method + " " + request.URL.Path
roles := h.routeToRoles[route]
if len(roles) == 0 {
h.logger.Debugf("no authentication role defined for route %s", route)
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
responseHeader := make(http.Header, 0)
for _, role := range roles {
if !role.checker.isAuthorized(responseHeader, request) {
continue
}
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
h.childHandler.ServeHTTP(writer, request)
return
}
// Flush out response headers if all roles failed to authenticate
for headerKey, headerValues := range responseHeader {
for _, headerValue := range headerValues {
writer.Header().Add(headerKey, headerValue)
}
}
allRoleNames := make([]string, len(roles))
for i, role := range roles {
allRoleNames[i] = role.name
}
h.logger.Debugf("access to route %s unauthorized after checking for roles %s",
route, andStrings(allRoleNames))
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
// TODO v3.41.0 remove
if role.name != "public" {
// custom role name, allow none authentication to be specified
return
}
_, isNoneChecker := role.checker.(*noneMethod)
if !isNoneChecker {
// not the none authentication method
return
}
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
if !isUnprotectedByDefault {
// route is not unprotected by default, so this is a user decision
return
}
h.logger.Warnf("route %s is unprotected by default, "+
"please set up authentication following the documentation at "+
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
"since this will become no longer publicly accessible after release v3.40.",
route)
}

View File

@@ -1,124 +0,0 @@
package auth
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_authHandler_ServeHTTP(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
settings Settings
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
requestMethod string
requestPath string
statusCode int
responseBody string
}{
"route_has_no_role": {
settings: Settings{
Roles: []Role{
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/b",
statusCode: http.StatusUnauthorized,
responseBody: "Unauthorized\n",
},
"authorized_unprotected_by_default": {
settings: Settings{
Roles: []Role{
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Warnf("route %s is unprotected by default, "+
"please set up authentication following the documentation at "+
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
"since this will become no longer publicly accessible after release v3.40.",
"GET /v1/vpn/status")
logger.EXPECT().Debugf("access to route %s authorized for role %s",
"GET /v1/vpn/status", "public")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/v1/vpn/status",
statusCode: http.StatusOK,
},
"authorized_none": {
settings: Settings{
Roles: []Role{
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
},
},
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
logger := NewMockDebugLogger(ctrl)
logger.EXPECT().Debugf("access to route %s authorized for role %s",
"GET /a", "role1")
return logger
},
requestMethod: http.MethodGet,
requestPath: "/a",
statusCode: http.StatusOK,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var debugLogger DebugLogger
if testCase.makeLogger != nil {
debugLogger = testCase.makeLogger(ctrl)
}
middleware, err := New(testCase.settings, debugLogger)
require.NoError(t, err)
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := middleware(childHandler)
server := httptest.NewServer(handler)
t.Cleanup(server.Close)
client := server.Client()
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
require.NoError(t, err)
request, err := http.NewRequestWithContext(context.Background(),
testCase.requestMethod, requestURL, nil)
require.NoError(t, err)
response, err := client.Do(request)
require.NoError(t, err)
t.Cleanup(func() {
err = response.Body.Close()
assert.NoError(t, err)
})
assert.Equal(t, testCase.statusCode, response.StatusCode)
body, err := io.ReadAll(response.Body)
require.NoError(t, err)
assert.Equal(t, testCase.responseBody, string(body))
})
}
}

View File

@@ -1,3 +0,0 @@
package auth
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger

View File

@@ -1,68 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger)
// Package auth is a generated GoMock package.
package auth
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockDebugLogger is a mock of DebugLogger interface.
type MockDebugLogger struct {
ctrl *gomock.Controller
recorder *MockDebugLoggerMockRecorder
}
// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger.
type MockDebugLoggerMockRecorder struct {
mock *MockDebugLogger
}
// NewMockDebugLogger creates a new mock instance.
func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger {
mock := &MockDebugLogger{ctrl: ctrl}
mock.recorder = &MockDebugLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder {
return m.recorder
}
// Debugf mocks base method.
func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Debugf", varargs...)
}
// Debugf indicates an expected call of Debugf.
func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...)
}
// Warnf mocks base method.
func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Warnf", varargs...)
}
// Warnf indicates an expected call of Warnf.
func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...)
}

View File

@@ -1,20 +0,0 @@
package auth
import "net/http"
type noneMethod struct{}
func newNoneMethod() *noneMethod {
return &noneMethod{}
}
// equal returns true if another auth checker is equal.
// This is used to deduplicate checkers for a particular route.
func (n *noneMethod) equal(other authorizationChecker) bool {
_, ok := other.(*noneMethod)
return ok
}
func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool {
return true
}

View File

@@ -1,131 +0,0 @@
package auth
import (
"errors"
"fmt"
"net/http"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/validate"
)
type Settings struct {
// Roles is a list of roles with their associated authentication
// and routes.
Roles []Role
}
func (s *Settings) SetDefaults() {
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
Name: "public",
Auth: "none",
Routes: []string{
http.MethodGet + " /openvpn/actions/restart",
http.MethodGet + " /unbound/actions/restart",
http.MethodGet + " /updater/restart",
http.MethodGet + " /v1/version",
http.MethodGet + " /v1/vpn/status",
http.MethodPut + " /v1/vpn/status",
http.MethodGet + " /v1/openvpn/status",
http.MethodPut + " /v1/openvpn/status",
http.MethodGet + " /v1/openvpn/portforwarded",
http.MethodGet + " /v1/dns/status",
http.MethodPut + " /v1/dns/status",
http.MethodGet + " /v1/updater/status",
http.MethodPut + " /v1/updater/status",
http.MethodGet + " /v1/publicip/ip",
},
}})
}
func (s Settings) Validate() (err error) {
for i, role := range s.Roles {
err = role.validate()
if err != nil {
return fmt.Errorf("role %s (%d of %d): %w",
role.Name, i+1, len(s.Roles), err)
}
}
return nil
}
const (
AuthNone = "none"
AuthAPIKey = "apikey"
AuthBasic = "basic"
)
// Role contains the role name, authentication method name and
// routes that the role can access.
type Role struct {
// Name is the role name and is only used for documentation
// and in the authentication middleware debug logs.
Name string
// Auth is the authentication method to use, which can be 'none' or 'apikey'.
Auth string
// APIKey is the API key to use when using the 'apikey' authentication.
APIKey string
// Username for HTTP Basic authentication method.
Username string
// Password for HTTP Basic authentication method.
Password string
// Routes is a list of routes that the role can access in the format
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
Routes []string
}
var (
ErrMethodNotSupported = errors.New("authentication method not supported")
ErrAPIKeyEmpty = errors.New("api key is empty")
ErrBasicUsernameEmpty = errors.New("username is empty")
ErrBasicPasswordEmpty = errors.New("password is empty")
ErrRouteNotSupported = errors.New("route not supported by the control server")
)
func (r Role) validate() (err error) {
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
if err != nil {
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
}
switch {
case r.Auth == AuthAPIKey && r.APIKey == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty)
case r.Auth == AuthBasic && r.Username == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty)
case r.Auth == AuthBasic && r.Password == "":
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty)
}
for i, route := range r.Routes {
_, ok := validRoutes[route]
if !ok {
return fmt.Errorf("route %d of %d: %w: %s",
i+1, len(r.Routes), ErrRouteNotSupported, route)
}
}
return nil
}
// WARNING: do not mutate programmatically.
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {},
http.MethodGet + " /v1/vpn/status": {},
http.MethodPut + " /v1/vpn/status": {},
http.MethodGet + " /v1/vpn/settings": {},
http.MethodPut + " /v1/vpn/settings": {},
http.MethodGet + " /v1/openvpn/status": {},
http.MethodPut + " /v1/openvpn/status": {},
http.MethodGet + " /v1/openvpn/portforwarded": {},
http.MethodGet + " /v1/openvpn/settings": {},
http.MethodGet + " /v1/dns/status": {},
http.MethodPut + " /v1/dns/status": {},
http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {},
}

View File

@@ -1,5 +0,0 @@
package log
type Logger interface {
Info(message string)
}

View File

@@ -6,31 +6,17 @@ import (
"github.com/qdm12/gluetun/internal/httpserver"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
)
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
ipv6Supported bool) (
server *httpserver.Server, err error) {
authSettings, err := auth.Read(authConfigPath)
if err != nil {
return nil, fmt.Errorf("reading auth settings: %w", err)
}
authSettings.SetDefaults()
err = authSettings.Validate()
if err != nil {
return nil, fmt.Errorf("validating auth settings: %w", err)
}
handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
handler := newHandler(ctx, logger, logEnabled, buildInfo,
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported)
if err != nil {
return nil, fmt.Errorf("creating handler: %w", err)
}
httpServerSettings := httpserver.Settings{
Address: address,

View File

@@ -50,11 +50,11 @@ func filterServer(server models.Server,
selection settings.ServerSelection) (filtered bool) {
// Note each condition is split to make sure
// we have full testing coverage.
if server.VPN != selection.VPN {
if server.VPN != vpn.Both && server.VPN != selection.VPN {
return true
}
if server.VPN != vpn.Wireguard &&
if selection.VPN == vpn.OpenVPN &&
filterByProtocol(selection, server.TCP, server.UDP) {
return true
}
@@ -119,8 +119,6 @@ func filterServer(server models.Server,
return true
}
// TODO filter port forward server for PIA
return false
}

View File

@@ -128,31 +128,6 @@ func noServerFoundError(selection settings.ServerSelection) (err error) {
messageParts = append(messageParts, "premium tier only")
}
if *selection.StreamOnly {
messageParts = append(messageParts, "stream only")
}
if *selection.MultiHopOnly {
messageParts = append(messageParts, "multihop only")
}
if *selection.PortForwardOnly {
messageParts = append(messageParts, "port forwarding only")
}
if *selection.SecureCoreOnly {
messageParts = append(messageParts, "secure core only")
}
if *selection.TorOnly {
messageParts = append(messageParts, "tor only")
}
if selection.TargetIP.IsValid() {
messageParts = append(messageParts,
"target ip address "+selection.TargetIP.String())
}
message := "for " + strings.Join(messageParts, "; ")
return fmt.Errorf("%w: %s", ErrNoServerFound, message)

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/http"
"sort"
"time"
"github.com/qdm12/gluetun/internal/models"
@@ -50,17 +49,13 @@ func getLatestRelease(ctx context.Context, client *http.Client) (tagName, name s
if err != nil {
return "", "", time, err
}
// Sort releases by tag names (semver)
sort.Slice(releases, func(i, j int) bool {
return releases[i].TagName > releases[j].TagName
})
for _, release := range releases {
if release.Prerelease {
continue
}
return release.TagName, release.Name, release.PublishedAt, nil
}
return "", "", time, fmt.Errorf("%w", errReleaseNotFound)
return "", "", time, errReleaseNotFound
}
var errCommitNotFound = errors.New("commit not found")