Compare commits
1 Commits
v3.39.1
...
v4-storage
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a6e8d74d6 |
@@ -37,12 +37,12 @@
|
||||
"go.useLanguageServer": true,
|
||||
"[go]": {
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit"
|
||||
"source.organizeImports": true
|
||||
}
|
||||
},
|
||||
"[go.mod]": {
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit"
|
||||
"source.organizeImports": true
|
||||
}
|
||||
},
|
||||
"gopls": {
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/config.yml
vendored
4
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,10 +1,10 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Report a Wiki issue
|
||||
url: https://github.com/qdm12/gluetun-wiki/issues/new/choose
|
||||
url: https://github.com/qdm12/gluetun-wiki/issues/new
|
||||
about: Please create an issue on the gluetun-wiki repository.
|
||||
- name: Configuration help?
|
||||
url: https://github.com/qdm12/gluetun/discussions/new/choose
|
||||
url: https://github.com/qdm12/gluetun/discussions/new
|
||||
about: Please create a Github discussion.
|
||||
- name: Unraid template issue
|
||||
url: https://github.com/qdm12/gluetun/discussions/550
|
||||
|
||||
8
.github/labels.yml
vendored
8
.github/labels.yml
vendored
@@ -3,9 +3,6 @@
|
||||
- name: "Status: 🔴 Blocked"
|
||||
color: "f7d692"
|
||||
description: "Blocked by another issue or pull request"
|
||||
- name: "Status: 📌 Before next release"
|
||||
color: "f7d692"
|
||||
description: "Has to be done before the next release"
|
||||
- name: "Status: 🔒 After next release"
|
||||
color: "f7d692"
|
||||
description: "Will be done after the next release"
|
||||
@@ -39,8 +36,6 @@
|
||||
# VPN providers
|
||||
- name: "☁️ AirVPN"
|
||||
color: "cfe8d4"
|
||||
- name: "☁️ Custom"
|
||||
color: "cfe8d4"
|
||||
- name: "☁️ Cyberghost"
|
||||
color: "cfe8d4"
|
||||
- name: "☁️ HideMyAss"
|
||||
@@ -96,9 +91,6 @@
|
||||
- name: "Category: Maintenance ⛓️"
|
||||
description: "Anything related to code or other maintenance"
|
||||
color: "ffc7ea"
|
||||
- name: "Category: Logs 📚"
|
||||
description: "Something to change in logs"
|
||||
color: "ffc7ea"
|
||||
- name: "Category: Good idea 🎯"
|
||||
description: "This is a good idea, judged by the maintainers"
|
||||
color: "ffc7ea"
|
||||
|
||||
@@ -120,3 +120,9 @@ linters:
|
||||
- wastedassign
|
||||
- whitespace
|
||||
- zerologlint
|
||||
|
||||
run:
|
||||
skip-dirs:
|
||||
- .devcontainer
|
||||
- .github
|
||||
- doc
|
||||
|
||||
@@ -197,7 +197,6 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
# Control server
|
||||
HTTP_CONTROL_SERVER_LOG=on \
|
||||
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
||||
HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \
|
||||
# Server data updater
|
||||
UPDATER_PERIOD=0 \
|
||||
UPDATER_MIN_RATIO=0.8 \
|
||||
|
||||
@@ -60,8 +60,8 @@ Lightweight swiss-knife-like VPN client to multiple VPN service providers
|
||||
- Supports: **AirVPN**, **Cyberghost**, **ExpressVPN**, **FastestVPN**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **SlickVPN**, **Surfshark**, **TorGuard**, **VPNSecure.me**, **VPNUnlimited**, **Vyprvpn**, **WeVPN**, **Windscribe** servers
|
||||
- Supports OpenVPN for all providers listed
|
||||
- Supports Wireguard both kernelspace and userspace
|
||||
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **ProtonVPN**, **Surfshark** and **Windscribe**
|
||||
- For **Cyberghost**, **Private Internet Access**, **PrivateVPN**, **PureVPN**, **Torguard**, **VPN Unlimited**, **VyprVPN** and **WeVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
||||
- For **AirVPN**, **FastestVPN**, **Ivpn**, **Mullvad**, **NordVPN**, **Perfect privacy**, **Surfshark** and **Windscribe**
|
||||
- For **ProtonVPN**, **PureVPN**, **Torguard**, **VPN Unlimited** and **WeVPN** using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
||||
- For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/custom.md)
|
||||
- More in progress, see [#134](https://github.com/qdm12/gluetun/issues/134)
|
||||
- DNS over TLS baked in with service provider(s) of your choice
|
||||
@@ -73,7 +73,7 @@ Lightweight swiss-knife-like VPN client to multiple VPN service providers
|
||||
- [Connect other containers to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-container-to-gluetun.md)
|
||||
- [Connect LAN devices to it](https://github.com/qdm12/gluetun-wiki/blob/main/setup/connect-a-lan-device-to-gluetun.md)
|
||||
- Compatible with amd64, i686 (32 bit), **ARM** 64 bit, ARM 32 bit v6 and v7, and even ppc64le 🎆
|
||||
- Custom VPN server side port forwarding for [Perfect Privacy](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/perfect-privacy.md#vpn-server-port-forwarding), [Private Internet Access](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/private-internet-access.md#vpn-server-port-forwarding) and [ProtonVPN](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/protonvpn.md#vpn-server-port-forwarding)
|
||||
- [Custom VPN server side port forwarding for Private Internet Access](https://github.com/qdm12/gluetun-wiki/blob/main/setup/providers/private-internet-access.md#vpn-server-port-forwarding)
|
||||
- Possibility of split horizon DNS by selecting multiple DNS over TLS providers
|
||||
- Unbound subprogram drops root privileges once launched
|
||||
- Can work as a Kubernetes sidecar container, thanks @rorph
|
||||
@@ -84,7 +84,7 @@ Lightweight swiss-knife-like VPN client to multiple VPN service providers
|
||||
|
||||
Go to the [Wiki](https://github.com/qdm12/gluetun-wiki)!
|
||||
|
||||
[🐛 Found a bug in the Wiki?!](https://github.com/qdm12/gluetun-wiki/issues/new/choose)
|
||||
[🐛 Found a bug in the Wiki?!](https://github.com/qdm12/gluetun-wiki/issues/new)
|
||||
|
||||
Here's a docker-compose.yml for the laziest:
|
||||
|
||||
|
||||
@@ -161,14 +161,12 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
return cli.Update(ctx, args[2:], logger)
|
||||
case "format-servers":
|
||||
return cli.FormatServers(args[2:])
|
||||
case "genkey":
|
||||
return cli.GenKey(args[2:])
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
||||
}
|
||||
}
|
||||
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2023-07-01T00:00:00Z")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -178,8 +176,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
Emails: []string{"quentin.mcgaw@gmail.com"},
|
||||
Version: buildInfo.Version,
|
||||
Commit: buildInfo.Commit,
|
||||
Created: buildInfo.Created,
|
||||
Announcement: "All control server routes will become private by default after the v3.41.0 release",
|
||||
BuildDate: buildInfo.Created,
|
||||
Announcement: "Wiki moved to https://github.com/qdm12/gluetun-wiki",
|
||||
AnnounceExp: announcementExp,
|
||||
// Sponsor information
|
||||
PaypalUser: "qmcgaw",
|
||||
@@ -476,7 +474,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
||||
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
|
||||
logger.New(log.SetComponent("http server")),
|
||||
allSettings.ControlServer.AuthFilePath,
|
||||
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
|
||||
storage, ipv6Supported)
|
||||
if err != nil {
|
||||
@@ -598,7 +595,6 @@ type clier interface {
|
||||
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
|
||||
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
|
||||
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
|
||||
GenKey(args []string) error
|
||||
}
|
||||
|
||||
type Tun interface {
|
||||
|
||||
5
go.mod
5
go.mod
@@ -3,17 +3,16 @@ module github.com/qdm12/gluetun
|
||||
go 1.22
|
||||
|
||||
require (
|
||||
github.com/breml/rootcerts v0.2.17
|
||||
github.com/breml/rootcerts v0.2.16
|
||||
github.com/fatih/color v1.17.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/klauspost/compress v1.17.8
|
||||
github.com/klauspost/pgzip v1.2.6
|
||||
github.com/pelletier/go-toml/v2 v2.2.2
|
||||
github.com/qdm12/dns v1.11.0
|
||||
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
|
||||
github.com/qdm12/gosettings v0.4.2
|
||||
github.com/qdm12/goshutdown v0.3.0
|
||||
github.com/qdm12/gosplash v0.2.0
|
||||
github.com/qdm12/gosplash v0.1.0
|
||||
github.com/qdm12/gotree v0.2.0
|
||||
github.com/qdm12/log v0.1.0
|
||||
github.com/qdm12/ss-server v0.6.0
|
||||
|
||||
18
go.sum
18
go.sum
@@ -4,8 +4,8 @@ github.com/alcortesm/tgz v0.0.0-20161220082320-9c5fe88206d7/go.mod h1:6zEj6s6u/g
|
||||
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
|
||||
github.com/breml/rootcerts v0.2.17 h1:0/M2BE2Apw0qEJCXDOkaiu7d5Sx5ObNfe1BkImJ4u1I=
|
||||
github.com/breml/rootcerts v0.2.17/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
|
||||
github.com/breml/rootcerts v0.2.16 h1:yN1TGvicfHx8dKz3OQRIrx/5nE/iN3XT1ibqGbd6urc=
|
||||
github.com/breml/rootcerts v0.2.16/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -83,8 +83,6 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
||||
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -95,12 +93,14 @@ github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8
|
||||
github.com/qdm12/golibs v0.0.0-20210723175634-a75ca7fd74c2/go.mod h1:6aRbg4Z/bTbm9JfxsGXfWKHi7zsOvPfUTK1S5HuAFKg=
|
||||
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6 h1:bge5AL7cjHJMPz+5IOz5yF01q/l8No6+lIEBieA8gMg=
|
||||
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6/go.mod h1:6aRbg4Z/bTbm9JfxsGXfWKHi7zsOvPfUTK1S5HuAFKg=
|
||||
github.com/qdm12/gosettings v0.4.1 h1:c7+14jO1Y2kFXBCUfS2+QE2NgwTKfzcdJzGEFRItCI8=
|
||||
github.com/qdm12/gosettings v0.4.1/go.mod h1:uItKwGXibJp2pQ0am6MBKilpjfvYTGiH+zXHd10jFj8=
|
||||
github.com/qdm12/gosettings v0.4.2 h1:Gb39NScPr7OQV+oy0o1OD7A121udITDJuUGa7ljDF58=
|
||||
github.com/qdm12/gosettings v0.4.2/go.mod h1:CPrt2YC4UsURTrslmhxocVhMCW03lIrqdH2hzIf5prg=
|
||||
github.com/qdm12/goshutdown v0.3.0 h1:pqBpJkdwlZlfTEx4QHtS8u8CXx6pG0fVo6S1N0MpSEM=
|
||||
github.com/qdm12/goshutdown v0.3.0/go.mod h1:EqZ46No00kCTZ5qzdd3qIzY6ayhMt24QI8Mh8LVQYmM=
|
||||
github.com/qdm12/gosplash v0.2.0 h1:DOxCEizbW6ZG+FgpH2oK1atT6bM8MHL9GZ2ywSS4zZY=
|
||||
github.com/qdm12/gosplash v0.2.0/go.mod h1:k+1PzhO0th9cpX4q2Nneu4xTsndXqrM/x7NTIYmJ4jo=
|
||||
github.com/qdm12/gosplash v0.1.0 h1:Sfl+zIjFZFP7b0iqf2l5UkmEY97XBnaKkH3FNY6Gf7g=
|
||||
github.com/qdm12/gosplash v0.1.0/go.mod h1:+A3fWW4/rUeDXhY3ieBzwghKdnIPFJgD8K3qQkenJlw=
|
||||
github.com/qdm12/gotree v0.2.0 h1:+58ltxkNLUyHtATFereAcOjBVfY6ETqRex8XK90Fb/c=
|
||||
github.com/qdm12/gotree v0.2.0/go.mod h1:1SdFaqKZuI46U1apbXIf25pDMNnrPuYLEqMF/qL4lY4=
|
||||
github.com/qdm12/log v0.1.0 h1:jYBd/xscHYpblzZAd2kjZp2YmuYHjAAfbTViJWxoPTw=
|
||||
@@ -115,16 +115,10 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
|
||||
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"flag"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (c *CLI) GenKey(args []string) (err error) {
|
||||
flagSet := flag.NewFlagSet("genkey", flag.ExitOnError)
|
||||
err = flagSet.Parse(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing flags: %w", err)
|
||||
}
|
||||
|
||||
const keyLength = 128 / 8
|
||||
keyBytes := make([]byte, keyLength)
|
||||
|
||||
_, _ = rand.Read(keyBytes)
|
||||
|
||||
key := base58Encode(keyBytes)
|
||||
fmt.Println(key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func base58Encode(data []byte) string {
|
||||
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||
const radix = 58
|
||||
|
||||
zcount := 0
|
||||
for zcount < len(data) && data[zcount] == 0 {
|
||||
zcount++
|
||||
}
|
||||
|
||||
// integer simplification of ceil(log(256)/log(58))
|
||||
ceilLog256Div58 := (len(data)-zcount)*555/406 + 1 //nolint:gomnd
|
||||
size := zcount + ceilLog256Div58
|
||||
|
||||
output := make([]byte, size)
|
||||
|
||||
high := size - 1
|
||||
for _, b := range data {
|
||||
i := size - 1
|
||||
for carry := uint32(b); i > high || carry != 0; i-- {
|
||||
carry += 256 * uint32(output[i]) //nolint:gomnd
|
||||
output[i] = byte(carry % radix)
|
||||
carry /= radix
|
||||
}
|
||||
high = i
|
||||
}
|
||||
|
||||
// Determine the additional "zero-gap" in the output buffer
|
||||
additionalZeroGapEnd := zcount
|
||||
for additionalZeroGapEnd < size && output[additionalZeroGapEnd] == 0 {
|
||||
additionalZeroGapEnd++
|
||||
}
|
||||
|
||||
val := output[additionalZeroGapEnd-zcount:]
|
||||
size = len(val)
|
||||
for i := range val {
|
||||
output[i] = alphabet[val[i]]
|
||||
}
|
||||
|
||||
return string(output[:size])
|
||||
}
|
||||
@@ -39,7 +39,6 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) {
|
||||
providers.Ivpn,
|
||||
providers.Mullvad,
|
||||
providers.Nordvpn,
|
||||
providers.Protonvpn,
|
||||
providers.Surfshark,
|
||||
providers.Windscribe,
|
||||
}
|
||||
|
||||
@@ -19,11 +19,6 @@ type ControlServer struct {
|
||||
// Log can be true or false to enable logging on requests.
|
||||
// It cannot be nil in the internal state.
|
||||
Log *bool
|
||||
// AuthFilePath is the path to the file containing the authentication
|
||||
// configuration for the middleware.
|
||||
// It cannot be empty in the internal state and defaults to
|
||||
// /gluetun/auth/config.toml.
|
||||
AuthFilePath string
|
||||
}
|
||||
|
||||
func (c ControlServer) validate() (err error) {
|
||||
@@ -49,9 +44,8 @@ func (c ControlServer) validate() (err error) {
|
||||
|
||||
func (c *ControlServer) copy() (copied ControlServer) {
|
||||
return ControlServer{
|
||||
Address: gosettings.CopyPointer(c.Address),
|
||||
Log: gosettings.CopyPointer(c.Log),
|
||||
AuthFilePath: c.AuthFilePath,
|
||||
Address: gosettings.CopyPointer(c.Address),
|
||||
Log: gosettings.CopyPointer(c.Log),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,13 +55,11 @@ func (c *ControlServer) copy() (copied ControlServer) {
|
||||
func (c *ControlServer) overrideWith(other ControlServer) {
|
||||
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
|
||||
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
|
||||
c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath)
|
||||
}
|
||||
|
||||
func (c *ControlServer) setDefaults() {
|
||||
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
|
||||
c.Log = gosettings.DefaultPointer(c.Log, true)
|
||||
c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml")
|
||||
}
|
||||
|
||||
func (c ControlServer) String() string {
|
||||
@@ -78,7 +70,6 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("Control server settings:")
|
||||
node.Appendf("Listening address: %s", *c.Address)
|
||||
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
|
||||
node.Appendf("Authentication file path: %s", c.AuthFilePath)
|
||||
return node
|
||||
}
|
||||
|
||||
@@ -87,10 +78,6 @@ func (c *ControlServer) read(r *reader.Reader) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
|
||||
|
||||
c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -191,19 +191,11 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
||||
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
||||
}
|
||||
|
||||
if vpnServiceProvider == providers.Custom {
|
||||
switch len(settings.Names) {
|
||||
case 0:
|
||||
case 1:
|
||||
// Allow a single name to be specified for the custom provider in case
|
||||
// the user wants to use VPN server side port forwarding with PIA
|
||||
// which requires a server name for TLS verification.
|
||||
filterChoices.Names = settings.Names
|
||||
default:
|
||||
return fmt.Errorf("%w: %d names specified instead of "+
|
||||
"0 or 1 for the custom provider",
|
||||
ErrNameNotValid, len(settings.Names))
|
||||
}
|
||||
if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 {
|
||||
// Allow a single name to be specified for the custom provider in case
|
||||
// the user wants to use VPN server side port forwarding with PIA
|
||||
// which requires a server name for TLS verification.
|
||||
filterChoices.Names = settings.Names
|
||||
}
|
||||
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
|
||||
if err != nil {
|
||||
@@ -237,8 +229,6 @@ func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string)
|
||||
switch {
|
||||
case *settings.OwnedOnly && vpnServiceProvider != providers.Mullvad:
|
||||
return fmt.Errorf("%w", ErrOwnedOnlyNotSupported)
|
||||
case vpnServiceProvider == providers.Protonvpn && *settings.FreeOnly && *settings.PortForwardOnly:
|
||||
return fmt.Errorf("%w: together with free only filter", ErrPortForwardOnlyNotSupported)
|
||||
case *settings.StreamOnly &&
|
||||
!helpers.IsOneOf(vpnServiceProvider, providers.Protonvpn, providers.VPNUnlimited):
|
||||
return fmt.Errorf("%w", ErrStreamOnlyNotSupported)
|
||||
|
||||
@@ -78,8 +78,7 @@ func Test_Settings_String(t *testing.T) {
|
||||
| └── Enabled: no
|
||||
├── Control server settings:
|
||||
| ├── Listening address: :8000
|
||||
| ├── Logging: yes
|
||||
| └── Authentication file path: /gluetun/auth/config.toml
|
||||
| └── Logging: yes
|
||||
├── OS Alpine settings:
|
||||
| ├── Process UID: 1000
|
||||
| └── Process GID: 1000
|
||||
|
||||
@@ -62,7 +62,6 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
|
||||
providers.Ivpn,
|
||||
providers.Mullvad,
|
||||
providers.Nordvpn,
|
||||
providers.Protonvpn,
|
||||
providers.Surfshark,
|
||||
providers.Windscribe,
|
||||
) {
|
||||
@@ -174,15 +173,10 @@ func (w *Wireguard) overrideWith(other Wireguard) {
|
||||
func (w *Wireguard) setDefaults(vpnProvider string) {
|
||||
w.PrivateKey = gosettings.DefaultPointer(w.PrivateKey, "")
|
||||
w.PreSharedKey = gosettings.DefaultPointer(w.PreSharedKey, "")
|
||||
switch vpnProvider {
|
||||
case providers.Nordvpn:
|
||||
if vpnProvider == providers.Nordvpn {
|
||||
defaultNordVPNAddress := netip.AddrFrom4([4]byte{10, 5, 0, 2})
|
||||
defaultNordVPNPrefix := netip.PrefixFrom(defaultNordVPNAddress, defaultNordVPNAddress.BitLen())
|
||||
w.Addresses = gosettings.DefaultSlice(w.Addresses, []netip.Prefix{defaultNordVPNPrefix})
|
||||
case providers.Protonvpn:
|
||||
defaultAddress := netip.AddrFrom4([4]byte{10, 2, 0, 2})
|
||||
defaultPrefix := netip.PrefixFrom(defaultAddress, defaultAddress.BitLen())
|
||||
w.Addresses = gosettings.DefaultSlice(w.Addresses, []netip.Prefix{defaultPrefix})
|
||||
}
|
||||
defaultAllowedIPs := []netip.Prefix{
|
||||
netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||
|
||||
@@ -39,8 +39,8 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
// Validate EndpointIP
|
||||
switch vpnProvider {
|
||||
case providers.Airvpn, providers.Fastestvpn, providers.Ivpn,
|
||||
providers.Mullvad, providers.Nordvpn, providers.Protonvpn,
|
||||
providers.Surfshark, providers.Windscribe:
|
||||
providers.Mullvad, providers.Nordvpn, providers.Surfshark,
|
||||
providers.Windscribe:
|
||||
// endpoint IP addresses are baked in
|
||||
case providers.Custom:
|
||||
if !w.EndpointIP.IsValid() || w.EndpointIP.IsUnspecified() {
|
||||
@@ -57,8 +57,7 @@ func (w WireguardSelection) validate(vpnProvider string) (err error) {
|
||||
return fmt.Errorf("%w", ErrWireguardEndpointPortNotSet)
|
||||
}
|
||||
// EndpointPort cannot be set
|
||||
case providers.Fastestvpn, providers.Nordvpn,
|
||||
providers.Protonvpn, providers.Surfshark:
|
||||
case providers.Fastestvpn, providers.Surfshark, providers.Nordvpn:
|
||||
if *w.EndpointPort != 0 {
|
||||
return fmt.Errorf("%w", ErrWireguardEndpointPortSet)
|
||||
}
|
||||
|
||||
@@ -3,4 +3,13 @@ package vpn
|
||||
const (
|
||||
OpenVPN = "openvpn"
|
||||
Wireguard = "wireguard"
|
||||
Both = "openvpn+wireguard"
|
||||
)
|
||||
|
||||
func IsWireguard(s string) bool {
|
||||
return s == Wireguard || s == Both
|
||||
}
|
||||
|
||||
func IsOpenVPN(s string) bool {
|
||||
return s == OpenVPN || s == Both
|
||||
}
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isDeleteMatchInstruction returns true if the iptables instruction
|
||||
// is a delete instruction by rule matching. It returns false if the
|
||||
// instruction is a delete instruction by line number, or not a delete
|
||||
// instruction.
|
||||
func isDeleteMatchInstruction(instruction string) bool {
|
||||
fields := strings.Fields(instruction)
|
||||
for i, field := range fields {
|
||||
switch {
|
||||
case field != "-D" && field != "--delete": //nolint:goconst
|
||||
continue
|
||||
case i == len(fields)-1: // malformed: missing chain name
|
||||
return false
|
||||
case i == len(fields)-2: // chain name is last field
|
||||
return true
|
||||
default:
|
||||
// chain name is fields[i+1]
|
||||
const base, bitLength = 10, 16
|
||||
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
|
||||
return err != nil // not a line number
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
|
||||
runner Runner, logger Logger) (err error) {
|
||||
targetRule, err := parseIptablesInstruction(instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, iptablesBinary,
|
||||
targetRule, runner, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding iptables chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
logger.Debug("rule matching \"" + instruction + "\" not found")
|
||||
return nil
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
|
||||
instruction, lineNumber))
|
||||
|
||||
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
|
||||
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
|
||||
logger.Debug(cmd.String())
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||
if output != "" {
|
||||
err = fmt.Errorf("%w: %s", err, output)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findLineNumber finds the line number of an iptables rule.
|
||||
// It returns 0 if the rule is not found.
|
||||
func findLineNumber(ctx context.Context, iptablesBinary string,
|
||||
instruction iptablesInstruction, runner Runner, logger Logger) (
|
||||
lineNumber uint16, err error) {
|
||||
listFlags := []string{"-t", instruction.table, "-L", instruction.chain,
|
||||
"--line-numbers", "-n", "-v"}
|
||||
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
|
||||
logger.Debug(cmd.String())
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||
if output != "" {
|
||||
err = fmt.Errorf("%w: %s", err, output)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
chain, err := parseChain(output)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing chain list: %w", err)
|
||||
}
|
||||
|
||||
for _, rule := range chain.rules {
|
||||
if instruction.equalToRule(instruction.table, chain.name, rule) {
|
||||
return rule.lineNumber, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
@@ -1,188 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_isDeleteMatchInstruction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
isDeleteMatch bool
|
||||
}{
|
||||
"not_delete": {
|
||||
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
|
||||
},
|
||||
"malformed_missing_chain_name": {
|
||||
instruction: "-t nat -D",
|
||||
},
|
||||
"delete_chain_name_last_field": {
|
||||
instruction: "-t nat --delete PREROUTING",
|
||||
isDeleteMatch: true,
|
||||
},
|
||||
"delete_match": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
|
||||
isDeleteMatch: true,
|
||||
},
|
||||
"delete_line_number_last_field": {
|
||||
instruction: "-t nat -D PREROUTING 2",
|
||||
},
|
||||
"delete_line_number": {
|
||||
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
|
||||
},
|
||||
}
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
|
||||
|
||||
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
|
||||
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
|
||||
"^--line-numbers$", "^-n$", "^-v$")
|
||||
}
|
||||
|
||||
func Test_deleteIPTablesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const iptablesBinary = "/sbin/iptables"
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
makeRunner func(ctrl *gomock.Controller) *MockRunner
|
||||
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
||||
"fields count 1 is not even: \"invalid\"",
|
||||
},
|
||||
"list_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().
|
||||
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||
},
|
||||
"rule_not_found": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
|
||||
num pkts bytes target prot opt in out source destination
|
||||
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
|
||||
nil)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
|
||||
return logger
|
||||
},
|
||||
},
|
||||
"rule_found_delete_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||
"num pkts bytes target prot opt in out source destination \n"+
|
||||
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||
},
|
||||
"rule_found_delete_success": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||
"num pkts bytes target prot opt in out source destination \n"+
|
||||
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
instruction := testCase.instruction
|
||||
var runner *MockRunner
|
||||
if testCase.makeRunner != nil {
|
||||
runner = testCase.makeRunner(ctrl)
|
||||
}
|
||||
var logger *MockLogger
|
||||
if testCase.makeLogger != nil {
|
||||
logger = testCase.makeLogger(ctrl)
|
||||
}
|
||||
|
||||
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import "github.com/qdm12/golibs/command"
|
||||
|
||||
type Runner interface {
|
||||
Run(cmd command.ExecCmd) (output string, err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// empty string path is returned.
|
||||
func findIP6tablesSupported(ctx context.Context, runner command.Runner) (
|
||||
ip6tablesPath string, err error) {
|
||||
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables", "ip6tables-nft", "ip6tables-legacy")
|
||||
ip6tablesPath, err = checkIptablesSupport(ctx, runner, "ip6tables-legacy", "ip6tables", "ip6tables-nft")
|
||||
if errors.Is(err, ErrIPTablesNotSupported) {
|
||||
return "", nil
|
||||
} else if err != nil {
|
||||
@@ -40,14 +40,10 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
c.logger.Debug(c.ip6Tables + " " + instruction)
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ip6Tables, instruction, output, err)
|
||||
@@ -59,7 +55,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP": //nolint:goconst
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
}
|
||||
|
||||
@@ -70,14 +70,10 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
c.logger.Debug(c.ipTables + " " + instruction)
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ipTables, instruction, output, err)
|
||||
@@ -147,7 +143,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
|
||||
defaultInterface string, connection models.Connection, remove bool) error {
|
||||
protocol := connection.Protocol
|
||||
if protocol == "tcp-client" {
|
||||
protocol = "tcp" //nolint:goconst
|
||||
protocol = "tcp"
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
|
||||
@@ -214,14 +210,10 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
|
||||
}
|
||||
|
||||
err = c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -d 127.0.0.1 -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -d 127.0.0.1 -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
|
||||
@@ -229,14 +221,10 @@ func (c *Config) redirectPort(ctx context.Context, intf string,
|
||||
}
|
||||
|
||||
err = c.runIP6tablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -d ::1 -p tcp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p tcp -m tcp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
fmt.Sprintf("-t nat %s PREROUTING %s -d ::1 -p udp --dport %d -j REDIRECT --to-ports %d",
|
||||
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
|
||||
fmt.Sprintf("%s INPUT %s -p udp -m udp --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), interfaceFlag, destinationPort),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
|
||||
|
||||
@@ -1,381 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type chain struct {
|
||||
name string
|
||||
policy string
|
||||
packets uint64
|
||||
bytes uint64
|
||||
rules []chainRule
|
||||
}
|
||||
|
||||
type chainRule struct {
|
||||
lineNumber uint16 // starts from 1 and cannot be zero.
|
||||
packets uint64
|
||||
bytes uint64
|
||||
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
|
||||
protocol string // "tcp", "udp" or "" for all protocols.
|
||||
inputInterface string // input interface, for example "tun0" or "*""
|
||||
outputInterface string // output interface, for example "eth0" or "*""
|
||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
}
|
||||
|
||||
var (
|
||||
ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
)
|
||||
|
||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
// Text example:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// pkts bytes target prot opt in out source destination
|
||||
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
|
||||
iptablesOutput = strings.TrimSpace(iptablesOutput)
|
||||
linesWithComments := strings.Split(iptablesOutput, "\n")
|
||||
|
||||
// Filter out lines starting with a '#' character
|
||||
lines := make([]string, 0, len(linesWithComments))
|
||||
for _, line := range linesWithComments {
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, line)
|
||||
}
|
||||
|
||||
const minLines = 2 // chain general information line + legend line
|
||||
if len(lines) < minLines {
|
||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||
ErrChainListMalformed, iptablesOutput)
|
||||
}
|
||||
|
||||
c, err = parseChainGeneralDataLine(lines[0])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check for the legend line
|
||||
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
|
||||
legendLine := strings.TrimSpace(lines[1])
|
||||
legendFields := strings.Fields(legendLine)
|
||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||
}
|
||||
|
||||
lines = lines[2:] // remove chain general information line and legend line
|
||||
if len(lines) == 0 {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
c.rules = make([]chainRule, len(lines))
|
||||
for i, line := range lines {
|
||||
c.rules[i], err = parseChainRuleLine(line)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// parseChainGeneralDataLine parses the first line of iptables chain list output.
|
||||
// For example, it can parse the following line:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// It returns a chain struct with the parsed data.
|
||||
func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
runesToRemove := []rune{'(', ')', ','}
|
||||
for _, r := range runesToRemove {
|
||||
line = strings.ReplaceAll(line, string(r), "")
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
const expectedNumberOfFields = 8
|
||||
if len(fields) != expectedNumberOfFields {
|
||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
indexToExpectedValue := map[int]string{
|
||||
0: "Chain",
|
||||
2: "policy",
|
||||
5: "packets",
|
||||
7: "bytes",
|
||||
}
|
||||
for index, expectedValue := range indexToExpectedValue {
|
||||
if fields[index] == expectedValue {
|
||||
continue
|
||||
}
|
||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||
ErrChainListMalformed, expectedValue, index, line)
|
||||
}
|
||||
|
||||
base.name = fields[1] // chain name could be custom
|
||||
base.policy = fields[3]
|
||||
err = checkTarget(base.policy)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
|
||||
}
|
||||
|
||||
packets, err := parseMetricSize(fields[4])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
base.packets = packets
|
||||
|
||||
bytes, err := parseMetricSize(fields[6])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
base.bytes = bytes
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||
)
|
||||
|
||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
|
||||
const minFields = 10
|
||||
if len(fields) < minFields {
|
||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
for fieldIndex, field := range fields[:minFields] {
|
||||
err = parseChainRuleField(fieldIndex, field, &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(fields) > minFields {
|
||||
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||
if field == "" {
|
||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||
}
|
||||
|
||||
const (
|
||||
numIndex = iota
|
||||
packetsIndex
|
||||
bytesIndex
|
||||
targetIndex
|
||||
protocolIndex
|
||||
optIndex
|
||||
inputInterfaceIndex
|
||||
outputInterfaceIndex
|
||||
sourceIndex
|
||||
destinationIndex
|
||||
)
|
||||
|
||||
switch fieldIndex {
|
||||
case numIndex:
|
||||
rule.lineNumber, err = parseLineNumber(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing line number: %w", err)
|
||||
}
|
||||
case packetsIndex:
|
||||
rule.packets, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
case bytesIndex:
|
||||
rule.bytes, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
case targetIndex:
|
||||
err = checkTarget(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking target: %w", err)
|
||||
}
|
||||
rule.target = field
|
||||
case protocolIndex:
|
||||
rule.protocol, err = parseProtocol(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing protocol: %w", err)
|
||||
}
|
||||
case optIndex: // ignored
|
||||
case inputInterfaceIndex:
|
||||
rule.inputInterface = field
|
||||
case outputInterfaceIndex:
|
||||
rule.outputInterface = field
|
||||
case sourceIndex:
|
||||
rule.source, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case destinationIndex:
|
||||
rule.destination, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||
for i := 0; i < len(optionalFields); i++ {
|
||||
key := optionalFields[i]
|
||||
switch key {
|
||||
case "tcp", "udp":
|
||||
i++
|
||||
value := optionalFields[i]
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
||||
}
|
||||
rule.destinationPort = uint16(destinationPort)
|
||||
case "redir":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
case "ports":
|
||||
i++
|
||||
ports, err := parsePortsCSV(optionalFields[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||
}
|
||||
rule.redirPorts = ports
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fields := strings.Split(s, ",")
|
||||
ports = make([]uint16, len(fields))
|
||||
for i, field := range fields {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(field, base, bitLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
)
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if lineNumber == 0 {
|
||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||
}
|
||||
return uint16(lineNumber), nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTargetUnknown = errors.New("unknown target")
|
||||
)
|
||||
|
||||
func checkTarget(target string) (err error) {
|
||||
switch target {
|
||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrProtocolUnknown = errors.New("unknown protocol")
|
||||
)
|
||||
|
||||
func parseProtocol(s string) (protocol string, err error) {
|
||||
switch s {
|
||||
case "0":
|
||||
case "6":
|
||||
protocol = "tcp"
|
||||
case "17":
|
||||
protocol = "udp"
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||
)
|
||||
|
||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||
// returns the raw integer matching it.
|
||||
func parseMetricSize(size string) (n uint64, err error) {
|
||||
if size == "" {
|
||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||
}
|
||||
|
||||
//nolint:gomnd
|
||||
multiplerLetterToValue := map[byte]uint64{
|
||||
'K': 1000,
|
||||
'M': 1000000,
|
||||
'G': 1000000000,
|
||||
'T': 1000000000000,
|
||||
}
|
||||
|
||||
lastCharacter := size[len(size)-1]
|
||||
multiplier, ok := multiplerLetterToValue[lastCharacter]
|
||||
if ok { // multiplier present
|
||||
size = size[:len(size)-1]
|
||||
} else {
|
||||
multiplier = 1
|
||||
}
|
||||
|
||||
const base, bitLength = 10, 64
|
||||
n, err = strconv.ParseUint(size, base, bitLength)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||
}
|
||||
n *= multiplier
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseChain(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
iptablesOutput string
|
||||
table chain
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_output": {
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||
},
|
||||
"single_line_only": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||
},
|
||||
"malformed_general_data_line": {
|
||||
iptablesOutput: `Chain INPUT
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||
"expected 8 fields in \"Chain INPUT\"",
|
||||
},
|
||||
"malformed_legend": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: legend " +
|
||||
"\"num pkts bytes target prot opt in out source\" " +
|
||||
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||
},
|
||||
"no_rule": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
table: chain{
|
||||
name: "INPUT",
|
||||
policy: "ACCEPT",
|
||||
packets: 140000,
|
||||
bytes: 226000000,
|
||||
},
|
||||
},
|
||||
"some_rules": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source destination
|
||||
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
3 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
|
||||
`,
|
||||
table: chain{
|
||||
name: "INPUT",
|
||||
policy: "ACCEPT",
|
||||
packets: 140000,
|
||||
bytes: 226000000,
|
||||
rules: []chainRule{
|
||||
{
|
||||
lineNumber: 1,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "udp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destinationPort: 55405,
|
||||
},
|
||||
{
|
||||
lineNumber: 2,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "tcp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destinationPort: 55405,
|
||||
},
|
||||
{
|
||||
lineNumber: 3,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "DROP",
|
||||
protocol: "",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
table, err := parseChain(testCase.iptablesOutput)
|
||||
|
||||
assert.Equal(t, testCase.table, table)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,12 @@ import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) {
|
||||
c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+
|
||||
"no default route matching its family", subnet))
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
package firewall
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger
|
||||
@@ -1,109 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger)
|
||||
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
command "github.com/qdm12/golibs/command"
|
||||
)
|
||||
|
||||
// MockRunner is a mock of Runner interface.
|
||||
type MockRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
||||
type MockRunnerMockRecorder struct {
|
||||
mock *MockRunner
|
||||
}
|
||||
|
||||
// NewMockRunner creates a new mock instance.
|
||||
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
||||
mock := &MockRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
||||
}
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockLogger) Info(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Info", arg0)
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
if i.table == "" {
|
||||
i.table = "filter"
|
||||
}
|
||||
}
|
||||
|
||||
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
|
||||
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
|
||||
switch {
|
||||
case i.table != table:
|
||||
return false
|
||||
case i.chain != chain:
|
||||
return false
|
||||
case i.target != rule.target:
|
||||
return false
|
||||
case i.protocol != rule.protocol:
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.source, rule.source):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
}
|
||||
|
||||
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
return instruction == chainRule ||
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var (
|
||||
ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
)
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
if len(fields)%2 != 0 {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
||||
ErrIptablesCommandMalformed, len(fields), s)
|
||||
}
|
||||
|
||||
for i := 0; i < len(fields); i += 2 {
|
||||
key := fields[i]
|
||||
value := fields[i+1]
|
||||
err = parseInstructionFlag(key, value, &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
||||
switch key {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.chain = value
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match": // ignore match
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
instruction.outputInterface = value
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
portStrings := strings.Split(value, ",")
|
||||
instruction.toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
return netip.ParsePrefix(value)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseIptablesInstruction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
s string
|
||||
instruction iptablesInstruction
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_instruction": {
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: empty instruction",
|
||||
},
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
s: "-A INPUT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
},
|
||||
},
|
||||
"instruction_A": {
|
||||
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("5.6.7.8/32"),
|
||||
destinationPort: 10000,
|
||||
target: "ACCEPT",
|
||||
},
|
||||
},
|
||||
"nat_redirection": {
|
||||
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
instruction: iptablesInstruction{
|
||||
table: "nat",
|
||||
chain: "PREROUTING",
|
||||
append: false,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
destinationPort: 43716,
|
||||
target: "REDIRECT",
|
||||
toPorts: []uint16{5678},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rule, err := parseIptablesInstruction(testCase.s)
|
||||
|
||||
assert.Equal(t, testCase.instruction, rule)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseIPPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
value string
|
||||
prefix netip.Prefix
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
|
||||
},
|
||||
"invalid": {
|
||||
value: "invalid",
|
||||
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
|
||||
},
|
||||
"valid_ipv4_with_bits": {
|
||||
value: "10.0.0.0/16",
|
||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
|
||||
},
|
||||
"valid_ipv4_without_bits": {
|
||||
value: "10.0.0.4",
|
||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
|
||||
},
|
||||
"valid_ipv6_with_bits": {
|
||||
value: "2001:db8::/32",
|
||||
prefix: netip.PrefixFrom(
|
||||
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||
32),
|
||||
},
|
||||
"valid_ipv6_without_bits": {
|
||||
value: "2001:db8::",
|
||||
prefix: netip.PrefixFrom(
|
||||
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||
128),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prefix, err := parseIPPrefix(testCase.value)
|
||||
|
||||
assert.Equal(t, testCase.prefix, prefix)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
50
internal/firewall/runner_mock_test.go
Normal file
50
internal/firewall/runner_mock_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/golibs/command (interfaces: Runner)
|
||||
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
command "github.com/qdm12/golibs/command"
|
||||
)
|
||||
|
||||
// MockRunner is a mock of Runner interface.
|
||||
type MockRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
||||
type MockRunnerMockRecorder struct {
|
||||
mock *MockRunner
|
||||
}
|
||||
|
||||
// NewMockRunner creates a new mock instance.
|
||||
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
||||
mock := &MockRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
||||
}
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner
|
||||
|
||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||
return newCmdMatcher(path,
|
||||
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
||||
|
||||
@@ -21,7 +21,7 @@ type Connection struct {
|
||||
PubKey string `json:"pubkey"`
|
||||
// ServerName is used for PIA for port forwarding
|
||||
ServerName string `json:"server_name,omitempty"`
|
||||
// PortForward is used for PIA and ProtonVPN for port forwarding
|
||||
// PortForward is used for PIA for port forwarding
|
||||
PortForward bool `json:"port_forward"`
|
||||
}
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ func getMarkdownHeaders(vpnProvider string) (headers []string) {
|
||||
case providers.Mullvad:
|
||||
return []string{countryHeader, cityHeader, ispHeader, ownedHeader, hostnameHeader, vpnHeader}
|
||||
case providers.Nordvpn:
|
||||
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader, vpnHeader, categoriesHeader}
|
||||
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader, categoriesHeader}
|
||||
case providers.Perfectprivacy:
|
||||
return []string{cityHeader, tcpHeader, udpHeader}
|
||||
case providers.Privado:
|
||||
@@ -135,15 +135,14 @@ func getMarkdownHeaders(vpnProvider string) (headers []string) {
|
||||
case providers.Privatevpn:
|
||||
return []string{countryHeader, cityHeader, hostnameHeader}
|
||||
case providers.Protonvpn:
|
||||
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader, vpnHeader,
|
||||
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader,
|
||||
freeHeader, portForwardHeader, secureHeader, torHeader}
|
||||
case providers.Purevpn:
|
||||
return []string{countryHeader, regionHeader, cityHeader, hostnameHeader, tcpHeader, udpHeader}
|
||||
case providers.SlickVPN:
|
||||
return []string{regionHeader, countryHeader, cityHeader, hostnameHeader}
|
||||
case providers.Surfshark:
|
||||
return []string{regionHeader, countryHeader, cityHeader, hostnameHeader,
|
||||
vpnHeader, multiHopHeader, tcpHeader, udpHeader}
|
||||
return []string{regionHeader, countryHeader, cityHeader, hostnameHeader, multiHopHeader, tcpHeader, udpHeader}
|
||||
case providers.Torguard:
|
||||
return []string{countryHeader, cityHeader, hostnameHeader, tcpHeader, udpHeader}
|
||||
case providers.VPNSecure:
|
||||
|
||||
@@ -22,8 +22,8 @@ type Server struct {
|
||||
Number uint16 `json:"number,omitempty"`
|
||||
ServerName string `json:"server_name,omitempty"`
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
TCP bool `json:"tcp,omitempty"`
|
||||
UDP bool `json:"udp,omitempty"`
|
||||
TCP bool `json:"tcp,omitempty"` // TODO v4 rename to openvpn_tcp
|
||||
UDP bool `json:"udp,omitempty"` // TODO v4 rename to openvpn_udp
|
||||
OvpnX509 string `json:"x509,omitempty"`
|
||||
RetroLoc string `json:"retroloc,omitempty"` // TODO remove in v4
|
||||
MultiHop bool `json:"multihop,omitempty"`
|
||||
@@ -38,6 +38,25 @@ type Server struct {
|
||||
IPs []netip.Addr `json:"ips,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) SetVPN(vpnType string) {
|
||||
switch s.VPN {
|
||||
case "":
|
||||
s.VPN = vpnType
|
||||
case vpn.Both:
|
||||
return
|
||||
case vpn.OpenVPN:
|
||||
if vpnType == vpn.Wireguard {
|
||||
s.VPN = vpn.Both
|
||||
}
|
||||
case vpn.Wireguard:
|
||||
if vpnType == vpn.OpenVPN {
|
||||
s.VPN = vpn.Both
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("VPN type %q not supported", s.VPN))
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVPNFieldEmpty = errors.New("vpn field is empty")
|
||||
ErrHostnameFieldEmpty = errors.New("hostname field is empty")
|
||||
@@ -48,16 +67,18 @@ var (
|
||||
)
|
||||
|
||||
func (s *Server) HasMinimumInformation() (err error) {
|
||||
isOpenVPN := s.VPN == vpn.OpenVPN || s.VPN == vpn.Both
|
||||
isWireguard := s.VPN == vpn.Wireguard || s.VPN == vpn.Both
|
||||
switch {
|
||||
case s.VPN == "":
|
||||
return fmt.Errorf("%w", ErrVPNFieldEmpty)
|
||||
case len(s.IPs) == 0:
|
||||
return fmt.Errorf("%w", ErrIPsFieldEmpty)
|
||||
case s.VPN == vpn.Wireguard && (s.TCP || s.UDP):
|
||||
case isWireguard && !isOpenVPN && (s.TCP || s.UDP):
|
||||
return fmt.Errorf("%w", ErrNetworkProtocolSet)
|
||||
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
|
||||
case isOpenVPN && !s.TCP && !s.UDP:
|
||||
return fmt.Errorf("%w", ErrNoNetworkProtocol)
|
||||
case s.VPN == vpn.Wireguard && s.WgPubKey == "":
|
||||
case isWireguard && s.WgPubKey == "":
|
||||
return fmt.Errorf("%w", ErrWireguardPublicKeyEmpty)
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -18,11 +18,6 @@ var (
|
||||
func extractDataFromLines(lines []string) (
|
||||
connection models.Connection, err error) {
|
||||
for i, line := range lines {
|
||||
hashSymbolIndex := strings.Index(line, "#")
|
||||
if hashSymbolIndex >= 0 {
|
||||
line = line[:hashSymbolIndex]
|
||||
}
|
||||
|
||||
ip, port, protocol, err := extractDataFromLine(line)
|
||||
if err != nil {
|
||||
return connection, fmt.Errorf("on line %d: %w", i+1, err)
|
||||
|
||||
@@ -40,12 +40,11 @@ func getOpenVPNConnection(extractor Extractor,
|
||||
connection.Port = customPort
|
||||
}
|
||||
|
||||
// assume all custom provider servers support port forwarding
|
||||
connection.PortForward = true
|
||||
if len(selection.Names) > 0 {
|
||||
// Set the server name for PIA port forwarding code used
|
||||
// together with the custom provider.
|
||||
connection.ServerName = selection.Names[0]
|
||||
connection.PortForward = true
|
||||
}
|
||||
|
||||
return connection, nil
|
||||
@@ -54,17 +53,17 @@ func getOpenVPNConnection(extractor Extractor,
|
||||
func getWireguardConnection(selection settings.ServerSelection) (
|
||||
connection models.Connection) {
|
||||
connection = models.Connection{
|
||||
Type: vpn.Wireguard,
|
||||
IP: selection.Wireguard.EndpointIP,
|
||||
Port: *selection.Wireguard.EndpointPort,
|
||||
Protocol: constants.UDP,
|
||||
PubKey: selection.Wireguard.PublicKey,
|
||||
PortForward: true, // assume all custom provider servers support port forwarding
|
||||
Type: vpn.Wireguard,
|
||||
IP: selection.Wireguard.EndpointIP,
|
||||
Port: *selection.Wireguard.EndpointPort,
|
||||
Protocol: constants.UDP,
|
||||
PubKey: selection.Wireguard.PublicKey,
|
||||
}
|
||||
if len(selection.Names) > 0 {
|
||||
// Set the server name for PIA port forwarding code used
|
||||
// together with the custom provider.
|
||||
connection.ServerName = selection.Names[0]
|
||||
connection.PortForward = true
|
||||
}
|
||||
return connection
|
||||
}
|
||||
|
||||
@@ -68,26 +68,17 @@ func (hts hostToServerData) adaptWithIPs(hostToIPs map[string][]netip.Addr) {
|
||||
func (hts hostToServerData) toServersSlice() (servers []models.Server) {
|
||||
servers = make([]models.Server, 0, 2*len(hts)) //nolint:gomnd
|
||||
for hostname, serverData := range hts {
|
||||
baseServer := models.Server{
|
||||
server := models.Server{
|
||||
VPN: vpn.Both,
|
||||
Hostname: hostname,
|
||||
Country: serverData.country,
|
||||
City: serverData.city,
|
||||
IPs: serverData.ips,
|
||||
TCP: serverData.openvpnTCP,
|
||||
UDP: serverData.openvpnUDP,
|
||||
WgPubKey: "658QxufMbjOTmB61Z7f+c7Rjg7oqWLnepTalqBERjF0=",
|
||||
}
|
||||
if serverData.openvpn {
|
||||
openvpnServer := baseServer
|
||||
openvpnServer.VPN = vpn.OpenVPN
|
||||
openvpnServer.TCP = serverData.openvpnTCP
|
||||
openvpnServer.UDP = serverData.openvpnUDP
|
||||
servers = append(servers, openvpnServer)
|
||||
}
|
||||
if serverData.wireguard {
|
||||
wireguardServer := baseServer
|
||||
wireguardServer.VPN = vpn.Wireguard
|
||||
const wireguardPublicKey = "658QxufMbjOTmB61Z7f+c7Rjg7oqWLnepTalqBERjF0="
|
||||
wireguardServer.WgPubKey = wireguardPublicKey
|
||||
servers = append(servers, wireguardServer)
|
||||
}
|
||||
servers = append(servers, server)
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
@@ -59,11 +58,9 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
|
||||
servers = make([]models.Server, 0, len(hostToIPs))
|
||||
for _, serverData := range data.Servers {
|
||||
city, region := parseCity(serverData.City)
|
||||
server := models.Server{
|
||||
Country: serverData.Country,
|
||||
City: city,
|
||||
Region: region,
|
||||
City: serverData.City,
|
||||
ISP: serverData.ISP,
|
||||
}
|
||||
|
||||
@@ -99,11 +96,3 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
func parseCity(city string) (parsedCity, region string) {
|
||||
commaIndex := strings.Index(city, ", ")
|
||||
if commaIndex == -1 {
|
||||
return city, ""
|
||||
}
|
||||
return city[:commaIndex], city[commaIndex+2:]
|
||||
}
|
||||
|
||||
@@ -35,11 +35,11 @@ func (hts hostToServer) add(data serverData) (err error) {
|
||||
|
||||
switch data.Type {
|
||||
case "openvpn":
|
||||
server.VPN = vpn.OpenVPN
|
||||
server.SetVPN(vpn.OpenVPN)
|
||||
server.UDP = true
|
||||
server.TCP = true
|
||||
case "wireguard":
|
||||
server.VPN = vpn.Wireguard
|
||||
server.SetVPN(vpn.Wireguard)
|
||||
case "bridge":
|
||||
// ignore bridge servers
|
||||
return nil
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Check out the JSON data from https://api.nordvpn.com/v2/servers?limit=10
|
||||
@@ -93,9 +92,6 @@ func (s serversData) idToData() (
|
||||
) {
|
||||
groups = make(map[uint32]groupData, len(s.Groups))
|
||||
for _, group := range s.Groups {
|
||||
if group.Type.Identifier == "regions" { //nolint:goconst
|
||||
group.Title = strings.ReplaceAll(group.Title, ",", "")
|
||||
}
|
||||
groups[group.ID] = group
|
||||
}
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
|
||||
|
||||
server := models.Server{
|
||||
Country: location.Country.Name,
|
||||
Region: region,
|
||||
Region: jsonServer.region(groups),
|
||||
City: location.Country.City.Name,
|
||||
Categories: jsonServer.categories(groups),
|
||||
Hostname: jsonServer.Hostname,
|
||||
@@ -98,12 +98,6 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
|
||||
server.Number = number
|
||||
}
|
||||
|
||||
var wireguardFound, openvpnFound bool
|
||||
wireguardServer := server
|
||||
wireguardServer.VPN = vpn.Wireguard
|
||||
openVPNServer := server // accumulate UDP+TCP technologies
|
||||
openVPNServer.VPN = vpn.OpenVPN
|
||||
|
||||
for _, technology := range jsonServer.Technologies {
|
||||
if technology.Status != "online" {
|
||||
continue
|
||||
@@ -118,33 +112,25 @@ func extractServers(jsonServer serverData, groups map[uint32]groupData,
|
||||
|
||||
switch technologyData.Identifier {
|
||||
case "openvpn_udp", "openvpn_dedicated_udp":
|
||||
openvpnFound = true
|
||||
openVPNServer.UDP = true
|
||||
server.SetVPN(vpn.OpenVPN)
|
||||
server.UDP = true
|
||||
case "openvpn_tcp", "openvpn_dedicated_tcp":
|
||||
openvpnFound = true
|
||||
openVPNServer.TCP = true
|
||||
server.SetVPN(vpn.OpenVPN)
|
||||
server.TCP = true
|
||||
case "wireguard_udp":
|
||||
wireguardFound = true
|
||||
wireguardServer.WgPubKey, err = jsonServer.wireguardPublicKey(technologies)
|
||||
server.WgPubKey, err = jsonServer.wireguardPublicKey(technologies)
|
||||
if err != nil {
|
||||
warning := fmt.Sprintf("ignoring Wireguard server %s: %s", jsonServer.Name, err)
|
||||
warnings = append(warnings, warning)
|
||||
wireguardFound = false
|
||||
continue
|
||||
}
|
||||
server.SetVPN(vpn.Wireguard)
|
||||
default: // Ignore other technologies
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
const maxServers = 2
|
||||
servers = make([]models.Server, 0, maxServers)
|
||||
if openvpnFound {
|
||||
servers = append(servers, openVPNServer)
|
||||
}
|
||||
if wireguardFound {
|
||||
servers = append(servers, wireguardServer)
|
||||
}
|
||||
servers = append(servers, server)
|
||||
|
||||
return servers, warnings
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
}
|
||||
|
||||
serverName := objects.ServerName
|
||||
apiIP := buildAPIIPAddress(objects.Gateway)
|
||||
|
||||
logger := objects.Logger
|
||||
|
||||
if !objects.CanPortForward {
|
||||
@@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
|
||||
if !dataFound || expired {
|
||||
client := objects.Client
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP,
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
|
||||
p.portForwardPath, objects.Username, objects.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing port forward data: %w", err)
|
||||
@@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
||||
|
||||
// First time binding
|
||||
if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil {
|
||||
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
|
||||
return nil, fmt.Errorf("binding port: %w", err)
|
||||
}
|
||||
|
||||
@@ -100,8 +100,6 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
||||
panic("gateway is not set")
|
||||
}
|
||||
|
||||
apiIP := buildAPIIPAddress(objects.Gateway)
|
||||
|
||||
privateIPClient, err := newHTTPClient(objects.ServerName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating custom HTTP client: %w", err)
|
||||
@@ -129,7 +127,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-keepAliveTimer.C:
|
||||
err = bindPort(ctx, privateIPClient, apiIP, data)
|
||||
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("binding port: %w", err)
|
||||
}
|
||||
@@ -141,25 +139,14 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) {
|
||||
if gateway.Is6() {
|
||||
panic("IPv6 gateway not supported")
|
||||
}
|
||||
|
||||
gatewayBytes := gateway.As4()
|
||||
gatewayBytes[2] = 128
|
||||
gatewayBytes[3] = 1
|
||||
return netip.AddrFrom4(gatewayBytes)
|
||||
}
|
||||
|
||||
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||
apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
|
||||
gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchToken(ctx, client, username, password)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("fetching token: %w", err)
|
||||
}
|
||||
|
||||
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token)
|
||||
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("fetching port forwarding data: %w", err)
|
||||
}
|
||||
@@ -299,7 +286,7 @@ func fetchToken(ctx context.Context, client *http.Client,
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) (
|
||||
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
|
||||
port uint16, signature string, expiration time.Time, err error) {
|
||||
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
|
||||
|
||||
@@ -307,7 +294,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.
|
||||
queryParams.Add("token", token)
|
||||
url := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(apiIP.String(), "19999"),
|
||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
||||
Path: "/getSignature",
|
||||
RawQuery: queryParams.Encode(),
|
||||
}
|
||||
@@ -353,7 +340,7 @@ var (
|
||||
ErrBadResponse = errors.New("bad response received")
|
||||
)
|
||||
|
||||
func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
|
||||
func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
|
||||
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serializing payload: %w", err)
|
||||
@@ -364,7 +351,7 @@ func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr,
|
||||
queryParams.Add("signature", data.Signature)
|
||||
bindPortURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(apiIPAddress.String(), "19999"),
|
||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
||||
Path: "/bindPort",
|
||||
RawQuery: queryParams.Encode(),
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/updater/openvpn"
|
||||
@@ -65,7 +64,6 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
continue
|
||||
}
|
||||
server := models.Server{
|
||||
VPN: vpn.OpenVPN,
|
||||
Country: country,
|
||||
City: city,
|
||||
IPs: ips,
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
func (p *Provider) GetConnection(selection settings.ServerSelection, ipv6Supported bool) (
|
||||
connection models.Connection, err error) {
|
||||
defaults := utils.NewConnectionDefaults(443, 1194, 51820) //nolint:gomnd
|
||||
defaults := utils.NewConnectionDefaults(443, 1194, 0) //nolint:gomnd
|
||||
return utils.GetConnection(p.Name(),
|
||||
p.storage, selection, defaults, ipv6Supported, p.randSource)
|
||||
}
|
||||
|
||||
@@ -28,11 +28,10 @@ type logicalServer struct {
|
||||
}
|
||||
|
||||
type physicalServer struct {
|
||||
EntryIP netip.Addr `json:"EntryIP"`
|
||||
ExitIP netip.Addr `json:"ExitIP"`
|
||||
Domain string `json:"Domain"`
|
||||
Status uint8 `json:"Status"`
|
||||
X25519PublicKey string `json:"X25519PublicKey"`
|
||||
EntryIP netip.Addr `json:"EntryIP"`
|
||||
ExitIP netip.Addr `json:"ExitIP"`
|
||||
Domain string `json:"Domain"`
|
||||
Status uint8 `json:"Status"`
|
||||
}
|
||||
|
||||
func fetchAPI(ctx context.Context, client *http.Client) (
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type ipToServers map[string][2]models.Server // first server is OpenVPN, second is Wireguard.
|
||||
type ipToServer map[string]models.Server
|
||||
|
||||
type features struct {
|
||||
secureCore bool
|
||||
@@ -16,50 +16,36 @@ type features struct {
|
||||
stream bool
|
||||
}
|
||||
|
||||
func (its ipToServers) add(country, region, city, name, hostname, wgPubKey string,
|
||||
func (its ipToServer) add(country, region, city, name, hostname string,
|
||||
free bool, entryIP netip.Addr, features features) {
|
||||
key := entryIP.String()
|
||||
|
||||
servers, ok := its[key]
|
||||
server, ok := its[key]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
|
||||
baseServer := models.Server{
|
||||
Country: country,
|
||||
Region: region,
|
||||
City: city,
|
||||
ServerName: name,
|
||||
Hostname: hostname,
|
||||
Free: free,
|
||||
SecureCore: features.secureCore,
|
||||
Tor: features.tor,
|
||||
PortForward: features.p2p,
|
||||
Stream: features.stream,
|
||||
IPs: []netip.Addr{entryIP},
|
||||
}
|
||||
openvpnServer := baseServer
|
||||
openvpnServer.VPN = vpn.OpenVPN
|
||||
openvpnServer.UDP = true
|
||||
openvpnServer.TCP = true
|
||||
servers[0] = openvpnServer
|
||||
wireguardServer := baseServer
|
||||
wireguardServer.VPN = vpn.Wireguard
|
||||
wireguardServer.WgPubKey = wgPubKey
|
||||
servers[1] = wireguardServer
|
||||
its[key] = servers
|
||||
server.VPN = vpn.OpenVPN
|
||||
server.Country = country
|
||||
server.Region = region
|
||||
server.City = city
|
||||
server.ServerName = name
|
||||
server.Hostname = hostname
|
||||
server.Free = free
|
||||
server.SecureCore = features.secureCore
|
||||
server.Tor = features.tor
|
||||
server.PortForward = features.p2p
|
||||
server.Stream = features.stream
|
||||
server.UDP = true
|
||||
server.TCP = true
|
||||
server.IPs = []netip.Addr{entryIP}
|
||||
its[key] = server
|
||||
}
|
||||
|
||||
func (its ipToServers) toServersSlice() (serversSlice []models.Server) {
|
||||
const vpnProtocols = 2
|
||||
serversSlice = make([]models.Server, 0, vpnProtocols*len(its))
|
||||
for _, servers := range its {
|
||||
serversSlice = append(serversSlice, servers[0], servers[1])
|
||||
func (its ipToServer) toServersSlice() (servers []models.Server) {
|
||||
servers = make([]models.Server, 0, len(its))
|
||||
for _, server := range its {
|
||||
servers = append(servers, server)
|
||||
}
|
||||
return serversSlice
|
||||
}
|
||||
|
||||
func (its ipToServers) numberOfServers() int {
|
||||
const serversPerIP = 2
|
||||
return len(its) * serversPerIP
|
||||
return servers
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
common.ErrNotEnoughServers, count, minServers)
|
||||
}
|
||||
|
||||
ipToServer := make(ipToServers, count)
|
||||
ipToServer := make(ipToServer, count)
|
||||
for _, logicalServer := range data.LogicalServers {
|
||||
region := getStringValue(logicalServer.Region)
|
||||
city := getStringValue(logicalServer.City)
|
||||
@@ -65,7 +65,6 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
|
||||
hostname := physicalServer.Domain
|
||||
entryIP := physicalServer.EntryIP
|
||||
wgPubKey := physicalServer.X25519PublicKey
|
||||
|
||||
// Note: for multi-hop use the server name or hostname
|
||||
// instead of the country
|
||||
@@ -75,11 +74,11 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
u.warner.Warn(warning)
|
||||
}
|
||||
|
||||
ipToServer.add(country, region, city, name, hostname, wgPubKey, free, entryIP, features)
|
||||
ipToServer.add(country, region, city, name, hostname, free, entryIP, features)
|
||||
}
|
||||
}
|
||||
|
||||
if ipToServer.numberOfServers() < minServers {
|
||||
if len(ipToServer) < minServers {
|
||||
return nil, fmt.Errorf("%w: %d and expected at least %d",
|
||||
common.ErrNotEnoughServers, len(ipToServer), minServers)
|
||||
}
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/provider/surfshark/servers"
|
||||
)
|
||||
|
||||
func addServersFromAPI(ctx context.Context, client *http.Client,
|
||||
hts hostToServers) (err error) {
|
||||
hts hostToServer) (err error) {
|
||||
data, err := fetchAPI(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -25,12 +26,14 @@ func addServersFromAPI(ctx context.Context, client *http.Client,
|
||||
retroLoc := locationData.RetroLoc // empty string if the host has no retro-compatible region
|
||||
|
||||
tcp, udp := true, true // OpenVPN servers from API supports both TCP and UDP
|
||||
hts.addOpenVPN(serverData.Host, serverData.Region, serverData.Country,
|
||||
serverData.Location, retroLoc, tcp, udp)
|
||||
const wgPubKey = ""
|
||||
hts.add(serverData.Host, vpn.OpenVPN, serverData.Region, serverData.Country,
|
||||
serverData.Location, retroLoc, wgPubKey, tcp, udp)
|
||||
|
||||
if serverData.PubKey != "" {
|
||||
hts.addWireguard(serverData.Host, serverData.Region, serverData.Country,
|
||||
serverData.Location, retroLoc, serverData.PubKey)
|
||||
const wgTCP, wgUDP = false, false // unused
|
||||
hts.add(serverData.Host, vpn.Wireguard, serverData.Region, serverData.Country,
|
||||
serverData.Location, retroLoc, serverData.PubKey, wgTCP, wgUDP)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ func Test_addServersFromAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
hts hostToServers
|
||||
hts hostToServer
|
||||
exchanges []httpExchange
|
||||
expected hostToServers
|
||||
expected hostToServer
|
||||
err error
|
||||
}{
|
||||
"fetch API error": {
|
||||
@@ -37,8 +37,8 @@ func Test_addServersFromAPI(t *testing.T) {
|
||||
err: errors.New("HTTP status code not OK: 204 No Content"),
|
||||
},
|
||||
"success": {
|
||||
hts: hostToServers{
|
||||
"existinghost": []models.Server{{Hostname: "existinghost"}},
|
||||
hts: hostToServer{
|
||||
"existinghost": models.Server{Hostname: "existinghost"},
|
||||
},
|
||||
exchanges: []httpExchange{{
|
||||
requestURL: "https://api.surfshark.com/v4/server/clusters/generic",
|
||||
@@ -61,25 +61,19 @@ func Test_addServersFromAPI(t *testing.T) {
|
||||
responseStatus: http.StatusOK,
|
||||
responseBody: io.NopCloser(strings.NewReader(`[]`)),
|
||||
}},
|
||||
expected: map[string][]models.Server{
|
||||
"existinghost": {{Hostname: "existinghost"}},
|
||||
"host1": {{
|
||||
VPN: vpn.OpenVPN,
|
||||
expected: map[string]models.Server{
|
||||
"existinghost": {Hostname: "existinghost"},
|
||||
"host1": {
|
||||
VPN: vpn.Both,
|
||||
Region: "region1",
|
||||
Country: "country1",
|
||||
City: "location1",
|
||||
Hostname: "host1",
|
||||
TCP: true,
|
||||
UDP: true,
|
||||
}, {
|
||||
VPN: vpn.Wireguard,
|
||||
Region: "region1",
|
||||
Country: "country1",
|
||||
City: "location1",
|
||||
Hostname: "host1",
|
||||
WgPubKey: "pubKeyValue",
|
||||
}},
|
||||
"host2": {{
|
||||
},
|
||||
"host2": {
|
||||
VPN: vpn.OpenVPN,
|
||||
Region: "region2",
|
||||
Country: "country1",
|
||||
@@ -87,7 +81,7 @@ func Test_addServersFromAPI(t *testing.T) {
|
||||
Hostname: "host2",
|
||||
TCP: true,
|
||||
UDP: true,
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -7,94 +7,62 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type hostToServers map[string][]models.Server
|
||||
type hostToServer map[string]models.Server
|
||||
|
||||
func (hts hostToServers) addOpenVPN(host, region, country, city,
|
||||
retroLoc string, tcp, udp bool) {
|
||||
// Check for existing server for this host and OpenVPN.
|
||||
servers := hts[host]
|
||||
for i, existingServer := range servers {
|
||||
if existingServer.Hostname != host ||
|
||||
existingServer.VPN != vpn.OpenVPN {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update OpenVPN supported protocols and return
|
||||
if !existingServer.TCP {
|
||||
servers[i].TCP = tcp
|
||||
}
|
||||
if !existingServer.UDP {
|
||||
servers[i].UDP = udp
|
||||
func (hts hostToServer) add(host, vpnType, region, country, city,
|
||||
retroLoc, wgPubKey string, openvpnTCP, openvpnUDP bool) {
|
||||
server, ok := hts[host]
|
||||
if !ok {
|
||||
server := models.Server{
|
||||
VPN: vpnType,
|
||||
Region: region,
|
||||
Country: country,
|
||||
City: city,
|
||||
RetroLoc: retroLoc,
|
||||
Hostname: host,
|
||||
WgPubKey: wgPubKey,
|
||||
TCP: openvpnTCP,
|
||||
UDP: openvpnUDP,
|
||||
}
|
||||
hts[host] = server
|
||||
return
|
||||
}
|
||||
|
||||
server := models.Server{
|
||||
VPN: vpn.OpenVPN,
|
||||
Region: region,
|
||||
Country: country,
|
||||
City: city,
|
||||
RetroLoc: retroLoc,
|
||||
Hostname: host,
|
||||
TCP: tcp,
|
||||
UDP: udp,
|
||||
server.SetVPN(vpnType)
|
||||
if vpnType == vpn.OpenVPN {
|
||||
server.TCP = server.TCP || openvpnTCP
|
||||
server.UDP = server.UDP || openvpnUDP
|
||||
} else if wgPubKey != "" {
|
||||
server.WgPubKey = wgPubKey
|
||||
}
|
||||
hts[host] = append(servers, server)
|
||||
|
||||
hts[host] = server
|
||||
}
|
||||
|
||||
func (hts hostToServers) addWireguard(host, region, country, city, retroLoc,
|
||||
wgPubKey string) {
|
||||
// Check for existing server for this host and Wireguard.
|
||||
servers := hts[host]
|
||||
for _, existingServer := range servers {
|
||||
if existingServer.Hostname == host &&
|
||||
existingServer.VPN == vpn.Wireguard {
|
||||
// No update necessary for Wireguard
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
server := models.Server{
|
||||
VPN: vpn.Wireguard,
|
||||
Region: region,
|
||||
Country: country,
|
||||
City: city,
|
||||
RetroLoc: retroLoc,
|
||||
Hostname: host,
|
||||
WgPubKey: wgPubKey,
|
||||
}
|
||||
hts[host] = append(servers, server)
|
||||
}
|
||||
|
||||
func (hts hostToServers) toHostsSlice() (hosts []string) {
|
||||
const vpnServerTypes = 2 // OpenVPN + Wireguard
|
||||
hosts = make([]string, 0, vpnServerTypes*len(hts))
|
||||
func (hts hostToServer) toHostsSlice() (hosts []string) {
|
||||
hosts = make([]string, 0, len(hts))
|
||||
for host := range hts {
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
return hosts
|
||||
}
|
||||
|
||||
func (hts hostToServers) adaptWithIPs(hostToIPs map[string][]netip.Addr) {
|
||||
for host, IPs := range hostToIPs {
|
||||
servers := hts[host]
|
||||
for i := range servers {
|
||||
servers[i].IPs = IPs
|
||||
}
|
||||
hts[host] = servers
|
||||
}
|
||||
for host, servers := range hts {
|
||||
if len(servers[0].IPs) == 0 {
|
||||
func (hts hostToServer) adaptWithIPs(hostToIPs map[string][]netip.Addr) {
|
||||
for host, server := range hts {
|
||||
ips := hostToIPs[host]
|
||||
if len(ips) == 0 {
|
||||
delete(hts, host)
|
||||
continue
|
||||
}
|
||||
server.IPs = ips
|
||||
hts[host] = server
|
||||
}
|
||||
}
|
||||
|
||||
func (hts hostToServers) toServersSlice() (servers []models.Server) {
|
||||
const vpnServerTypes = 2 // OpenVPN + Wireguard
|
||||
servers = make([]models.Server, 0, vpnServerTypes*len(hts))
|
||||
for _, serversForHost := range hts {
|
||||
servers = append(servers, serversForHost...)
|
||||
func (hts hostToServer) toServersSlice() (servers []models.Server) {
|
||||
servers = make([]models.Server, 0, len(hts))
|
||||
for _, server := range hts {
|
||||
servers = append(servers, server)
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/provider/surfshark/servers"
|
||||
)
|
||||
|
||||
// getRemainingServers finds extra servers not found in the API or in the ZIP file.
|
||||
func getRemainingServers(hts hostToServers) {
|
||||
func getRemainingServers(hts hostToServer) {
|
||||
locationData := servers.LocationData()
|
||||
hostnameToLocationLeft := hostToLocation(locationData)
|
||||
for _, hostnameDone := range hts.toHostsSlice() {
|
||||
@@ -15,7 +16,8 @@ func getRemainingServers(hts hostToServers) {
|
||||
for hostname, locationData := range hostnameToLocationLeft {
|
||||
// we assume the OpenVPN server supports both TCP and UDP
|
||||
const tcp, udp = true, true
|
||||
hts.addOpenVPN(hostname, locationData.Region, locationData.Country,
|
||||
locationData.City, locationData.RetroLoc, tcp, udp)
|
||||
const wgPubKey = ""
|
||||
hts.add(hostname, vpn.OpenVPN, locationData.Region, locationData.Country,
|
||||
locationData.City, locationData.RetroLoc, wgPubKey, tcp, udp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
||||
servers []models.Server, err error) {
|
||||
hts := make(hostToServers)
|
||||
hts := make(hostToServer)
|
||||
|
||||
err = addServersFromAPI(ctx, u.client, hts)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/qdm12/gluetun/internal/provider/surfshark/servers"
|
||||
"github.com/qdm12/gluetun/internal/updater/openvpn"
|
||||
)
|
||||
|
||||
func addOpenVPNServersFromZip(ctx context.Context,
|
||||
unzipper common.Unzipper, hts hostToServers) (
|
||||
unzipper common.Unzipper, hts hostToServer) (
|
||||
warnings []string, err error) {
|
||||
const url = "https://my.surfshark.com/vpn/api/v1/server/configurations"
|
||||
contents, err := unzipper.FetchAndExtract(ctx, url)
|
||||
@@ -66,8 +67,9 @@ func addOpenVPNServersFromZip(ctx context.Context,
|
||||
continue
|
||||
}
|
||||
|
||||
hts.addOpenVPN(host, data.Region, data.Country, data.City,
|
||||
data.RetroLoc, tcp, udp)
|
||||
const wgPubKey = ""
|
||||
hts.add(host, vpn.OpenVPN, data.Region, data.Country, data.City,
|
||||
data.RetroLoc, wgPubKey, tcp, udp)
|
||||
}
|
||||
|
||||
return warnings, nil
|
||||
|
||||
@@ -2,17 +2,13 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/log"
|
||||
)
|
||||
|
||||
func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||
authSettings auth.Settings,
|
||||
func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
||||
buildInfo models.BuildInformation,
|
||||
vpnLooper VPNLooper,
|
||||
pfGetter PortForwardedGetter,
|
||||
@@ -21,7 +17,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||
publicIPLooper PublicIPLoop,
|
||||
storage Storage,
|
||||
ipv6Supported bool,
|
||||
) (httpHandler http.Handler, err error) {
|
||||
) http.Handler {
|
||||
handler := &handler{}
|
||||
|
||||
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
|
||||
@@ -33,25 +29,16 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
|
||||
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
|
||||
|
||||
authMiddleware, err := auth.New(authSettings, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating auth middleware: %w", err)
|
||||
}
|
||||
handlerWithLog := withLogMiddleware(handler, logger, logging)
|
||||
handler.setLogEnabled = handlerWithLog.setEnabled
|
||||
|
||||
middlewares := []func(http.Handler) http.Handler{
|
||||
authMiddleware,
|
||||
log.New(logger, logging),
|
||||
}
|
||||
httpHandler = handler
|
||||
for _, middleware := range middlewares {
|
||||
httpHandler = middleware(httpHandler)
|
||||
}
|
||||
return httpHandler, nil
|
||||
return handlerWithLog
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
v0 http.Handler
|
||||
v1 http.Handler
|
||||
v0 http.Handler
|
||||
v1 http.Handler
|
||||
setLogEnabled func(enabled bool)
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package log
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -7,21 +7,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func New(logger Logger, enabled bool) (
|
||||
middleware func(http.Handler) http.Handler) {
|
||||
return func(handler http.Handler) http.Handler {
|
||||
return &logMiddleware{
|
||||
childHandler: handler,
|
||||
logger: logger,
|
||||
timeNow: time.Now,
|
||||
enabled: enabled,
|
||||
}
|
||||
func withLogMiddleware(childHandler http.Handler, logger infoer, enabled bool) *logMiddleware {
|
||||
return &logMiddleware{
|
||||
childHandler: childHandler,
|
||||
logger: logger,
|
||||
timeNow: time.Now,
|
||||
enabled: enabled,
|
||||
}
|
||||
}
|
||||
|
||||
type logMiddleware struct {
|
||||
childHandler http.Handler
|
||||
logger Logger
|
||||
logger infoer
|
||||
timeNow func() time.Time
|
||||
enabled bool
|
||||
enabledMu sync.RWMutex
|
||||
@@ -42,7 +39,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.RemoteAddr + " in " + duration.String())
|
||||
}
|
||||
|
||||
func (m *logMiddleware) SetEnabled(enabled bool) {
|
||||
func (m *logMiddleware) setEnabled(enabled bool) {
|
||||
m.enabledMu.Lock()
|
||||
defer m.enabledMu.Unlock()
|
||||
m.enabled = enabled
|
||||
@@ -1,10 +1,8 @@
|
||||
package server
|
||||
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...any)
|
||||
infoer
|
||||
warner
|
||||
Warnf(format string, args ...any)
|
||||
errorer
|
||||
}
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type apiKeyMethod struct {
|
||||
apiKeyDigest [32]byte
|
||||
}
|
||||
|
||||
func newAPIKeyMethod(apiKey string) *apiKeyMethod {
|
||||
return &apiKeyMethod{
|
||||
apiKeyDigest: sha256.Sum256([]byte(apiKey)),
|
||||
}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (a *apiKeyMethod) equal(other authorizationChecker) bool {
|
||||
otherTokenMethod, ok := other.(*apiKeyMethod)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return a.apiKeyDigest == otherTokenMethod.apiKeyDigest
|
||||
}
|
||||
|
||||
func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool {
|
||||
xAPIKey := request.Header.Get("X-API-Key")
|
||||
if xAPIKey == "" {
|
||||
xAPIKey = request.URL.Query().Get("api_key")
|
||||
}
|
||||
xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey))
|
||||
return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type basicAuthMethod struct {
|
||||
authDigest [32]byte
|
||||
}
|
||||
|
||||
func newBasicAuthMethod(username, password string) *basicAuthMethod {
|
||||
return &basicAuthMethod{
|
||||
authDigest: sha256.Sum256([]byte(username + password)),
|
||||
}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (a *basicAuthMethod) equal(other authorizationChecker) bool {
|
||||
otherBasicMethod, ok := other.(*basicAuthMethod)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return a.authDigest == otherBasicMethod.authDigest
|
||||
}
|
||||
|
||||
func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool {
|
||||
username, password, ok := request.BasicAuth()
|
||||
if !ok {
|
||||
headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
return false
|
||||
}
|
||||
requestAuthDigest := sha256.Sum256([]byte(username + password))
|
||||
return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
// If the file does not exist, it returns empty settings and no error.
|
||||
func Read(filepath string) (settings Settings, err error) {
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return Settings{}, nil
|
||||
}
|
||||
return settings, fmt.Errorf("opening file: %w", err)
|
||||
}
|
||||
decoder := toml.NewDecoder(file)
|
||||
decoder.DisallowUnknownFields()
|
||||
err = decoder.Decode(&settings)
|
||||
if err == nil {
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
strictErr := new(toml.StrictMissingError)
|
||||
ok := errors.As(err, &strictErr)
|
||||
if !ok {
|
||||
return settings, fmt.Errorf("toml decoding file: %w", err)
|
||||
}
|
||||
return settings, fmt.Errorf("toml decoding file: %w:\n%s",
|
||||
strictErr, strictErr.String())
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
func Test_Read(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
fileContent string
|
||||
settings Settings
|
||||
errMessage string
|
||||
}{
|
||||
"empty_file": {},
|
||||
"malformed_toml": {
|
||||
fileContent: "this is not a toml file",
|
||||
errMessage: `toml decoding file: toml: expected character =`,
|
||||
},
|
||||
"unknown_field": {
|
||||
fileContent: `unknown = "what is this"`,
|
||||
errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct:
|
||||
1| unknown = "what is this"
|
||||
| ~~~~~~~ missing field`,
|
||||
},
|
||||
"filled_settings": {
|
||||
fileContent: `[[roles]]
|
||||
name = "public"
|
||||
auth = "none"
|
||||
routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"]
|
||||
|
||||
[[roles]]
|
||||
name = "client"
|
||||
auth = "apikey"
|
||||
apikey = "xyz"
|
||||
routes = ["GET /v1/vpn/status"]
|
||||
`,
|
||||
settings: Settings{
|
||||
Roles: []Role{{
|
||||
Name: "public",
|
||||
Auth: AuthNone,
|
||||
Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"},
|
||||
}, {
|
||||
Name: "client",
|
||||
Auth: AuthAPIKey,
|
||||
APIKey: "xyz",
|
||||
Routes: []string{"GET /v1/vpn/status"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
filepath := tempDir + "/config.toml"
|
||||
const permissions fs.FileMode = 0600
|
||||
err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
settings, err := Read(filepath)
|
||||
|
||||
assert.Equal(t, testCase.settings, settings)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package auth
|
||||
|
||||
func andStrings(strings []string) (result string) {
|
||||
return joinStrings(strings, "and")
|
||||
}
|
||||
|
||||
func joinStrings(strings []string, lastJoin string) (result string) {
|
||||
if len(strings) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result = strings[0]
|
||||
for i := 1; i < len(strings); i++ {
|
||||
if i < len(strings)-1 {
|
||||
result += ", " + strings[i]
|
||||
} else {
|
||||
result += " " + lastJoin + " " + strings[i]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
package auth
|
||||
|
||||
type DebugLogger interface {
|
||||
Debugf(format string, args ...any)
|
||||
Warnf(format string, args ...any)
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package auth
|
||||
|
||||
import "net/http"
|
||||
|
||||
type authorizationChecker interface {
|
||||
equal(other authorizationChecker) bool
|
||||
isAuthorized(headers http.Header, request *http.Request) bool
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type internalRole struct {
|
||||
name string
|
||||
checker authorizationChecker
|
||||
}
|
||||
|
||||
func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) {
|
||||
routeToRoles = make(map[string][]internalRole)
|
||||
for _, role := range settings.Roles {
|
||||
var checker authorizationChecker
|
||||
switch role.Auth {
|
||||
case AuthNone:
|
||||
checker = newNoneMethod()
|
||||
case AuthAPIKey:
|
||||
checker = newAPIKeyMethod(role.APIKey)
|
||||
case AuthBasic:
|
||||
checker = newBasicAuthMethod(role.Username, role.Password)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth)
|
||||
}
|
||||
|
||||
iRole := internalRole{
|
||||
name: role.Name,
|
||||
checker: checker,
|
||||
}
|
||||
for _, route := range role.Routes {
|
||||
checkerExists := false
|
||||
for _, role := range routeToRoles[route] {
|
||||
if role.checker.equal(iRole.checker) {
|
||||
checkerExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if checkerExists {
|
||||
// even if the role name is different, if the checker is the same, skip it.
|
||||
continue
|
||||
}
|
||||
routeToRoles[route] = append(routeToRoles[route], iRole)
|
||||
}
|
||||
}
|
||||
return routeToRoles, nil
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
func Test_settingsToLookupMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
routeToRoles map[string][]internalRole
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty_settings": {
|
||||
routeToRoles: map[string][]internalRole{},
|
||||
},
|
||||
"auth_method_not_supported": {
|
||||
settings: Settings{
|
||||
Roles: []Role{{Name: "a", Auth: "bad"}},
|
||||
},
|
||||
errWrapped: ErrMethodNotSupported,
|
||||
errMessage: "authentication method not supported: bad",
|
||||
},
|
||||
"success": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}},
|
||||
{Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}},
|
||||
},
|
||||
},
|
||||
routeToRoles: map[string][]internalRole{
|
||||
"GET /path": {
|
||||
{name: "a", checker: newNoneMethod()}, // deduplicated method
|
||||
},
|
||||
"PUT /path": {
|
||||
{name: "b", checker: newNoneMethod()},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
routeToRoles, err := settingsToLookupMap(testCase.settings)
|
||||
|
||||
assert.Equal(t, testCase.routeToRoles, routeToRoles)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func New(settings Settings, debugLogger DebugLogger) (
|
||||
middleware func(http.Handler) http.Handler,
|
||||
err error) {
|
||||
routeToRoles, err := settingsToLookupMap(settings)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting settings to lookup maps: %w", err)
|
||||
}
|
||||
|
||||
//nolint:goconst
|
||||
return func(handler http.Handler) http.Handler {
|
||||
return &authHandler{
|
||||
childHandler: handler,
|
||||
routeToRoles: routeToRoles,
|
||||
unprotectedRoutes: map[string]struct{}{
|
||||
http.MethodGet + " /openvpn/actions/restart": {},
|
||||
http.MethodGet + " /unbound/actions/restart": {},
|
||||
http.MethodGet + " /updater/restart": {},
|
||||
http.MethodGet + " /v1/version": {},
|
||||
http.MethodGet + " /v1/vpn/status": {},
|
||||
http.MethodPut + " /v1/vpn/status": {},
|
||||
// GET /v1/vpn/settings is protected by default
|
||||
// PUT /v1/vpn/settings is protected by default
|
||||
http.MethodGet + " /v1/openvpn/status": {},
|
||||
http.MethodPut + " /v1/openvpn/status": {},
|
||||
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||
// GET /v1/openvpn/settings is protected by default
|
||||
http.MethodGet + " /v1/dns/status": {},
|
||||
http.MethodPut + " /v1/dns/status": {},
|
||||
http.MethodGet + " /v1/updater/status": {},
|
||||
http.MethodPut + " /v1/updater/status": {},
|
||||
http.MethodGet + " /v1/publicip/ip": {},
|
||||
},
|
||||
logger: debugLogger,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
childHandler http.Handler
|
||||
routeToRoles map[string][]internalRole
|
||||
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove
|
||||
logger DebugLogger
|
||||
}
|
||||
|
||||
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
route := request.Method + " " + request.URL.Path
|
||||
roles := h.routeToRoles[route]
|
||||
if len(roles) == 0 {
|
||||
h.logger.Debugf("no authentication role defined for route %s", route)
|
||||
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
responseHeader := make(http.Header, 0)
|
||||
for _, role := range roles {
|
||||
if !role.checker.isAuthorized(responseHeader, request) {
|
||||
continue
|
||||
}
|
||||
|
||||
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
|
||||
|
||||
h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
|
||||
h.childHandler.ServeHTTP(writer, request)
|
||||
return
|
||||
}
|
||||
|
||||
// Flush out response headers if all roles failed to authenticate
|
||||
for headerKey, headerValues := range responseHeader {
|
||||
for _, headerValue := range headerValues {
|
||||
writer.Header().Add(headerKey, headerValue)
|
||||
}
|
||||
}
|
||||
|
||||
allRoleNames := make([]string, len(roles))
|
||||
for i, role := range roles {
|
||||
allRoleNames[i] = role.name
|
||||
}
|
||||
h.logger.Debugf("access to route %s unauthorized after checking for roles %s",
|
||||
route, andStrings(allRoleNames))
|
||||
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
|
||||
// TODO v3.41.0 remove
|
||||
if role.name != "public" {
|
||||
// custom role name, allow none authentication to be specified
|
||||
return
|
||||
}
|
||||
_, isNoneChecker := role.checker.(*noneMethod)
|
||||
if !isNoneChecker {
|
||||
// not the none authentication method
|
||||
return
|
||||
}
|
||||
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
|
||||
if !isUnprotectedByDefault {
|
||||
// route is not unprotected by default, so this is a user decision
|
||||
return
|
||||
}
|
||||
h.logger.Warnf("route %s is unprotected by default, "+
|
||||
"please set up authentication following the documentation at "+
|
||||
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||
"since this will become no longer publicly accessible after release v3.40.",
|
||||
route)
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_authHandler_ServeHTTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
|
||||
requestMethod string
|
||||
requestPath string
|
||||
statusCode int
|
||||
responseBody string
|
||||
}{
|
||||
"route_has_no_role": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/b",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
responseBody: "Unauthorized\n",
|
||||
},
|
||||
"authorized_unprotected_by_default": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Warnf("route %s is unprotected by default, "+
|
||||
"please set up authentication following the documentation at "+
|
||||
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||
"since this will become no longer publicly accessible after release v3.40.",
|
||||
"GET /v1/vpn/status")
|
||||
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||
"GET /v1/vpn/status", "public")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/v1/vpn/status",
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
"authorized_none": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||
"GET /a", "role1")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/a",
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var debugLogger DebugLogger
|
||||
if testCase.makeLogger != nil {
|
||||
debugLogger = testCase.makeLogger(ctrl)
|
||||
}
|
||||
middleware, err := New(testCase.settings, debugLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := middleware(childHandler)
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := server.Client()
|
||||
|
||||
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
|
||||
require.NoError(t, err)
|
||||
request, err := http.NewRequestWithContext(context.Background(),
|
||||
testCase.requestMethod, requestURL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := client.Do(request)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err = response.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
assert.Equal(t, testCase.statusCode, response.StatusCode)
|
||||
body, err := io.ReadAll(response.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testCase.responseBody, string(body))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
package auth
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger
|
||||
@@ -1,68 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger)
|
||||
|
||||
// Package auth is a generated GoMock package.
|
||||
package auth
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDebugLogger is a mock of DebugLogger interface.
|
||||
type MockDebugLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDebugLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger.
|
||||
type MockDebugLoggerMockRecorder struct {
|
||||
mock *MockDebugLogger
|
||||
}
|
||||
|
||||
// NewMockDebugLogger creates a new mock instance.
|
||||
func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
mock := &MockDebugLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockDebugLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debugf mocks base method.
|
||||
func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Debugf", varargs...)
|
||||
}
|
||||
|
||||
// Debugf indicates an expected call of Debugf.
|
||||
func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...)
|
||||
}
|
||||
|
||||
// Warnf mocks base method.
|
||||
func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Warnf", varargs...)
|
||||
}
|
||||
|
||||
// Warnf indicates an expected call of Warnf.
|
||||
func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...)
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package auth
|
||||
|
||||
import "net/http"
|
||||
|
||||
type noneMethod struct{}
|
||||
|
||||
func newNoneMethod() *noneMethod {
|
||||
return &noneMethod{}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (n *noneMethod) equal(other authorizationChecker) bool {
|
||||
_, ok := other.(*noneMethod)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/validate"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
// Roles is a list of roles with their associated authentication
|
||||
// and routes.
|
||||
Roles []Role
|
||||
}
|
||||
|
||||
func (s *Settings) SetDefaults() {
|
||||
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
|
||||
Name: "public",
|
||||
Auth: "none",
|
||||
Routes: []string{
|
||||
http.MethodGet + " /openvpn/actions/restart",
|
||||
http.MethodGet + " /unbound/actions/restart",
|
||||
http.MethodGet + " /updater/restart",
|
||||
http.MethodGet + " /v1/version",
|
||||
http.MethodGet + " /v1/vpn/status",
|
||||
http.MethodPut + " /v1/vpn/status",
|
||||
http.MethodGet + " /v1/openvpn/status",
|
||||
http.MethodPut + " /v1/openvpn/status",
|
||||
http.MethodGet + " /v1/openvpn/portforwarded",
|
||||
http.MethodGet + " /v1/dns/status",
|
||||
http.MethodPut + " /v1/dns/status",
|
||||
http.MethodGet + " /v1/updater/status",
|
||||
http.MethodPut + " /v1/updater/status",
|
||||
http.MethodGet + " /v1/publicip/ip",
|
||||
},
|
||||
}})
|
||||
}
|
||||
|
||||
func (s Settings) Validate() (err error) {
|
||||
for i, role := range s.Roles {
|
||||
err = role.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("role %s (%d of %d): %w",
|
||||
role.Name, i+1, len(s.Roles), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
AuthNone = "none"
|
||||
AuthAPIKey = "apikey"
|
||||
AuthBasic = "basic"
|
||||
)
|
||||
|
||||
// Role contains the role name, authentication method name and
|
||||
// routes that the role can access.
|
||||
type Role struct {
|
||||
// Name is the role name and is only used for documentation
|
||||
// and in the authentication middleware debug logs.
|
||||
Name string
|
||||
// Auth is the authentication method to use, which can be 'none' or 'apikey'.
|
||||
Auth string
|
||||
// APIKey is the API key to use when using the 'apikey' authentication.
|
||||
APIKey string
|
||||
// Username for HTTP Basic authentication method.
|
||||
Username string
|
||||
// Password for HTTP Basic authentication method.
|
||||
Password string
|
||||
// Routes is a list of routes that the role can access in the format
|
||||
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
|
||||
Routes []string
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMethodNotSupported = errors.New("authentication method not supported")
|
||||
ErrAPIKeyEmpty = errors.New("api key is empty")
|
||||
ErrBasicUsernameEmpty = errors.New("username is empty")
|
||||
ErrBasicPasswordEmpty = errors.New("password is empty")
|
||||
ErrRouteNotSupported = errors.New("route not supported by the control server")
|
||||
)
|
||||
|
||||
func (r Role) validate() (err error) {
|
||||
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Auth == AuthAPIKey && r.APIKey == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty)
|
||||
case r.Auth == AuthBasic && r.Username == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty)
|
||||
case r.Auth == AuthBasic && r.Password == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty)
|
||||
}
|
||||
|
||||
for i, route := range r.Routes {
|
||||
_, ok := validRoutes[route]
|
||||
if !ok {
|
||||
return fmt.Errorf("route %d of %d: %w: %s",
|
||||
i+1, len(r.Routes), ErrRouteNotSupported, route)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WARNING: do not mutate programmatically.
|
||||
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
|
||||
http.MethodGet + " /openvpn/actions/restart": {},
|
||||
http.MethodGet + " /unbound/actions/restart": {},
|
||||
http.MethodGet + " /updater/restart": {},
|
||||
http.MethodGet + " /v1/version": {},
|
||||
http.MethodGet + " /v1/vpn/status": {},
|
||||
http.MethodPut + " /v1/vpn/status": {},
|
||||
http.MethodGet + " /v1/vpn/settings": {},
|
||||
http.MethodPut + " /v1/vpn/settings": {},
|
||||
http.MethodGet + " /v1/openvpn/status": {},
|
||||
http.MethodPut + " /v1/openvpn/status": {},
|
||||
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||
http.MethodGet + " /v1/openvpn/settings": {},
|
||||
http.MethodGet + " /v1/dns/status": {},
|
||||
http.MethodPut + " /v1/dns/status": {},
|
||||
http.MethodGet + " /v1/updater/status": {},
|
||||
http.MethodPut + " /v1/updater/status": {},
|
||||
http.MethodGet + " /v1/publicip/ip": {},
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package log
|
||||
|
||||
type Logger interface {
|
||||
Info(message string)
|
||||
}
|
||||
@@ -6,31 +6,17 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/httpserver"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||
)
|
||||
|
||||
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
|
||||
authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
|
||||
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
|
||||
ipv6Supported bool) (
|
||||
server *httpserver.Server, err error) {
|
||||
authSettings, err := auth.Read(authConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading auth settings: %w", err)
|
||||
}
|
||||
authSettings.SetDefaults()
|
||||
err = authSettings.Validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating auth settings: %w", err)
|
||||
}
|
||||
|
||||
handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
|
||||
handler := newHandler(ctx, logger, logEnabled, buildInfo,
|
||||
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
|
||||
storage, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating handler: %w", err)
|
||||
}
|
||||
|
||||
httpServerSettings := httpserver.Settings{
|
||||
Address: address,
|
||||
|
||||
@@ -50,11 +50,11 @@ func filterServer(server models.Server,
|
||||
selection settings.ServerSelection) (filtered bool) {
|
||||
// Note each condition is split to make sure
|
||||
// we have full testing coverage.
|
||||
if server.VPN != selection.VPN {
|
||||
if server.VPN != vpn.Both && server.VPN != selection.VPN {
|
||||
return true
|
||||
}
|
||||
|
||||
if server.VPN != vpn.Wireguard &&
|
||||
if selection.VPN == vpn.OpenVPN &&
|
||||
filterByProtocol(selection, server.TCP, server.UDP) {
|
||||
return true
|
||||
}
|
||||
@@ -119,8 +119,6 @@ func filterServer(server models.Server,
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO filter port forward server for PIA
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -128,31 +128,6 @@ func noServerFoundError(selection settings.ServerSelection) (err error) {
|
||||
messageParts = append(messageParts, "premium tier only")
|
||||
}
|
||||
|
||||
if *selection.StreamOnly {
|
||||
messageParts = append(messageParts, "stream only")
|
||||
}
|
||||
|
||||
if *selection.MultiHopOnly {
|
||||
messageParts = append(messageParts, "multihop only")
|
||||
}
|
||||
|
||||
if *selection.PortForwardOnly {
|
||||
messageParts = append(messageParts, "port forwarding only")
|
||||
}
|
||||
|
||||
if *selection.SecureCoreOnly {
|
||||
messageParts = append(messageParts, "secure core only")
|
||||
}
|
||||
|
||||
if *selection.TorOnly {
|
||||
messageParts = append(messageParts, "tor only")
|
||||
}
|
||||
|
||||
if selection.TargetIP.IsValid() {
|
||||
messageParts = append(messageParts,
|
||||
"target ip address "+selection.TargetIP.String())
|
||||
}
|
||||
|
||||
message := "for " + strings.Join(messageParts, "; ")
|
||||
|
||||
return fmt.Errorf("%w: %s", ErrNoServerFound, message)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
@@ -50,17 +49,13 @@ func getLatestRelease(ctx context.Context, client *http.Client) (tagName, name s
|
||||
if err != nil {
|
||||
return "", "", time, err
|
||||
}
|
||||
// Sort releases by tag names (semver)
|
||||
sort.Slice(releases, func(i, j int) bool {
|
||||
return releases[i].TagName > releases[j].TagName
|
||||
})
|
||||
for _, release := range releases {
|
||||
if release.Prerelease {
|
||||
continue
|
||||
}
|
||||
return release.TagName, release.Name, release.PublishedAt, nil
|
||||
}
|
||||
return "", "", time, fmt.Errorf("%w", errReleaseNotFound)
|
||||
return "", "", time, errReleaseNotFound
|
||||
}
|
||||
|
||||
var errCommitNotFound = errors.New("commit not found")
|
||||
|
||||
Reference in New Issue
Block a user