Compare commits

..

18 Commits

Author SHA1 Message Date
Quentin McGaw
90b9e81129 Merge branch 'master' into pmtu 2025-11-07 21:55:58 +00:00
Quentin McGaw
2391c890b4 Run MTU discovery AFTER healthcheck is started 2025-10-17 00:39:44 +00:00
Quentin McGaw
51fd46b58e Merge branch 'master' into pmtu 2025-10-17 00:17:45 +00:00
Quentin McGaw
906e7b5ee1 Remove unneeded error context wrapping 2025-10-14 17:56:54 +00:00
Quentin McGaw
5428580b8f Handle ICMP not permitted errors 2025-10-14 17:56:04 +00:00
Quentin McGaw
6c25ee53f1 Fix unit test 2025-10-06 11:08:03 +00:00
Quentin McGaw
b9051b02bf Use the VPN local gateway IP address to run path MTU discovery 2025-10-06 10:03:15 +00:00
Quentin McGaw
f0f3193c1c Remove VPN_PMTUD option 2025-10-06 09:57:15 +00:00
Quentin McGaw
c0ebd180cb Revert to VPN original MTU (set by WIREGUARD_MTU for example) if ICMP fails 2025-10-06 09:57:15 +00:00
Quentin McGaw
b6e873cf25 Improve logging in case of ICMP blocked 2025-10-06 09:57:15 +00:00
Quentin McGaw
ccc2f306b9 Fallback on 1320 if ICMP is blocked 2025-10-06 09:57:15 +00:00
Quentin McGaw
5b1dc295fe Return an error if all MTUs failed to test 2025-10-06 09:57:15 +00:00
Quentin McGaw
00bc8bbbbb Handle administrative prohibition of ICMP 2025-10-06 09:57:15 +00:00
Quentin McGaw
8bef380d8c Fix unit test 2025-10-06 09:57:15 +00:00
Quentin McGaw
9ad1907574 Update log that PMTUD can take up to 4s 2025-10-06 09:57:15 +00:00
Quentin McGaw
d83999d954 Make binary search faster with 11 parallel queries 2025-10-06 09:57:15 +00:00
Quentin McGaw
162d244865 Use PMTUD to set the MTU to the VPN interface
- Add `VPN_PMTUD` option enabled by default
- One can revert to use `VPN_PMTUD=off` to disable the new PMTUD mechanism
2025-10-06 09:57:15 +00:00
Quentin McGaw
e21d798f57 pmtud package 2025-10-06 09:57:15 +00:00
75 changed files with 1378 additions and 1494 deletions

View File

@@ -20,7 +20,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v5
- uses: DavidAnson/markdownlint-cli2-action@v21 - uses: DavidAnson/markdownlint-cli2-action@v20
with: with:
globs: "**.md" globs: "**.md"
config: .markdownlint-cli2.jsonc config: .markdownlint-cli2.jsonc

View File

@@ -56,9 +56,6 @@ linters:
- revive - revive
path: internal\/provider\/(common|utils)\/.+\.go path: internal\/provider\/(common|utils)\/.+\.go
text: "var-naming: avoid (bad|meaningless) package names" text: "var-naming: avoid (bad|meaningless) package names"
- linters:
- lll
source: "^// https://.+$"
- linters: - linters:
- err113 - err113
- mnd - mnd

View File

@@ -167,19 +167,20 @@ ENV VPN_SERVICE_PROVIDER=pia \
HEALTH_ICMP_TARGET_IP=1.1.1.1 \ HEALTH_ICMP_TARGET_IP=1.1.1.1 \
HEALTH_RESTART_VPN=on \ HEALTH_RESTART_VPN=on \
# DNS # DNS
DNS_SERVER=on \
DNS_UPSTREAM_RESOLVER_TYPE=DoT \ DNS_UPSTREAM_RESOLVER_TYPE=DoT \
DNS_UPSTREAM_RESOLVERS=cloudflare \ DNS_UPSTREAM_RESOLVERS=cloudflare \
DNS_BLOCK_IPS= \ DNS_BLOCK_IPS= \
DNS_BLOCK_IP_PREFIXES= \ DNS_BLOCK_IP_PREFIXES=127.0.0.1/8,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,169.254.0.0/16,::1/128,fc00::/7,fe80::/10,::ffff:7f00:1/104,::ffff:a00:0/104,::ffff:a9fe:0/112,::ffff:ac10:0/108,::ffff:c0a8:0/112 \
DNS_CACHING=on \ DNS_CACHING=on \
DNS_UPSTREAM_IPV6=off \ DNS_UPSTREAM_IPV6=off \
BLOCK_MALICIOUS=on \ BLOCK_MALICIOUS=on \
BLOCK_SURVEILLANCE=off \ BLOCK_SURVEILLANCE=off \
BLOCK_ADS=off \ BLOCK_ADS=off \
DNS_UNBLOCK_HOSTNAMES= \ DNS_UNBLOCK_HOSTNAMES= \
DNS_REBINDING_PROTECTION_EXEMPT_HOSTNAMES= \
DNS_UPDATE_PERIOD=24h \ DNS_UPDATE_PERIOD=24h \
DNS_UPSTREAM_PLAIN_ADDRESSES= \ DNS_ADDRESS=127.0.0.1 \
DNS_KEEP_NAMESERVER=off \
# HTTP proxy # HTTP proxy
HTTPPROXY= \ HTTPPROXY= \
HTTPPROXY_LOG=off \ HTTPPROXY_LOG=off \
@@ -200,13 +201,10 @@ ENV VPN_SERVICE_PROVIDER=pia \
HTTP_CONTROL_SERVER_LOG=on \ HTTP_CONTROL_SERVER_LOG=on \
HTTP_CONTROL_SERVER_ADDRESS=":8000" \ HTTP_CONTROL_SERVER_ADDRESS=":8000" \
HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \ HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \
HTTP_CONTROL_SERVER_AUTH_DEFAULT_ROLE="{}" \
# Server data updater # Server data updater
UPDATER_PERIOD=0 \ UPDATER_PERIOD=0 \
UPDATER_MIN_RATIO=0.8 \ UPDATER_MIN_RATIO=0.8 \
UPDATER_VPN_SERVICE_PROVIDERS= \ UPDATER_VPN_SERVICE_PROVIDERS= \
UPDATER_PROTONVPN_USERNAME= \
UPDATER_PROTONVPN_PASSWORD= \
# Public IP # Public IP
PUBLICIP_FILE="/tmp/gluetun/ip" \ PUBLICIP_FILE="/tmp/gluetun/ip" \
PUBLICIP_ENABLED=on \ PUBLICIP_ENABLED=on \
@@ -222,8 +220,8 @@ ENV VPN_SERVICE_PROVIDER=pia \
# Extras # Extras
VERSION_INFORMATION=on \ VERSION_INFORMATION=on \
TZ= \ TZ= \
PUID=1000 \ PUID= \
PGID=1000 PGID=
ENTRYPOINT ["/gluetun-entrypoint"] ENTRYPOINT ["/gluetun-entrypoint"]
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp
HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck HEALTHCHECK --interval=5s --timeout=5s --start-period=10s --retries=3 CMD /gluetun-entrypoint healthcheck

View File

@@ -1,7 +1,5 @@
# Gluetun VPN client # Gluetun VPN client
⚠️ This and [gluetun-wiki](https://github.com/qdm12/gluetun-wiki) are the only websites for Gluetun, other websites claiming to be official are scams ⚠️
Lightweight swiss-army-knife-like VPN client to multiple VPN service providers Lightweight swiss-army-knife-like VPN client to multiple VPN service providers
![Title image](https://raw.githubusercontent.com/qdm12/gluetun/master/title.svg) ![Title image](https://raw.githubusercontent.com/qdm12/gluetun/master/title.svg)

View File

@@ -427,8 +427,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
parallelResolver := resolver.NewParallelResolver(allSettings.Updater.DNSAddress) parallelResolver := resolver.NewParallelResolver(allSettings.Updater.DNSAddress)
openvpnFileExtractor := extract.New() openvpnFileExtractor := extract.New()
providers := provider.NewProviders(storage, time.Now, updaterLogger, providers := provider.NewProviders(storage, time.Now, updaterLogger,
httpClient, unzipper, parallelResolver, publicIPLooper.Fetcher(), httpClient, unzipper, parallelResolver, publicIPLooper.Fetcher(), openvpnFileExtractor)
openvpnFileExtractor, allSettings.Updater)
vpnLogger := logger.New(log.SetComponent("vpn")) vpnLogger := logger.New(log.SetComponent("vpn"))
vpnLooper := vpn.NewLoop(allSettings.VPN, ipv6Supported, allSettings.Firewall.VPNInputPorts, vpnLooper := vpn.NewLoop(allSettings.VPN, ipv6Supported, allSettings.Firewall.VPNInputPorts,
@@ -467,10 +466,13 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
go shadowsocksLooper.Run(shadowsocksCtx, shadowsocksDone) go shadowsocksLooper.Run(shadowsocksCtx, shadowsocksDone)
otherGroupHandler.Add(shadowsocksHandler) otherGroupHandler.Add(shadowsocksHandler)
controlServerAddress := *allSettings.ControlServer.Address
controlServerLogging := *allSettings.ControlServer.Log
httpServerHandler, httpServerCtx, httpServerDone := goshutdown.NewGoRoutineHandler( httpServerHandler, httpServerCtx, httpServerDone := goshutdown.NewGoRoutineHandler(
"http server", goroutine.OptionTimeout(defaultShutdownTimeout)) "http server", goroutine.OptionTimeout(defaultShutdownTimeout))
httpServer, err := server.New(httpServerCtx, allSettings.ControlServer, httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
logger.New(log.SetComponent("http server")), logger.New(log.SetComponent("http server")),
allSettings.ControlServer.AuthFilePath,
buildInfo, vpnLooper, portForwardLooper, dnsLooper, updaterLooper, publicIPLooper, buildInfo, vpnLooper, portForwardLooper, dnsLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported) storage, ipv6Supported)
if err != nil { if err != nil {
@@ -579,6 +581,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error) LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) error
} }
type clier interface { type clier interface {

23
go.mod
View File

@@ -3,14 +3,13 @@ module github.com/qdm12/gluetun
go 1.25.0 go 1.25.0
require ( require (
github.com/ProtonMail/go-srp v0.0.7
github.com/breml/rootcerts v0.3.3 github.com/breml/rootcerts v0.3.3
github.com/fatih/color v1.18.0 github.com/fatih/color v1.18.0
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/klauspost/compress v1.18.1 github.com/klauspost/compress v1.18.1
github.com/klauspost/pgzip v1.2.6 github.com/klauspost/pgzip v1.2.6
github.com/pelletier/go-toml/v2 v2.2.4 github.com/pelletier/go-toml/v2 v2.2.4
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f github.com/qdm12/dns/v2 v2.0.0-rc9
github.com/qdm12/gosettings v0.4.4 github.com/qdm12/gosettings v0.4.4
github.com/qdm12/goshutdown v0.3.0 github.com/qdm12/goshutdown v0.3.0
github.com/qdm12/gosplash v0.2.0 github.com/qdm12/gosplash v0.2.0
@@ -22,21 +21,17 @@ require (
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
golang.org/x/net v0.47.0 golang.org/x/net v0.46.0
golang.org/x/sys v0.38.0 golang.org/x/sys v0.37.0
golang.org/x/text v0.31.0 golang.org/x/text v0.30.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gopkg.in/ini.v1 v1.67.0 gopkg.in/ini.v1 v1.67.0
) )
require ( require (
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf // indirect
github.com/ProtonMail/go-crypto v1.3.0-proton // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
@@ -47,7 +42,6 @@ require (
github.com/mdlayher/socket v0.4.1 // indirect github.com/mdlayher/socket v0.4.1 // indirect
github.com/miekg/dns v1.1.62 // indirect github.com/miekg/dns v1.1.62 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.20.5 // indirect github.com/prometheus/client_golang v1.20.5 // indirect
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
@@ -56,11 +50,10 @@ require (
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.44.0 // indirect golang.org/x/crypto v0.43.0 // indirect
golang.org/x/mod v0.29.0 // indirect golang.org/x/mod v0.28.0 // indirect
golang.org/x/sync v0.18.0 // indirect golang.org/x/sync v0.17.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.37.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/protobuf v1.35.1 // indirect google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

76
go.sum
View File

@@ -1,23 +1,9 @@
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
github.com/ProtonMail/go-crypto v0.0.0-20230321155629-9a39f2531310/go.mod h1:8TI4H3IbrackdNgv+92dI+rhpCaLqM0IfpgCgenFvRE=
github.com/ProtonMail/go-crypto v1.3.0-proton h1:tAQKQRZX/73VmzK6yHSCaRUOvS/3OYSQzhXQsrR7yUM=
github.com/ProtonMail/go-crypto v1.3.0-proton/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
github.com/ProtonMail/go-srp v0.0.7 h1:Sos3Qk+th4tQR64vsxGIxYpN3rdnG9Wf9K4ZloC1JrI=
github.com/ProtonMail/go-srp v0.0.7/go.mod h1:giCp+7qRnMIcCvI6V6U3S1lDDXDQYx2ewJ6F/9wdlJk=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/breml/rootcerts v0.3.3 h1://GnaRtQ/9BY2+GtMk2wtWxVdCRysiaPr5/xBwl7NKw= github.com/breml/rootcerts v0.3.3 h1://GnaRtQ/9BY2+GtMk2wtWxVdCRysiaPr5/xBwl7NKw=
github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw= github.com/breml/rootcerts v0.3.3/go.mod h1:S/PKh+4d1HUn4HQovEB8hPJZO6pUZYrIhmXBhsegfXw=
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/cronokirby/saferith v0.33.0 h1:TgoQlfsD4LIwx71+ChfRcIpjkw+RPOapDEVxa+LhwLo=
github.com/cronokirby/saferith v0.33.0/go.mod h1:QKJhjoqUtBsXCAVEjw38mFqoi7DebT7kthcD7UzbnoA=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
@@ -57,8 +43,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
@@ -69,8 +53,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw= github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f h1:6wN5D9wACfmXDsQ366egVt0jXY4nqL/QnIwg4nWhXco= github.com/qdm12/dns/v2 v2.0.0-rc9 h1:qDzRkHr6993jknNB/ZOCnZOyIG6bsZcl2MIfdeUd0kI=
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE= github.com/qdm12/dns/v2 v2.0.0-rc9/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c= github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg= github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4= github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
@@ -100,72 +84,48 @@ github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZla
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -7,4 +7,3 @@ func newNoopLogger() *noopLogger {
} }
func (l *noopLogger) Info(string) {} func (l *noopLogger) Info(string) {}
func (l *noopLogger) Warn(string) {}

View File

@@ -76,7 +76,7 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
openvpnFileExtractor := extract.New() openvpnFileExtractor := extract.New()
providers := provider.NewProviders(storage, time.Now, warner, client, providers := provider.NewProviders(storage, time.Now, warner, client,
unzipper, parallelResolver, ipFetcher, openvpnFileExtractor, allSettings.Updater) unzipper, parallelResolver, ipFetcher, openvpnFileExtractor)
providerConf := providers.Get(allSettings.VPN.Provider.Name) providerConf := providers.Get(allSettings.VPN.Provider.Name)
connection, err := providerConf.GetConnection( connection, err := providerConf.GetConnection(
allSettings.VPN.Provider.ServerSelection, ipv6Supported) allSettings.VPN.Provider.ServerSelection, ipv6Supported)

View File

@@ -6,7 +6,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"net/http" "net/http"
"slices"
"strings" "strings"
"time" "time"
@@ -25,8 +24,6 @@ import (
var ( var (
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified") ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
ErrNoProviderSpecified = errors.New("no provider was specified") ErrNoProviderSpecified = errors.New("no provider was specified")
ErrUsernameMissing = errors.New("username is required for this provider")
ErrPasswordMissing = errors.New("password is required for this provider")
) )
type UpdaterLogger interface { type UpdaterLogger interface {
@@ -38,7 +35,7 @@ type UpdaterLogger interface {
func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error { func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error {
options := settings.Updater{} options := settings.Updater{}
var endUserMode, maintainerMode, updateAll bool var endUserMode, maintainerMode, updateAll bool
var csvProviders, ipToken, protonUsername, protonPassword string var csvProviders, ipToken string
flagSet := flag.NewFlagSet("update", flag.ExitOnError) flagSet := flag.NewFlagSet("update", flag.ExitOnError)
flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)") flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)")
flagSet.BoolVar(&maintainerMode, "maintainer", false, flagSet.BoolVar(&maintainerMode, "maintainer", false,
@@ -50,8 +47,6 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers") flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers")
flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for") flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for")
flagSet.StringVar(&ipToken, "ip-token", "", "IP data service token (e.g. ipinfo.io) to use") flagSet.StringVar(&ipToken, "ip-token", "", "IP data service token (e.g. ipinfo.io) to use")
flagSet.StringVar(&protonUsername, "proton-username", "", "Username to use to authenticate with Proton")
flagSet.StringVar(&protonPassword, "proton-password", "", "Password to use to authenticate with Proton")
if err := flagSet.Parse(args); err != nil { if err := flagSet.Parse(args); err != nil {
return err return err
} }
@@ -69,11 +64,6 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
options.Providers = strings.Split(csvProviders, ",") options.Providers = strings.Split(csvProviders, ",")
} }
if slices.Contains(options.Providers, providers.Protonvpn) {
options.ProtonUsername = &protonUsername
options.ProtonPassword = &protonPassword
}
options.SetDefaults(options.Providers[0]) options.SetDefaults(options.Providers[0])
err := options.Validate() err := options.Validate()
@@ -81,11 +71,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
return fmt.Errorf("options validation failed: %w", err) return fmt.Errorf("options validation failed: %w", err)
} }
serversDataPath := constants.ServersData storage, err := storage.New(logger, constants.ServersData)
if maintainerMode {
serversDataPath = ""
}
storage, err := storage.New(logger, serversDataPath)
if err != nil { if err != nil {
return fmt.Errorf("creating servers storage: %w", err) return fmt.Errorf("creating servers storage: %w", err)
} }
@@ -108,7 +94,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
openvpnFileExtractor := extract.New() openvpnFileExtractor := extract.New()
providers := provider.NewProviders(storage, time.Now, logger, httpClient, providers := provider.NewProviders(storage, time.Now, logger, httpClient,
unzipper, parallelResolver, ipFetcher, openvpnFileExtractor, options) unzipper, parallelResolver, ipFetcher, openvpnFileExtractor)
updater := updater.New(httpClient, storage, providers, logger) updater := updater.New(httpClient, storage, providers, logger)
err = updater.UpdateServers(ctx, options.Providers, options.MinRatio) err = updater.UpdateServers(ctx, options.Providers, options.MinRatio)

View File

@@ -14,10 +14,6 @@ func readObsolete(r *reader.Reader) (warnings []string) {
"DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.", "DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.",
"HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete", "HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete",
"HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete", "HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete",
"DNS_SERVER": "DNS_SERVER is obsolete because the forwarding server is always enabled.",
"DOT": "DOT is obsolete because the forwarding server is always enabled.",
"DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because the forwarding server is always used and " +
"forwards local names to private DNS resolvers found in /etc/resolv.conf",
} }
sortedKeys := maps.Keys(keyToMessage) sortedKeys := maps.Keys(keyToMessage)
slices.Sort(sortedKeys) slices.Sort(sortedKeys)

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/qdm12/dns/v2/pkg/provider" "github.com/qdm12/dns/v2/pkg/provider"
@@ -14,25 +13,20 @@ import (
"github.com/qdm12/gotree" "github.com/qdm12/gotree"
) )
const (
DNSUpstreamTypeDot = "dot"
DNSUpstreamTypeDoh = "doh"
DNSUpstreamTypePlain = "plain"
)
// DNS contains settings to configure DNS. // DNS contains settings to configure DNS.
type DNS struct { type DNS struct {
// UpstreamType can be [dnsUpstreamTypeDot], [dnsUpstreamTypeDoh] // ServerEnabled is true if the server should be running
// or [dnsUpstreamTypePlain]. It defaults to [dnsUpstreamTypeDot]. // and used. It defaults to true, and cannot be nil
// in the internal state.
ServerEnabled *bool
// UpstreamType can be dot or plain, and defaults to dot.
UpstreamType string `json:"upstream_type"` UpstreamType string `json:"upstream_type"`
// UpdatePeriod is the period to update DNS block lists. // UpdatePeriod is the period to update DNS block lists.
// It can be set to 0 to disable the update. // It can be set to 0 to disable the update.
// It defaults to 24h and cannot be nil in // It defaults to 24h and cannot be nil in
// the internal state. // the internal state.
UpdatePeriod *time.Duration UpdatePeriod *time.Duration
// Providers is a list of DNS providers. // Providers is a list of DNS providers
// It defaults to either ["cloudflare"] or [] if the
// UpstreamPlainAddresses field is set.
Providers []string `json:"providers"` Providers []string `json:"providers"`
// Caching is true if the server should cache // Caching is true if the server should cache
// DNS responses. // DNS responses.
@@ -42,23 +36,32 @@ type DNS struct {
// Blacklist contains settings to configure the filter // Blacklist contains settings to configure the filter
// block lists. // block lists.
Blacklist DNSBlacklist Blacklist DNSBlacklist
// UpstreamPlainAddresses are the upstream plaintext DNS resolver // ServerAddress is the DNS server to use inside
// addresses to use by the built-in DNS server forwarder. // the Go program and for the system.
// Note, if the upstream type is [dnsUpstreamTypePlain] these are merged // It defaults to '127.0.0.1' to be used with the
// together with provider names set in the Providers field. // local server. It cannot be the zero value in the internal
// If this field is set, the Providers field will default to the empty slice. // state.
UpstreamPlainAddresses []netip.AddrPort ServerAddress netip.Addr
// KeepNameserver is true if the existing DNS server
// found in /etc/resolv.conf should be used
// Note setting this to true will likely DNS traffic
// outside the VPN tunnel since it would go through
// the local DNS server of your Docker/Kubernetes
// configuration, which is likely not going through the tunnel.
// This will also disable the DNS forwarder server and the
// `ServerAddress` field will be ignored.
// It defaults to false and cannot be nil in the
// internal state.
KeepNameserver *bool
} }
var ( var (
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid") ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short") ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
ErrDNSUpstreamPlainNoIPv6 = errors.New("upstream plain addresses do not contain any IPv6 address")
ErrDNSUpstreamPlainNoIPv4 = errors.New("upstream plain addresses do not contain any IPv4 address")
) )
func (d DNS) validate() (err error) { func (d DNS) validate() (err error) {
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) { if !helpers.IsOneOf(d.UpstreamType, "dot", "doh", "plain") {
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType) return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType)
} }
@@ -76,18 +79,6 @@ func (d DNS) validate() (err error) {
} }
} }
if d.UpstreamType == DNSUpstreamTypePlain {
if *d.IPv6 && !slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool {
return addrPort.Addr().Is6()
}) {
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
} else if !slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool {
return addrPort.Addr().Is4()
}) {
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
}
}
err = d.Blacklist.validate() err = d.Blacklist.validate()
if err != nil { if err != nil {
return err return err
@@ -98,13 +89,15 @@ func (d DNS) validate() (err error) {
func (d *DNS) Copy() (copied DNS) { func (d *DNS) Copy() (copied DNS) {
return DNS{ return DNS{
UpstreamType: d.UpstreamType, ServerEnabled: gosettings.CopyPointer(d.ServerEnabled),
UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod), UpstreamType: d.UpstreamType,
Providers: gosettings.CopySlice(d.Providers), UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod),
Caching: gosettings.CopyPointer(d.Caching), Providers: gosettings.CopySlice(d.Providers),
IPv6: gosettings.CopyPointer(d.IPv6), Caching: gosettings.CopyPointer(d.Caching),
Blacklist: d.Blacklist.copy(), IPv6: gosettings.CopyPointer(d.IPv6),
UpstreamPlainAddresses: d.UpstreamPlainAddresses, Blacklist: d.Blacklist.copy(),
ServerAddress: d.ServerAddress,
KeepNameserver: gosettings.CopyPointer(d.KeepNameserver),
} }
} }
@@ -112,17 +105,20 @@ func (d *DNS) Copy() (copied DNS) {
// settings object with any field set in the other // settings object with any field set in the other
// settings. // settings.
func (d *DNS) overrideWith(other DNS) { func (d *DNS) overrideWith(other DNS) {
d.ServerEnabled = gosettings.OverrideWithPointer(d.ServerEnabled, other.ServerEnabled)
d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType) d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType)
d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod) d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod)
d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers) d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers)
d.Caching = gosettings.OverrideWithPointer(d.Caching, other.Caching) d.Caching = gosettings.OverrideWithPointer(d.Caching, other.Caching)
d.IPv6 = gosettings.OverrideWithPointer(d.IPv6, other.IPv6) d.IPv6 = gosettings.OverrideWithPointer(d.IPv6, other.IPv6)
d.Blacklist.overrideWith(other.Blacklist) d.Blacklist.overrideWith(other.Blacklist)
d.UpstreamPlainAddresses = gosettings.OverrideWithSlice(d.UpstreamPlainAddresses, other.UpstreamPlainAddresses) d.ServerAddress = gosettings.OverrideWithValidator(d.ServerAddress, other.ServerAddress)
d.KeepNameserver = gosettings.OverrideWithPointer(d.KeepNameserver, other.KeepNameserver)
} }
func (d *DNS) setDefaults() { func (d *DNS) setDefaults() {
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, DNSUpstreamTypeDot) d.ServerEnabled = gosettings.DefaultPointer(d.ServerEnabled, true)
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, "dot")
const defaultUpdatePeriod = 24 * time.Hour const defaultUpdatePeriod = 24 * time.Hour
d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod) d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod)
d.Providers = gosettings.DefaultSlice(d.Providers, []string{ d.Providers = gosettings.DefaultSlice(d.Providers, []string{
@@ -131,53 +127,26 @@ func (d *DNS) setDefaults() {
d.Caching = gosettings.DefaultPointer(d.Caching, true) d.Caching = gosettings.DefaultPointer(d.Caching, true)
d.IPv6 = gosettings.DefaultPointer(d.IPv6, false) d.IPv6 = gosettings.DefaultPointer(d.IPv6, false)
d.Blacklist.setDefaults() d.Blacklist.setDefaults()
d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{}) d.ServerAddress = gosettings.DefaultValidator(d.ServerAddress,
} netip.AddrFrom4([4]byte{127, 0, 0, 1}))
d.KeepNameserver = gosettings.DefaultPointer(d.KeepNameserver, false)
func defaultDNSProviders() []string {
return []string{
provider.Cloudflare().Name,
}
} }
func (d DNS) GetFirstPlaintextIPv4() (ipv4 netip.Addr) { func (d DNS) GetFirstPlaintextIPv4() (ipv4 netip.Addr) {
if d.UpstreamType == DNSUpstreamTypePlain { localhost := netip.AddrFrom4([4]byte{127, 0, 0, 1})
for _, addrPort := range d.UpstreamPlainAddresses { if d.ServerAddress.Compare(localhost) != 0 && d.ServerAddress.Is4() {
if addrPort.Addr().Is4() { return d.ServerAddress
return addrPort.Addr()
}
}
} }
ipv4 = findPlainIPv4InProviders(d.Providers)
if ipv4.IsValid() {
return ipv4
}
// Either:
// - all upstream plain addresses are IPv6 and no provider is set
// - all providers set do not have a plaintext IPv4 address
ipv4 = findPlainIPv4InProviders(defaultDNSProviders())
if !ipv4.IsValid() {
panic("no plaintext IPv4 address found in default DNS providers")
}
return ipv4
}
func findPlainIPv4InProviders(providerNames []string) netip.Addr {
providers := provider.NewProviders() providers := provider.NewProviders()
for _, name := range providerNames { provider, err := providers.Get(d.Providers[0])
provider, err := providers.Get(name) if err != nil {
if err != nil { // Settings should be validated before calling this function,
// Settings should be validated before calling this function, // so an error happening here is a programming error.
// so an error happening here is a programming error. panic(err)
panic(err)
}
if len(provider.Plain.IPv4) > 0 {
return provider.Plain.IPv4[0].Addr()
}
} }
return netip.Addr{}
return provider.Plain.IPv4[0].Addr()
} }
func (d DNS) String() string { func (d DNS) String() string {
@@ -186,22 +155,22 @@ func (d DNS) String() string {
func (d DNS) toLinesNode() (node *gotree.Node) { func (d DNS) toLinesNode() (node *gotree.Node) {
node = gotree.New("DNS settings:") node = gotree.New("DNS settings:")
node.Appendf("Keep existing nameserver(s): %s", gosettings.BoolToYesNo(d.KeepNameserver))
if *d.KeepNameserver {
return node
}
node.Appendf("DNS server address to use: %s", d.ServerAddress)
node.Appendf("DNS forwarder server enabled: %s", gosettings.BoolToYesNo(d.ServerEnabled))
if !*d.ServerEnabled {
return node
}
node.Appendf("Upstream resolver type: %s", d.UpstreamType) node.Appendf("Upstream resolver type: %s", d.UpstreamType)
upstreamResolvers := node.Append("Upstream resolvers:") upstreamResolvers := node.Append("Upstream resolvers:")
if len(d.UpstreamPlainAddresses) > 0 { for _, provider := range d.Providers {
if d.UpstreamType == DNSUpstreamTypePlain { upstreamResolvers.Append(provider)
for _, addr := range d.UpstreamPlainAddresses {
upstreamResolvers.Append(addr.String())
}
} else {
node.Appendf("Upstream plain addresses: ignored because upstream type is not plain")
}
} else {
for _, provider := range d.Providers {
upstreamResolvers.Append(provider)
}
} }
node.Appendf("Caching: %s", gosettings.BoolToYesNo(d.Caching)) node.Appendf("Caching: %s", gosettings.BoolToYesNo(d.Caching))
@@ -219,6 +188,11 @@ func (d DNS) toLinesNode() (node *gotree.Node) {
} }
func (d *DNS) read(r *reader.Reader) (err error) { func (d *DNS) read(r *reader.Reader) (err error) {
d.ServerEnabled, err = r.BoolPtr("DNS_SERVER", reader.RetroKeys("DOT"))
if err != nil {
return err
}
d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE") d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE")
d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD") d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD")
@@ -243,43 +217,15 @@ func (d *DNS) read(r *reader.Reader) (err error) {
return err return err
} }
err = d.readUpstreamPlainAddresses(r) d.ServerAddress, err = r.NetipAddr("DNS_ADDRESS", reader.RetroKeys("DNS_PLAINTEXT_ADDRESS"))
if err != nil {
return err
}
d.KeepNameserver, err = r.BoolPtr("DNS_KEEP_NAMESERVER")
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
// If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_TYPE=plain
// for these to be used. This is an added safety measure to reduce misunderstandings, and
// reduce odd settings overrides.
d.UpstreamPlainAddresses, err = r.CSVNetipAddrPorts("DNS_UPSTREAM_PLAIN_ADDRESSES")
if err != nil {
return err
}
// Retro-compatibility - remove in v4
// If DNS_ADDRESS is set to a non-localhost address, append it to the other
// upstream plain addresses, assuming port 53, and force the upstream type to plain AND
// clear any user picked providers, to maintain retro-compatibility behavior.
serverAddress, err := r.NetipAddr("DNS_ADDRESS",
reader.RetroKeys("DNS_PLAINTEXT_ADDRESS"),
reader.IsRetro("DNS_UPSTREAM_PLAIN_ADDRESSES"))
if err != nil {
return err
} else if !serverAddress.IsValid() {
return nil
}
isLocalhost := serverAddress.Compare(netip.AddrFrom4([4]byte{127, 0, 0, 1})) == 0
if isLocalhost {
return nil
}
const defaultPlainPort = 53
addrPort := netip.AddrPortFrom(serverAddress, defaultPlainPort)
d.UpstreamPlainAddresses = append(d.UpstreamPlainAddresses, addrPort)
d.UpstreamType = DNSUpstreamTypePlain
d.Providers = []string{}
return nil
}

View File

@@ -1,26 +0,0 @@
package settings
import (
"testing"
"github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/require"
)
func Test_defaultDNSProviders(t *testing.T) {
t.Parallel()
names := defaultDNSProviders()
found := false
providers := provider.NewProviders()
for _, name := range names {
provider, err := providers.Get(name)
require.NoError(t, err)
if len(provider.Plain.IPv4) > 0 {
found = true
break
}
}
require.True(t, found, "no default DNS provider has a plaintext IPv4 address")
}

View File

@@ -22,9 +22,6 @@ type DNSBlacklist struct {
AddBlockedHosts []string AddBlockedHosts []string
AddBlockedIPs []netip.Addr AddBlockedIPs []netip.Addr
AddBlockedIPPrefixes []netip.Prefix AddBlockedIPPrefixes []netip.Prefix
// RebindingProtectionExemptHostnames is a list of hostnames
// exempt from DNS rebinding protection.
RebindingProtectionExemptHostnames []string
} }
func (b *DNSBlacklist) setDefaults() { func (b *DNSBlacklist) setDefaults() {
@@ -36,9 +33,8 @@ func (b *DNSBlacklist) setDefaults() {
var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9_])(\.([a-zA-Z0-9]|[a-zA-Z0-9_][a-zA-Z0-9\-_]{0,61}[a-zA-Z0-9]))*$`) //nolint:lll
var ( var (
ErrAllowedHostNotValid = errors.New("allowed host is not valid") ErrAllowedHostNotValid = errors.New("allowed host is not valid")
ErrBlockedHostNotValid = errors.New("blocked host is not valid") ErrBlockedHostNotValid = errors.New("blocked host is not valid")
ErrRebindingProtectionExemptHostNotValid = errors.New("rebinding protection exempt host is not valid")
) )
func (b DNSBlacklist) validate() (err error) { func (b DNSBlacklist) validate() (err error) {
@@ -54,25 +50,18 @@ func (b DNSBlacklist) validate() (err error) {
} }
} }
for _, host := range b.RebindingProtectionExemptHostnames {
if !hostRegex.MatchString(host) {
return fmt.Errorf("%w: %s", ErrRebindingProtectionExemptHostNotValid, host)
}
}
return nil return nil
} }
func (b DNSBlacklist) copy() (copied DNSBlacklist) { func (b DNSBlacklist) copy() (copied DNSBlacklist) {
return DNSBlacklist{ return DNSBlacklist{
BlockMalicious: gosettings.CopyPointer(b.BlockMalicious), BlockMalicious: gosettings.CopyPointer(b.BlockMalicious),
BlockAds: gosettings.CopyPointer(b.BlockAds), BlockAds: gosettings.CopyPointer(b.BlockAds),
BlockSurveillance: gosettings.CopyPointer(b.BlockSurveillance), BlockSurveillance: gosettings.CopyPointer(b.BlockSurveillance),
AllowedHosts: gosettings.CopySlice(b.AllowedHosts), AllowedHosts: gosettings.CopySlice(b.AllowedHosts),
AddBlockedHosts: gosettings.CopySlice(b.AddBlockedHosts), AddBlockedHosts: gosettings.CopySlice(b.AddBlockedHosts),
AddBlockedIPs: gosettings.CopySlice(b.AddBlockedIPs), AddBlockedIPs: gosettings.CopySlice(b.AddBlockedIPs),
AddBlockedIPPrefixes: gosettings.CopySlice(b.AddBlockedIPPrefixes), AddBlockedIPPrefixes: gosettings.CopySlice(b.AddBlockedIPPrefixes),
RebindingProtectionExemptHostnames: gosettings.CopySlice(b.RebindingProtectionExemptHostnames),
} }
} }
@@ -84,8 +73,6 @@ func (b *DNSBlacklist) overrideWith(other DNSBlacklist) {
b.AddBlockedHosts = gosettings.OverrideWithSlice(b.AddBlockedHosts, other.AddBlockedHosts) b.AddBlockedHosts = gosettings.OverrideWithSlice(b.AddBlockedHosts, other.AddBlockedHosts)
b.AddBlockedIPs = gosettings.OverrideWithSlice(b.AddBlockedIPs, other.AddBlockedIPs) b.AddBlockedIPs = gosettings.OverrideWithSlice(b.AddBlockedIPs, other.AddBlockedIPs)
b.AddBlockedIPPrefixes = gosettings.OverrideWithSlice(b.AddBlockedIPPrefixes, other.AddBlockedIPPrefixes) b.AddBlockedIPPrefixes = gosettings.OverrideWithSlice(b.AddBlockedIPPrefixes, other.AddBlockedIPPrefixes)
b.RebindingProtectionExemptHostnames = gosettings.OverrideWithSlice(b.RebindingProtectionExemptHostnames,
other.RebindingProtectionExemptHostnames)
} }
func (b DNSBlacklist) ToBlockBuilderSettings(client *http.Client) ( func (b DNSBlacklist) ToBlockBuilderSettings(client *http.Client) (
@@ -142,13 +129,6 @@ func (b DNSBlacklist) toLinesNode() (node *gotree.Node) {
} }
} }
if len(b.RebindingProtectionExemptHostnames) > 0 {
exemptHostsNode := node.Append("Rebinding protection exempt hostnames:")
for _, host := range b.RebindingProtectionExemptHostnames {
exemptHostsNode.Append(host)
}
}
return node return node
} }
@@ -176,8 +156,6 @@ func (b *DNSBlacklist) read(r *reader.Reader) (err error) {
b.AllowedHosts = r.CSV("DNS_UNBLOCK_HOSTNAMES", reader.RetroKeys("UNBLOCK")) b.AllowedHosts = r.CSV("DNS_UNBLOCK_HOSTNAMES", reader.RetroKeys("UNBLOCK"))
b.RebindingProtectionExemptHostnames = r.CSV("DNS_REBINDING_PROTECTION_EXEMPT_HOSTNAMES")
return nil return nil
} }

View File

@@ -36,8 +36,6 @@ var (
ErrSystemPUIDNotValid = errors.New("process user id is not valid") ErrSystemPUIDNotValid = errors.New("process user id is not valid")
ErrSystemTimezoneNotValid = errors.New("timezone is not valid") ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small") ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
ErrUpdaterProtonUsernameMissing = errors.New("proton username is missing")
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid") ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
ErrVPNTypeNotValid = errors.New("VPN type is not valid") ErrVPNTypeNotValid = errors.New("VPN type is not valid")
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set") ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"os" "os"
"time"
"github.com/qdm12/gosettings" "github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader" "github.com/qdm12/gosettings/reader"
@@ -17,6 +18,12 @@ type Health struct {
// for the health check server. // for the health check server.
// It cannot be the empty string in the internal state. // It cannot be the empty string in the internal state.
ServerAddress string ServerAddress string
// ReadHeaderTimeout is the HTTP server header read timeout
// duration of the HTTP server. It defaults to 100 milliseconds.
ReadHeaderTimeout time.Duration
// ReadTimeout is the HTTP read timeout duration of the
// HTTP server. It defaults to 500 milliseconds.
ReadTimeout time.Duration
// TargetAddress is the address (host or host:port) // TargetAddress is the address (host or host:port)
// to TCP TLS dial to periodically for the health check. // to TCP TLS dial to periodically for the health check.
// It cannot be the empty string in the internal state. // It cannot be the empty string in the internal state.
@@ -41,10 +48,12 @@ func (h Health) Validate() (err error) {
func (h *Health) copy() (copied Health) { func (h *Health) copy() (copied Health) {
return Health{ return Health{
ServerAddress: h.ServerAddress, ServerAddress: h.ServerAddress,
TargetAddress: h.TargetAddress, ReadHeaderTimeout: h.ReadHeaderTimeout,
ICMPTargetIP: h.ICMPTargetIP, ReadTimeout: h.ReadTimeout,
RestartVPN: gosettings.CopyPointer(h.RestartVPN), TargetAddress: h.TargetAddress,
ICMPTargetIP: h.ICMPTargetIP,
RestartVPN: gosettings.CopyPointer(h.RestartVPN),
} }
} }
@@ -53,6 +62,8 @@ func (h *Health) copy() (copied Health) {
// settings. // settings.
func (h *Health) OverrideWith(other Health) { func (h *Health) OverrideWith(other Health) {
h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress) h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress)
h.ReadHeaderTimeout = gosettings.OverrideWithComparable(h.ReadHeaderTimeout, other.ReadHeaderTimeout)
h.ReadTimeout = gosettings.OverrideWithComparable(h.ReadTimeout, other.ReadTimeout)
h.TargetAddress = gosettings.OverrideWithComparable(h.TargetAddress, other.TargetAddress) h.TargetAddress = gosettings.OverrideWithComparable(h.TargetAddress, other.TargetAddress)
h.ICMPTargetIP = gosettings.OverrideWithComparable(h.ICMPTargetIP, other.ICMPTargetIP) h.ICMPTargetIP = gosettings.OverrideWithComparable(h.ICMPTargetIP, other.ICMPTargetIP)
h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN) h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN)
@@ -60,6 +71,10 @@ func (h *Health) OverrideWith(other Health) {
func (h *Health) SetDefaults() { func (h *Health) SetDefaults() {
h.ServerAddress = gosettings.DefaultComparable(h.ServerAddress, "127.0.0.1:9999") h.ServerAddress = gosettings.DefaultComparable(h.ServerAddress, "127.0.0.1:9999")
const defaultReadHeaderTimeout = 100 * time.Millisecond
h.ReadHeaderTimeout = gosettings.DefaultComparable(h.ReadHeaderTimeout, defaultReadHeaderTimeout)
const defaultReadTimeout = 500 * time.Millisecond
h.ReadTimeout = gosettings.DefaultComparable(h.ReadTimeout, defaultReadTimeout)
h.TargetAddress = gosettings.DefaultComparable(h.TargetAddress, "cloudflare.com:443") h.TargetAddress = gosettings.DefaultComparable(h.TargetAddress, "cloudflare.com:443")
h.ICMPTargetIP = gosettings.DefaultComparable(h.ICMPTargetIP, netip.IPv4Unspecified()) // use the VPN server IP h.ICMPTargetIP = gosettings.DefaultComparable(h.ICMPTargetIP, netip.IPv4Unspecified()) // use the VPN server IP
h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true) h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true)

View File

@@ -1,14 +1,11 @@
package settings package settings
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"net" "net"
"os" "os"
"strconv" "strconv"
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
"github.com/qdm12/gosettings" "github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader" "github.com/qdm12/gosettings/reader"
"github.com/qdm12/gotree" "github.com/qdm12/gotree"
@@ -27,9 +24,6 @@ type ControlServer struct {
// It cannot be empty in the internal state and defaults to // It cannot be empty in the internal state and defaults to
// /gluetun/auth/config.toml. // /gluetun/auth/config.toml.
AuthFilePath string AuthFilePath string
// AuthDefaultRole is a JSON encoded object defining the default role
// that applies to all routes without a previously user-defined role assigned to.
AuthDefaultRole string
} }
func (c ControlServer) validate() (err error) { func (c ControlServer) validate() (err error) {
@@ -50,30 +44,14 @@ func (c ControlServer) validate() (err error) {
ErrControlServerPrivilegedPort, port, uid) ErrControlServerPrivilegedPort, port, uid)
} }
jsonDecoder := json.NewDecoder(bytes.NewBufferString(c.AuthDefaultRole))
jsonDecoder.DisallowUnknownFields()
var role auth.Role
err = jsonDecoder.Decode(&role)
if err != nil {
return fmt.Errorf("default authentication role is not valid JSON: %w", err)
}
if role.Auth != "" {
err = role.Validate()
if err != nil {
return fmt.Errorf("default authentication role is not valid: %w", err)
}
}
return nil return nil
} }
func (c *ControlServer) copy() (copied ControlServer) { func (c *ControlServer) copy() (copied ControlServer) {
return ControlServer{ return ControlServer{
Address: gosettings.CopyPointer(c.Address), Address: gosettings.CopyPointer(c.Address),
Log: gosettings.CopyPointer(c.Log), Log: gosettings.CopyPointer(c.Log),
AuthFilePath: c.AuthFilePath, AuthFilePath: c.AuthFilePath,
AuthDefaultRole: c.AuthDefaultRole,
} }
} }
@@ -84,21 +62,12 @@ func (c *ControlServer) overrideWith(other ControlServer) {
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address) c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log) c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath) c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath)
c.AuthDefaultRole = gosettings.OverrideWithComparable(c.AuthDefaultRole, other.AuthDefaultRole)
} }
func (c *ControlServer) setDefaults() { func (c *ControlServer) setDefaults() {
c.Address = gosettings.DefaultPointer(c.Address, ":8000") c.Address = gosettings.DefaultPointer(c.Address, ":8000")
c.Log = gosettings.DefaultPointer(c.Log, true) c.Log = gosettings.DefaultPointer(c.Log, true)
c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml") c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml")
c.AuthDefaultRole = gosettings.DefaultComparable(c.AuthDefaultRole, "{}")
if c.AuthDefaultRole != "{}" {
var role auth.Role
_ = json.Unmarshal([]byte(c.AuthDefaultRole), &role)
role.Name = "default"
roleBytes, _ := json.Marshal(role) //nolint:errchkjson
c.AuthDefaultRole = string(roleBytes)
}
} }
func (c ControlServer) String() string { func (c ControlServer) String() string {
@@ -110,11 +79,6 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) {
node.Appendf("Listening address: %s", *c.Address) node.Appendf("Listening address: %s", *c.Address)
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log)) node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
node.Appendf("Authentication file path: %s", c.AuthFilePath) node.Appendf("Authentication file path: %s", c.AuthFilePath)
if c.AuthDefaultRole != "{}" {
var role auth.Role
_ = json.Unmarshal([]byte(c.AuthDefaultRole), &role)
node.AppendNode(role.ToLinesNode())
}
return node return node
} }
@@ -127,7 +91,6 @@ func (c *ControlServer) read(r *reader.Reader) (err error) {
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS") c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH") c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH")
c.AuthDefaultRole = r.String("HTTP_CONTROL_SERVER_AUTH_DEFAULT_ROLE")
return nil return nil
} }

View File

@@ -2,6 +2,7 @@ package settings
import ( import (
"fmt" "fmt"
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
@@ -173,11 +174,13 @@ func (s Settings) Warnings() (warnings []string) {
"by creating an issue, attaching the new certificate and we will update Gluetun.") "by creating an issue, attaching the new certificate and we will update Gluetun.")
} }
for _, upstreamAddress := range s.DNS.UpstreamPlainAddresses { // TODO remove in v4
if upstreamAddress.Addr().IsPrivate() { if s.DNS.ServerAddress.Unmap().Compare(netip.AddrFrom4([4]byte{127, 0, 0, 1})) != 0 {
warnings = append(warnings, "DNS upstream address "+upstreamAddress.String()+" is private: "+ warnings = append(warnings, "DNS address is set to "+s.DNS.ServerAddress.String()+
"DNS traffic might leak out of the VPN tunnel to that address.") " so the local forwarding DNS server will not be used."+
} " The default value changed to 127.0.0.1 so it uses the internal DNS server."+
" If this server fails to start, the IPv4 address of the first plaintext DNS server"+
" corresponding to the first DNS provider chosen is used.")
} }
return warnings return warnings

View File

@@ -38,6 +38,9 @@ func Test_Settings_String(t *testing.T) {
| ├── Run OpenVPN as: root | ├── Run OpenVPN as: root
| └── Verbosity level: 1 | └── Verbosity level: 1
├── DNS settings: ├── DNS settings:
| ├── Keep existing nameserver(s): no
| ├── DNS server address to use: 127.0.0.1
| ├── DNS forwarder server enabled: yes
| ├── Upstream resolver type: dot | ├── Upstream resolver type: dot
| ├── Upstream resolvers: | ├── Upstream resolvers:
| | └── Cloudflare | | └── Cloudflare

View File

@@ -2,7 +2,6 @@ package settings
import ( import (
"fmt" "fmt"
"slices"
"strings" "strings"
"time" "time"
@@ -32,10 +31,6 @@ type Updater struct {
// Providers is the list of VPN service providers // Providers is the list of VPN service providers
// to update server information for. // to update server information for.
Providers []string Providers []string
// ProtonUsername is the username to authenticate with the Proton API.
ProtonUsername *string
// ProtonPassword is the password to authenticate with the Proton API.
ProtonPassword *string
} }
func (u Updater) Validate() (err error) { func (u Updater) Validate() (err error) {
@@ -56,18 +51,6 @@ func (u Updater) Validate() (err error) {
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrVPNProviderNameNotValid, err) return fmt.Errorf("%w: %w", ErrVPNProviderNameNotValid, err)
} }
if provider == providers.Protonvpn {
authenticatedAPI := *u.ProtonUsername != "" || *u.ProtonPassword != ""
if authenticatedAPI {
switch {
case *u.ProtonUsername == "":
return fmt.Errorf("%w", ErrUpdaterProtonUsernameMissing)
case *u.ProtonPassword == "":
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
}
}
}
} }
return nil return nil
@@ -75,12 +58,10 @@ func (u Updater) Validate() (err error) {
func (u *Updater) copy() (copied Updater) { func (u *Updater) copy() (copied Updater) {
return Updater{ return Updater{
Period: gosettings.CopyPointer(u.Period), Period: gosettings.CopyPointer(u.Period),
DNSAddress: u.DNSAddress, DNSAddress: u.DNSAddress,
MinRatio: u.MinRatio, MinRatio: u.MinRatio,
Providers: gosettings.CopySlice(u.Providers), Providers: gosettings.CopySlice(u.Providers),
ProtonUsername: gosettings.CopyPointer(u.ProtonUsername),
ProtonPassword: gosettings.CopyPointer(u.ProtonPassword),
} }
} }
@@ -92,8 +73,6 @@ func (u *Updater) overrideWith(other Updater) {
u.DNSAddress = gosettings.OverrideWithComparable(u.DNSAddress, other.DNSAddress) u.DNSAddress = gosettings.OverrideWithComparable(u.DNSAddress, other.DNSAddress)
u.MinRatio = gosettings.OverrideWithComparable(u.MinRatio, other.MinRatio) u.MinRatio = gosettings.OverrideWithComparable(u.MinRatio, other.MinRatio)
u.Providers = gosettings.OverrideWithSlice(u.Providers, other.Providers) u.Providers = gosettings.OverrideWithSlice(u.Providers, other.Providers)
u.ProtonUsername = gosettings.OverrideWithPointer(u.ProtonUsername, other.ProtonUsername)
u.ProtonPassword = gosettings.OverrideWithPointer(u.ProtonPassword, other.ProtonPassword)
} }
func (u *Updater) SetDefaults(vpnProvider string) { func (u *Updater) SetDefaults(vpnProvider string) {
@@ -108,10 +87,6 @@ func (u *Updater) SetDefaults(vpnProvider string) {
if len(u.Providers) == 0 && vpnProvider != providers.Custom { if len(u.Providers) == 0 && vpnProvider != providers.Custom {
u.Providers = []string{vpnProvider} u.Providers = []string{vpnProvider}
} }
// Set these to empty strings to avoid nil pointer panics
u.ProtonUsername = gosettings.DefaultPointer(u.ProtonUsername, "")
u.ProtonPassword = gosettings.DefaultPointer(u.ProtonPassword, "")
} }
func (u Updater) String() string { func (u Updater) String() string {
@@ -128,10 +103,6 @@ func (u Updater) toLinesNode() (node *gotree.Node) {
node.Appendf("DNS address: %s", u.DNSAddress) node.Appendf("DNS address: %s", u.DNSAddress)
node.Appendf("Minimum ratio: %.1f", u.MinRatio) node.Appendf("Minimum ratio: %.1f", u.MinRatio)
node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", ")) node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", "))
if slices.Contains(u.Providers, providers.Protonvpn) {
node.Appendf("Proton API username: %s", *u.ProtonUsername)
node.Appendf("Proton API password: %s", gosettings.ObfuscateKey(*u.ProtonPassword))
}
return node return node
} }
@@ -154,14 +125,6 @@ func (u *Updater) read(r *reader.Reader) (err error) {
u.Providers = r.CSV("UPDATER_VPN_SERVICE_PROVIDERS") u.Providers = r.CSV("UPDATER_VPN_SERVICE_PROVIDERS")
u.ProtonUsername = r.Get("UPDATER_PROTONVPN_USERNAME")
if u.ProtonUsername != nil {
// Enforce to use the username not the email address
*u.ProtonUsername = strings.TrimSuffix(*u.ProtonUsername, "@protonmail.com")
*u.ProtonUsername = strings.TrimSuffix(*u.ProtonUsername, "@proton.me")
}
u.ProtonPassword = r.Get("UPDATER_PROTONVPN_PASSWORD")
return nil return nil
} }

View File

@@ -45,6 +45,7 @@ type Wireguard struct {
// It has been lowered to 1320 following quite a bit of // It has been lowered to 1320 following quite a bit of
// investigation in the issue: // investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533. // https://github.com/qdm12/gluetun/issues/2533.
// Note this should now be replaced with the PMTUD feature.
MTU uint16 `json:"mtu"` MTU uint16 `json:"mtu"`
// Implementation is the Wireguard implementation to use. // Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace". // It can be "auto", "userspace" or "kernelspace".

View File

@@ -155,8 +155,7 @@ func (w WireguardSelection) toLinesNode() (node *gotree.Node) {
func (w *WireguardSelection) read(r *reader.Reader) (err error) { func (w *WireguardSelection) read(r *reader.Reader) (err error) {
w.EndpointIP, err = r.NetipAddr("WIREGUARD_ENDPOINT_IP", reader.RetroKeys("VPN_ENDPOINT_IP")) w.EndpointIP, err = r.NetipAddr("WIREGUARD_ENDPOINT_IP", reader.RetroKeys("VPN_ENDPOINT_IP"))
if err != nil { if err != nil {
return fmt.Errorf("%w - note this MUST be an IP address, "+ return err
"see https://github.com/qdm12/gluetun/issues/788", err)
} }
w.EndpointPort, err = r.Uint16Ptr("WIREGUARD_ENDPOINT_PORT", reader.RetroKeys("VPN_ENDPOINT_PORT")) w.EndpointPort, err = r.Uint16Ptr("WIREGUARD_ENDPOINT_PORT", reader.RetroKeys("VPN_ENDPOINT_PORT"))

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"net/netip"
"time" "time"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/mapfilter" "github.com/qdm12/dns/v2/pkg/middlewares/filter/mapfilter"
@@ -17,23 +16,22 @@ import (
) )
type Loop struct { type Loop struct {
statusManager *loopstate.State statusManager *loopstate.State
state *state.State state *state.State
server *server.Server server *server.Server
filter *mapfilter.Filter filter *mapfilter.Filter
localResolvers []netip.Addr resolvConf string
resolvConf string client *http.Client
client *http.Client logger Logger
logger Logger userTrigger bool
userTrigger bool start <-chan struct{}
start <-chan struct{} running chan<- models.LoopStatus
running chan<- models.LoopStatus stop <-chan struct{}
stop <-chan struct{} stopped chan<- struct{}
stopped chan<- struct{} updateTicker <-chan struct{}
updateTicker <-chan struct{} backoffTime time.Duration
backoffTime time.Duration timeNow func() time.Time
timeNow func() time.Time timeSince func(time.Time) time.Duration
timeSince func(time.Time) time.Duration
} }
const defaultBackoffTime = 10 * time.Second const defaultBackoffTime = 10 * time.Second
@@ -50,9 +48,7 @@ func NewLoop(settings settings.DNS,
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped) statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
state := state.New(statusManager, settings, updateTicker) state := state.New(statusManager, settings, updateTicker)
filter, err := mapfilter.New(mapfilter.Settings{ filter, err := mapfilter.New(mapfilter.Settings{})
Logger: buildFilterLogger(logger),
})
if err != nil { if err != nil {
return nil, fmt.Errorf("creating map filter: %w", err) return nil, fmt.Errorf("creating map filter: %w", err)
} }
@@ -104,15 +100,3 @@ func (l *Loop) signalOrSetStatus(status models.LoopStatus) {
l.statusManager.SetStatus(status) l.statusManager.SetStatus(status)
} }
} }
type filterLogger struct {
logger Logger
}
func (l *filterLogger) Log(msg string) {
l.logger.Info(msg)
}
func buildFilterLogger(logger Logger) *filterLogger {
return &filterLogger{logger: logger}
}

View File

@@ -4,23 +4,21 @@ import (
"context" "context"
"errors" "errors"
"github.com/qdm12/dns/v2/pkg/nameserver"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
) )
func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
defer close(done) defer close(done)
var err error if *l.GetSettings().KeepNameserver {
l.localResolvers, err = nameserver.GetPrivateDNSServers() l.logger.Warn("⚠️⚠️⚠️ keeping the default container nameservers, " +
if err != nil { "this will likely leak DNS traffic outside the VPN " +
l.logger.Error("getting private DNS servers: " + err.Error()) "and go through your container network DNS outside the VPN tunnel!")
return } else {
const fallback = false
l.useUnencryptedDNS(fallback)
} }
const fallback = false
l.useUnencryptedDNS(fallback)
select { select {
case <-l.start: case <-l.start:
case <-ctx.Done(): case <-ctx.Done():
@@ -32,12 +30,14 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
// Their values are to be used if DOT=off // Their values are to be used if DOT=off
var runError <-chan error var runError <-chan error
for { settings := l.GetSettings()
for !*settings.KeepNameserver && *settings.ServerEnabled {
var err error var err error
runError, err = l.setupServer(ctx) runError, err = l.setupServer(ctx)
if err == nil { if err == nil {
l.backoffTime = defaultBackoffTime l.backoffTime = defaultBackoffTime
l.logger.Info("ready") l.logger.Info("ready")
l.signalOrSetStatus(constants.Running)
break break
} }
@@ -52,11 +52,14 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
} }
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
settings = l.GetSettings()
} }
l.signalOrSetStatus(constants.Running)
const fallback = false settings = l.GetSettings()
l.useUnencryptedDNS(fallback) if !*settings.KeepNameserver && !*settings.ServerEnabled {
const fallback = false
l.useUnencryptedDNS(fallback)
}
l.userTrigger = false l.userTrigger = false

View File

@@ -3,7 +3,6 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"github.com/qdm12/dns/v2/pkg/doh" "github.com/qdm12/dns/v2/pkg/doh"
"github.com/qdm12/dns/v2/pkg/dot" "github.com/qdm12/dns/v2/pkg/dot"
@@ -11,7 +10,6 @@ import (
"github.com/qdm12/dns/v2/pkg/middlewares/cache/lru" "github.com/qdm12/dns/v2/pkg/middlewares/cache/lru"
filtermiddleware "github.com/qdm12/dns/v2/pkg/middlewares/filter" filtermiddleware "github.com/qdm12/dns/v2/pkg/middlewares/filter"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/mapfilter" "github.com/qdm12/dns/v2/pkg/middlewares/filter/mapfilter"
"github.com/qdm12/dns/v2/pkg/middlewares/localdns"
"github.com/qdm12/dns/v2/pkg/plain" "github.com/qdm12/dns/v2/pkg/plain"
"github.com/qdm12/dns/v2/pkg/provider" "github.com/qdm12/dns/v2/pkg/provider"
"github.com/qdm12/dns/v2/pkg/server" "github.com/qdm12/dns/v2/pkg/server"
@@ -26,23 +24,30 @@ func (l *Loop) SetSettings(ctx context.Context, settings settings.DNS) (
return l.state.SetSettings(ctx, settings) return l.state.SetSettings(ctx, settings)
} }
func buildServerSettings(userSettings settings.DNS, func buildServerSettings(settings settings.DNS,
filter *mapfilter.Filter, localResolvers []netip.Addr, filter *mapfilter.Filter, logger Logger) (
logger Logger) (
serverSettings server.Settings, err error, serverSettings server.Settings, err error,
) { ) {
serverSettings.Logger = logger serverSettings.Logger = logger
upstreamResolvers := buildProviders(userSettings) providersData := provider.NewProviders()
upstreamResolvers := make([]provider.Provider, len(settings.Providers))
for i := range settings.Providers {
var err error
upstreamResolvers[i], err = providersData.Get(settings.Providers[i])
if err != nil {
panic(err) // this should already had been checked
}
}
ipVersion := "ipv4" ipVersion := "ipv4"
if *userSettings.IPv6 { if *settings.IPv6 {
ipVersion = "ipv6" ipVersion = "ipv6"
} }
var dialer server.Dialer var dialer server.Dialer
switch userSettings.UpstreamType { switch settings.UpstreamType {
case settings.DNSUpstreamTypeDot: case "dot":
dialerSettings := dot.Settings{ dialerSettings := dot.Settings{
UpstreamResolvers: upstreamResolvers, UpstreamResolvers: upstreamResolvers,
IPVersion: ipVersion, IPVersion: ipVersion,
@@ -51,7 +56,7 @@ func buildServerSettings(userSettings settings.DNS,
if err != nil { if err != nil {
return server.Settings{}, fmt.Errorf("creating DNS over TLS dialer: %w", err) return server.Settings{}, fmt.Errorf("creating DNS over TLS dialer: %w", err)
} }
case settings.DNSUpstreamTypeDoh: case "doh":
dialerSettings := doh.Settings{ dialerSettings := doh.Settings{
UpstreamResolvers: upstreamResolvers, UpstreamResolvers: upstreamResolvers,
IPVersion: ipVersion, IPVersion: ipVersion,
@@ -60,7 +65,7 @@ func buildServerSettings(userSettings settings.DNS,
if err != nil { if err != nil {
return server.Settings{}, fmt.Errorf("creating DNS over HTTPS dialer: %w", err) return server.Settings{}, fmt.Errorf("creating DNS over HTTPS dialer: %w", err)
} }
case settings.DNSUpstreamTypePlain: case "plain":
dialerSettings := plain.Settings{ dialerSettings := plain.Settings{
UpstreamResolvers: upstreamResolvers, UpstreamResolvers: upstreamResolvers,
IPVersion: ipVersion, IPVersion: ipVersion,
@@ -70,11 +75,11 @@ func buildServerSettings(userSettings settings.DNS,
return server.Settings{}, fmt.Errorf("creating plain DNS dialer: %w", err) return server.Settings{}, fmt.Errorf("creating plain DNS dialer: %w", err)
} }
default: default:
panic("unknown upstream type: " + userSettings.UpstreamType) panic("unknown upstream type: " + settings.UpstreamType)
} }
serverSettings.Dialer = dialer serverSettings.Dialer = dialer
if *userSettings.Caching { if *settings.Caching {
lruCache, err := lru.New(lru.Settings{}) lruCache, err := lru.New(lru.Settings{})
if err != nil { if err != nil {
return server.Settings{}, fmt.Errorf("creating LRU cache: %w", err) return server.Settings{}, fmt.Errorf("creating LRU cache: %w", err)
@@ -96,67 +101,5 @@ func buildServerSettings(userSettings settings.DNS,
} }
serverSettings.Middlewares = append(serverSettings.Middlewares, filterMiddleware) serverSettings.Middlewares = append(serverSettings.Middlewares, filterMiddleware)
localResolversAddrPorts := make([]netip.AddrPort, len(localResolvers))
const defaultDNSPort = 53
for i, addr := range localResolvers {
localResolversAddrPorts[i] = netip.AddrPortFrom(addr, defaultDNSPort)
}
localDNSMiddleware, err := localdns.New(localdns.Settings{
Resolvers: localResolversAddrPorts, // auto-detected at container start only
Logger: logger,
})
if err != nil {
return server.Settings{}, fmt.Errorf("creating local DNS middleware: %w", err)
}
// Place after cache middleware, since we want to avoid caching for local
// hostnames that may change regularly.
// Place after filter middleware to avoid conflicts with the rebinding protection.
serverSettings.Middlewares = append(serverSettings.Middlewares, localDNSMiddleware)
return serverSettings, nil return serverSettings, nil
} }
func buildProviders(userSettings settings.DNS) []provider.Provider {
if userSettings.UpstreamType == settings.DNSUpstreamTypePlain &&
len(userSettings.UpstreamPlainAddresses) > 0 {
providers := make([]provider.Provider, len(userSettings.UpstreamPlainAddresses))
for i, addrPort := range userSettings.UpstreamPlainAddresses {
providers[i] = provider.Provider{
Name: addrPort.String(),
}
if addrPort.Addr().Is4() {
providers[i].Plain.IPv4 = []netip.AddrPort{addrPort}
} else {
providers[i].Plain.IPv6 = []netip.AddrPort{addrPort}
}
}
}
providersData := provider.NewProviders()
providers := make([]provider.Provider, 0, len(userSettings.Providers)+len(userSettings.UpstreamPlainAddresses))
for _, providerName := range userSettings.Providers {
provider, err := providersData.Get(providerName)
if err != nil {
panic(err) // this should already had been checked
}
providers = append(providers, provider)
}
if userSettings.UpstreamType != settings.DNSUpstreamTypePlain {
return providers
}
for _, addrPort := range userSettings.UpstreamPlainAddresses {
newProvider := provider.Provider{
Name: addrPort.String(),
}
if addrPort.Addr().Is4() {
newProvider.Plain.IPv4 = []netip.AddrPort{addrPort}
} else {
newProvider.Plain.IPv6 = []netip.AddrPort{addrPort}
}
providers = append(providers, newProvider)
}
return providers
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net/netip"
"github.com/qdm12/dns/v2/pkg/check" "github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/dns/v2/pkg/nameserver" "github.com/qdm12/dns/v2/pkg/nameserver"
@@ -20,7 +21,7 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
settings := l.GetSettings() settings := l.GetSettings()
serverSettings, err := buildServerSettings(settings, l.filter, l.localResolvers, l.logger) serverSettings, err := buildServerSettings(settings, l.filter, l.logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("building server settings: %w", err) return nil, fmt.Errorf("building server settings: %w", err)
} }
@@ -37,8 +38,12 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
l.server = server l.server = server
// use internal DNS server // use internal DNS server
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{}) const defaultDNSPort = 53
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
})
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{ err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
IPs: []netip.Addr{settings.ServerAddress},
ResolvPath: l.resolvConf, ResolvPath: l.resolvConf,
}) })
if err != nil { if err != nil {

View File

@@ -40,6 +40,8 @@ func (s *State) SetSettings(ctx context.Context, settings settings.DNS) (
// Restart // Restart
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped) _, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
outcome, _ = s.statusApplier.ApplyStatus(ctx, constants.Running) if *settings.ServerEnabled {
outcome, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
}
return outcome return outcome
} }

View File

@@ -37,7 +37,6 @@ func (l *Loop) updateFiles(ctx context.Context) (err error) {
IPPrefixes: result.BlockedIPPrefixes, IPPrefixes: result.BlockedIPPrefixes,
} }
updateSettings.BlockHostnames(result.BlockedHostnames) updateSettings.BlockHostnames(result.BlockedHostnames)
updateSettings.SetRebindingProtectionExempt(settings.Blacklist.RebindingProtectionExemptHostnames)
err = l.filter.Update(updateSettings) err = l.filter.Update(updateSettings)
if err != nil { if err != nil {
return fmt.Errorf("updating filter: %w", err) return fmt.Errorf("updating filter: %w", err)

View File

@@ -131,18 +131,9 @@ func (c *Checker) smallPeriodicCheck(ctx context.Context) error {
c.configMutex.Lock() c.configMutex.Lock()
ip := c.icmpTarget ip := c.icmpTarget
c.configMutex.Unlock() c.configMutex.Unlock()
tryTimeouts := []time.Duration{ const maxTries = 3
5 * time.Second, const timeout = 10 * time.Second
5 * time.Second, const extraTryTime = 10 * time.Second // 10s added for each subsequent retry
5 * time.Second,
10 * time.Second,
10 * time.Second,
10 * time.Second,
15 * time.Second,
15 * time.Second,
15 * time.Second,
30 * time.Second,
}
check := func(ctx context.Context) error { check := func(ctx context.Context) error {
if c.icmpNotPermitted { if c.icmpNotPermitted {
return c.dnsClient.Check(ctx) return c.dnsClient.Check(ctx)
@@ -156,17 +147,19 @@ func (c *Checker) smallPeriodicCheck(ctx context.Context) error {
} }
return err return err
} }
return withRetries(ctx, tryTimeouts, c.logger, c.smallCheckName, check) return withRetries(ctx, maxTries, timeout, extraTryTime, c.logger, c.smallCheckName, check)
} }
func (c *Checker) fullPeriodicCheck(ctx context.Context) error { func (c *Checker) fullPeriodicCheck(ctx context.Context) error {
const maxTries = 2
// 20s timeout in case the connection is under stress // 20s timeout in case the connection is under stress
// See https://github.com/qdm12/gluetun/issues/2270 // See https://github.com/qdm12/gluetun/issues/2270
tryTimeouts := []time.Duration{10 * time.Second, 15 * time.Second, 30 * time.Second} const timeout = 20 * time.Second
const extraTryTime = 10 * time.Second // 10s added for each subsequent retry
check := func(ctx context.Context) error { check := func(ctx context.Context) error {
return tcpTLSCheck(ctx, c.dialer, c.tlsDialAddr) return tcpTLSCheck(ctx, c.dialer, c.tlsDialAddr)
} }
return withRetries(ctx, tryTimeouts, c.logger, "TCP+TLS dial", check) return withRetries(ctx, maxTries, timeout, extraTryTime, c.logger, "TCP+TLS dial", check)
} }
func tcpTLSCheck(ctx context.Context, dialer *net.Dialer, targetAddress string) error { func tcpTLSCheck(ctx context.Context, dialer *net.Dialer, targetAddress string) error {
@@ -225,17 +218,13 @@ func makeAddressToDial(address string) (addressToDial string, err error) {
var ErrAllCheckTriesFailed = errors.New("all check tries failed") var ErrAllCheckTriesFailed = errors.New("all check tries failed")
func withRetries(ctx context.Context, tryTimeouts []time.Duration, func withRetries(ctx context.Context, maxTries uint, tryTimeout, extraTryTime time.Duration,
logger Logger, checkName string, check func(ctx context.Context) error, logger Logger, checkName string, check func(ctx context.Context) error,
) error { ) error {
maxTries := len(tryTimeouts) try := uint(0)
type errData struct { var errs []error
err error for {
duration time.Duration timeout := tryTimeout + time.Duration(try)*extraTryTime //nolint:gosec
}
errs := make([]errData, maxTries)
for i, timeout := range tryTimeouts {
start := time.Now()
checkCtx, cancel := context.WithTimeout(ctx, timeout) checkCtx, cancel := context.WithTimeout(ctx, timeout)
err := check(checkCtx) err := check(checkCtx)
cancel() cancel()
@@ -245,14 +234,17 @@ func withRetries(ctx context.Context, tryTimeouts []time.Duration,
case ctx.Err() != nil: case ctx.Err() != nil:
return fmt.Errorf("%s: %w", checkName, ctx.Err()) return fmt.Errorf("%s: %w", checkName, ctx.Err())
} }
logger.Debugf("%s attempt %d/%d failed: %s", checkName, i+1, maxTries, err) logger.Debugf("%s attempt %d/%d failed: %s", checkName, try+1, maxTries, err)
errs[i].err = err errs = append(errs, err)
errs[i].duration = time.Since(start) try++
if try < maxTries {
continue
}
errStrings := make([]string, len(errs))
for i, err := range errs {
errStrings[i] = fmt.Sprintf("attempt %d: %s", i+1, err.Error())
}
return fmt.Errorf("%w: after %d %s attempts (%s)",
ErrAllCheckTriesFailed, maxTries, checkName, strings.Join(errStrings, "; "))
} }
errStrings := make([]string, len(errs))
for i, err := range errs {
errStrings[i] = fmt.Sprintf("attempt %d (%s): %s", i+1, err.duration, err.err)
}
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
} }

View File

@@ -44,12 +44,12 @@ func concatAddrPorts(addrs [][]netip.AddrPort) []netip.AddrPort {
var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup") var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup")
func (c *Client) Check(ctx context.Context) error { func (c *Client) Check(ctx context.Context) error {
dnsAddr := c.serverAddrs[c.dnsIPIndex].String() dnsAddr := c.serverAddrs[c.dnsIPIndex].Addr()
resolver := &net.Resolver{ resolver := &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{} dialer := net.Dialer{}
return dialer.DialContext(ctx, "udp", dnsAddr) return dialer.DialContext(ctx, "udp", dnsAddr.String())
}, },
} }
ips, err := resolver.LookupIP(ctx, "ip", "github.com") ips, err := resolver.LookupIP(ctx, "ip", "github.com")

View File

@@ -10,13 +10,11 @@ import (
func (s *Server) Run(ctx context.Context, done chan<- struct{}) { func (s *Server) Run(ctx context.Context, done chan<- struct{}) {
defer close(done) defer close(done)
const readHeaderTimeout = 100 * time.Millisecond
const readTimeout = 500 * time.Millisecond
server := http.Server{ server := http.Server{
Addr: s.config.ServerAddress, Addr: s.config.ServerAddress,
Handler: s.handler, Handler: s.handler,
ReadHeaderTimeout: readHeaderTimeout, ReadHeaderTimeout: s.config.ReadHeaderTimeout,
ReadTimeout: readTimeout, ReadTimeout: s.config.ReadTimeout,
} }
serverDone := make(chan struct{}) serverDone := make(chan struct{})
go func() { go func() {

View File

@@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link)) return netlink.LinkSetDown(linkToNetlinkLink(&link))
} }
func (n *NetLink) LinkSetMTU(link Link, mtu int) error {
return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu)
}
type netlinkLinkImpl struct { type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs attrs *netlink.LinkAttrs
linkType string linkType string

View File

@@ -0,0 +1,49 @@
package pmtud
import (
"net"
"time"
"golang.org/x/net/ipv4"
)
var _ net.PacketConn = &ipv4Wrapper{}
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
// the net.PacketConn interface. It's only used for Darwin or iOS.
type ipv4Wrapper struct {
ipv4Conn *ipv4.PacketConn
}
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
return &ipv4Wrapper{ipv4Conn: ipv4}
}
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
return n, addr, err
}
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return i.ipv4Conn.WriteTo(p, nil, addr)
}
func (i *ipv4Wrapper) Close() error {
return i.ipv4Conn.Close()
}
func (i *ipv4Wrapper) LocalAddr() net.Addr {
return i.ipv4Conn.LocalAddr()
}
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
return i.ipv4Conn.SetDeadline(t)
}
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
return i.ipv4Conn.SetReadDeadline(t)
}
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
return i.ipv4Conn.SetWriteDeadline(t)
}

83
internal/pmtud/check.go Normal file
View File

@@ -0,0 +1,83 @@
package pmtud
import (
"bytes"
"errors"
"fmt"
"golang.org/x/net/icmp"
)
var (
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU int) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
}
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
outboundMessage *icmp.Message,
) (match bool, err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return false, fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
) (err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
return fmt.Errorf("checking sent and received bodies: %w", err)
}
return nil
}
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrICMPEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrICMPEchoDataMismatch, sent, received)
}
return nil
}

10
internal/pmtud/df.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build !linux && !windows
package pmtud
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
// fragment flag is set on IP packets.
func setDontFragment(fd uintptr) (err error) {
return nil
}

View File

@@ -0,0 +1,12 @@
//go:build linux
package pmtud
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}

View File

@@ -0,0 +1,13 @@
//go:build windows
package pmtud
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
}

29
internal/pmtud/errors.go Normal file
View File

@@ -0,0 +1,29 @@
package pmtud
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrICMPNotPermitted = errors.New("ICMP not permitted")
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrICMPNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}

View File

@@ -0,0 +1,7 @@
package pmtud
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}

159
internal/pmtud/ipv4.go Normal file
View File

@@ -0,0 +1,159 @@
package pmtud
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"syscall"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
minIPv4MTU = 68
icmpv4Protocol = 1
)
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
}
return err
}
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return nil, err
}
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
}
return packetConn, nil
}
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
physicalLinkMTU int, pingTimeout time.Duration, logger Logger,
) (mtu int, err error) {
if ip.Is6() {
panic("IP address is not v4")
}
conn, err := listenICMPv4(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
for { // for loop in case we read an echo reply for another ICMP request
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
// Side note: echo reply should be at most the number of bytes
// sent, and can be lower, more precisely 576-ipHeader bytes,
// in case the next hop we are reaching replies with a destination
// unreachable and wants to ensure the response makes it way back
// by keeping a low packet size, see:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
const communicationAdministrativelyProhibitedCode = 13
switch inboundMessage.Code {
case fragmentationRequiredAndDFFlagSetCode:
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrICMPDestinationUnreachable,
ErrICMPCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrICMPDestinationUnreachable, inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = int(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}
// The code below is really for sanity checks
packetBytes = packetBytes[8:]
header, err := ipv4.ParseHeader(packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
}
packetBytes = packetBytes[header.Len:] // truncated original datagram
const truncated = true
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
if err != nil {
return 0, fmt.Errorf("checking echo reply: %w", err)
}
return mtu, nil
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}

122
internal/pmtud/ipv6.go Normal file
View File

@@ -0,0 +1,122 @@
package pmtud
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
const (
minIPv6MTU = 1280
icmpv6Protocol = 58
)
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return nil, err
}
return packetConn, nil
}
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
physicalLinkMTU int, pingTimeout time.Duration, logger Logger,
) (mtu int, err error) {
if ip.Is4() {
panic("IP address is not v6")
}
conn, err := listenICMPv6(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
for { // for loop if we encounter another ICMP packet with an unknown id.
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
packetBytes = packetBytes[ipv6.HeaderLen:]
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.PacketTooBig:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
mtu = typedBody.MTU
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking MTU: %w", err)
}
// Sanity checks
const truncatedBody = true
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
if err != nil {
return 0, fmt.Errorf("checking invoking message: %w", err)
}
return typedBody.MTU, nil
case *icmp.DstUnreach:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}

58
internal/pmtud/message.go Normal file
View File

@@ -0,0 +1,58 @@
package pmtud
import (
cryptorand "crypto/rand"
"encoding/binary"
"fmt"
"math/rand/v2"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func buildMessageToSend(ipVersion string, mtu int) (id uint16, message *icmp.Message) {
var seed [32]byte
_, _ = cryptorand.Read(seed[:])
randomSource := rand.NewChaCha8(seed)
const uint16Bytes = 2
idBytes := make([]byte, uint16Bytes)
_, _ = randomSource.Read(idBytes)
id = binary.BigEndian.Uint16(idBytes)
var ipHeaderLength int
var icmpType icmp.Type
switch ipVersion {
case "v4":
ipHeaderLength = ipv4.HeaderLen
icmpType = ipv4.ICMPTypeEcho
case "v6":
ipHeaderLength = ipv6.HeaderLen
icmpType = ipv6.ICMPTypeEchoRequest
default:
panic(fmt.Sprintf("IP version %q not supported", ipVersion))
}
const pingHeaderLength = 0 +
1 + // type
1 + // code
2 + // checksum
2 + // identifier
2 // sequence number
pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength
messageBodyData := make([]byte, pingBodyDataSize)
_, _ = randomSource.Read(messageBodyData)
// See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types
message = &icmp.Message{
Type: icmpType, // echo request
Code: 0, // no code
Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6)
Body: &icmp.Echo{
ID: int(id),
Seq: 0, // only one packet
Data: messageBodyData,
},
}
return id, message
}

View File

@@ -0,0 +1,7 @@
package pmtud
type noopLogger struct{}
func (noopLogger) Debug(_ string) {}
func (noopLogger) Debugf(_ string, _ ...any) {}
func (noopLogger) Warnf(_ string, _ ...any) {}

271
internal/pmtud/pmtud.go Normal file
View File

@@ -0,0 +1,271 @@
package pmtud
import (
"context"
"errors"
"fmt"
"math"
"net"
"net/netip"
"strings"
"time"
"golang.org/x/net/icmp"
)
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
// If the pingTimeout is zero, it defaults to 1 second.
// If the logger is nil, a no-op logger is used.
// It returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU int, pingTimeout time.Duration, logger Logger) (
mtu int, err error,
) {
if physicalLinkMTU == 0 {
const ethernetStandardMTU = 1500
physicalLinkMTU = ethernetStandardMTU
}
if pingTimeout == 0 {
pingTimeout = time.Second
}
if logger == nil {
logger = &noopLogger{}
}
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := minIPv4MTU
if ip.Is6() {
minMTU = minIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
}
type pmtudTestUnit struct {
mtu int
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU int, pingTimeout time.Duration,
logger Logger,
) (maxMTU int, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("testing the following MTUs: %v", mtusToTest)
tests := make([]pmtudTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// Create the MTU slice of length 11 such that:
// - the first element is the minMTU
// - the last element is the maxMTU
// - elements in-between are separated as close to each other
// The number 11 is chosen to find the final MTU in 3 searches,
// with a total search space of 1728 MTUs which is enough;
// to find it in 2 searches requires 37 parallel queries which
// could be blocked by firewalls.
func makeMTUsToTest(minMTU, maxMTU int) (mtus []int) {
const mtusLength = 11 // find the final MTU in 3 searches
diff := maxMTU - minMTU
switch {
case minMTU > maxMTU:
panic("minMTU > maxMTU")
case diff <= mtusLength:
mtus = make([]int, 0, diff)
for mtu := minMTU; mtu <= maxMTU; mtu++ {
mtus = append(mtus, mtu)
}
default:
step := float64(diff) / float64(mtusLength-1)
mtus = make([]int, 0, mtusLength)
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
mtus = append(mtus, int(math.Round(mtu)))
}
mtus = append(mtus, maxMTU) // last element is the maxMTU
}
return mtus
}
func collectReplies(conn net.PacketConn, ipVersion string,
tests []pmtudTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
}
return nil
}

View File

@@ -0,0 +1,22 @@
//go:build integration
package pmtud
import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_PathMTUDiscover(t *testing.T) {
t.Parallel()
const physicalLinkMTU = 1500
const timeout = time.Second
mtu, err := PathMTUDiscover(context.Background(), netip.MustParseAddr("1.1.1.1"),
physicalLinkMTU, timeout, nil)
require.NoError(t, err)
t.Log("MTU found:", mtu)
}

View File

@@ -0,0 +1,55 @@
package pmtud
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_makeMTUsToTest(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
minMTU int
maxMTU int
mtus []int
}{
"0_0": {
mtus: []int{0},
},
"0_1": {
maxMTU: 1,
mtus: []int{0, 1},
},
"0_8": {
maxMTU: 8,
mtus: []int{0, 1, 2, 3, 4, 5, 6, 7, 8},
},
"0_12": {
maxMTU: 12,
mtus: []int{0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12},
},
"0_80": {
maxMTU: 80,
mtus: []int{0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80},
},
"0_100": {
maxMTU: 100,
mtus: []int{0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
},
"1280_1500": {
minMTU: 1280,
maxMTU: 1500,
mtus: []int{1280, 1302, 1324, 1346, 1368, 1390, 1412, 1434, 1456, 1478, 1500},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
assert.Equal(t, testCase.mtus, mtus)
})
}
}

View File

@@ -18,7 +18,6 @@ func runCommand(ctx context.Context, cmder Cmder, logger Logger,
} }
portsString := strings.Join(portStrings, ",") portsString := strings.Join(portStrings, ",")
commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString) commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString)
commandString = strings.ReplaceAll(commandString, "{{PORT}}", portStrings[0])
args, err := command.Split(commandString) args, err := command.Split(commandString)
if err != nil { if err != nil {
return fmt.Errorf("parsing command: %w", err) return fmt.Errorf("parsing command: %w", err)

View File

@@ -14,11 +14,7 @@ func (s *Service) writePortForwardedFile(ports []uint16) (err error) {
fileData := []byte(strings.Join(portStrings, "\n")) fileData := []byte(strings.Join(portStrings, "\n"))
filepath := s.settings.Filepath filepath := s.settings.Filepath
if len(ports) == 0 { s.logger.Info("writing port file " + filepath)
s.logger.Info("clearing port file " + filepath)
} else {
s.logger.Info("writing port file " + filepath)
}
const perms = os.FileMode(0o644) const perms = os.FileMode(0o644)
err = os.WriteFile(filepath, fileData, perms) err = os.WriteFile(filepath, fileData, perms)
if err != nil { if err != nil {

View File

@@ -59,6 +59,8 @@ func (s *Service) cleanup() (err error) {
s.ports = nil s.ports = nil
filepath := s.settings.Filepath
s.logger.Info("clearing port file " + filepath)
err = s.writePortForwardedFile(nil) err = s.writePortForwardedFile(nil)
if err != nil { if err != nil {
return fmt.Errorf("clearing port file: %w", err) return fmt.Errorf("clearing port file: %w", err)

View File

@@ -13,7 +13,6 @@ var (
ErrNotEnoughServers = errors.New("not enough servers found") ErrNotEnoughServers = errors.New("not enough servers found")
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK") ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
ErrIPFetcherUnsupported = errors.New("IP fetcher not supported") ErrIPFetcherUnsupported = errors.New("IP fetcher not supported")
ErrCredentialsMissing = errors.New("credentials missing")
) )
type Fetcher interface { type Fetcher interface {

View File

@@ -18,12 +18,11 @@ type Provider struct {
func New(storage common.Storage, randSource rand.Source, func New(storage common.Storage, randSource rand.Source,
client *http.Client, updaterWarner common.Warner, client *http.Client, updaterWarner common.Warner,
username, password string,
) *Provider { ) *Provider {
return &Provider{ return &Provider{
storage: storage, storage: storage,
randSource: randSource, randSource: randSource,
Fetcher: updater.New(client, updaterWarner, username, password), Fetcher: updater.New(client, updaterWarner),
} }
} }

View File

@@ -1,567 +1,15 @@
package updater package updater
import ( import (
"bytes"
"context" "context"
crand "crypto/rand"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"math/rand/v2"
"net/http" "net/http"
"net/netip" "net/netip"
"slices"
"strings"
srp "github.com/ProtonMail/go-srp"
) )
// apiClient is a minimal Proton v4 API client which can handle all the var ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
// oddities of Proton's authentication flow they want to keep hidden
// from the public.
type apiClient struct {
apiURLBase string
httpClient *http.Client
appVersion string
userAgent string
generator *rand.ChaCha8
}
// newAPIClient returns an [apiClient] with sane defaults matching Proton's
// insane expectations.
func newAPIClient(ctx context.Context, httpClient *http.Client) (client *apiClient, err error) {
var seed [32]byte
_, _ = crand.Read(seed[:])
generator := rand.NewChaCha8(seed)
// Pick a random user agent from this list. Because I'm not going to tell
// Proton shit on where all these funny requests are coming from, given their
// unhelpfulness in figuring out their authentication flow.
userAgents := [...]string{
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:143.0) Gecko/20100101 Firefox/143.0",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:143.0) Gecko/20100101 Firefox/143.0",
"Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Gecko/20100101 Firefox/143.0",
}
userAgent := userAgents[generator.Uint64()%uint64(len(userAgents))]
appVersion, err := getMostRecentStableTag(ctx, httpClient)
if err != nil {
return nil, fmt.Errorf("getting most recent version for proton app: %w", err)
}
return &apiClient{
apiURLBase: "https://account.proton.me/api",
httpClient: httpClient,
appVersion: appVersion,
userAgent: userAgent,
generator: generator,
}, nil
}
var ErrCodeNotSuccess = errors.New("response code is not success")
// setHeaders sets the minimal necessary headers for Proton API requests
// to succeed without being blocked by their "security" measures.
// See for example [getMostRecentStableTag] on how the app version must
// be set to a recent version or they block your request. "SeCuRiTy"...
func (c *apiClient) setHeaders(request *http.Request, cookie cookie) {
request.Header.Set("Cookie", cookie.String())
request.Header.Set("User-Agent", c.userAgent)
request.Header.Set("x-pm-appversion", c.appVersion)
request.Header.Set("x-pm-locale", "en_US")
request.Header.Set("x-pm-uid", cookie.uid)
}
// authenticate performs the full Proton authentication flow
// to obtain an authenticated cookie (uid, token and session ID).
func (c *apiClient) authenticate(ctx context.Context, username, password string,
) (authCookie cookie, err error) {
sessionID, err := c.getSessionID(ctx)
if err != nil {
return cookie{}, fmt.Errorf("getting session ID: %w", err)
}
tokenType, accessToken, refreshToken, uid, err := c.getUnauthSession(ctx, sessionID)
if err != nil {
return cookie{}, fmt.Errorf("getting unauthenticated session data: %w", err)
}
cookieToken, err := c.cookieToken(ctx, sessionID, tokenType, accessToken, refreshToken, uid)
if err != nil {
return cookie{}, fmt.Errorf("getting cookie token: %w", err)
}
unauthCookie := cookie{
uid: uid,
token: cookieToken,
sessionID: sessionID,
}
modulusPGPClearSigned, serverEphemeralBase64, saltBase64,
srpSessionHex, version, err := c.authInfo(ctx, username, unauthCookie)
if err != nil {
return cookie{}, fmt.Errorf("getting auth information: %w", err)
}
// Prepare SRP proof generator using Proton's official SRP parameters and hashing.
srpAuth, err := srp.NewAuth(version, username, []byte(password),
saltBase64, modulusPGPClearSigned, serverEphemeralBase64)
if err != nil {
return cookie{}, fmt.Errorf("initializing SRP auth: %w", err)
}
// Generate SRP proofs (A, M1) with the usual 2048-bit modulus.
const modulusBits = 2048
proofs, err := srpAuth.GenerateProofs(modulusBits)
if err != nil {
return cookie{}, fmt.Errorf("generating SRP proofs: %w", err)
}
authCookie, err = c.auth(ctx, unauthCookie, username, srpSessionHex, proofs)
if err != nil {
return cookie{}, fmt.Errorf("authentifying: %w", err)
}
return authCookie, nil
}
var ErrSessionIDNotFound = errors.New("session ID not found in cookies")
func (c *apiClient) getSessionID(ctx context.Context) (sessionID string, err error) {
const url = "https://account.proton.me/vpn"
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
response, err := c.httpClient.Do(request)
if err != nil {
return "", err
}
err = response.Body.Close()
if err != nil {
return "", fmt.Errorf("closing response body: %w", err)
}
for _, cookie := range response.Cookies() {
if cookie.Name == "Session-Id" {
return cookie.Value, nil
}
}
return "", fmt.Errorf("%w", ErrSessionIDNotFound)
}
var ErrDataFieldMissing = errors.New("data field missing in response")
func (c *apiClient) getUnauthSession(ctx context.Context, sessionID string) (
tokenType, accessToken, refreshToken, uid string, err error,
) {
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/auth/v4/sessions", nil)
if err != nil {
return "", "", "", "", fmt.Errorf("creating request: %w", err)
}
unauthCookie := cookie{
sessionID: sessionID,
}
c.setHeaders(request, unauthCookie)
response, err := c.httpClient.Do(request)
if err != nil {
return "", "", "", "", err
}
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
return "", "", "", "", fmt.Errorf("reading response body: %w", err)
} else if response.StatusCode != http.StatusOK {
return "", "", "", "", buildError(response.StatusCode, responseBody)
}
var data struct {
Code uint `json:"Code"` // 1000 on success
AccessToken string `json:"AccessToken"` // 32-chars lowercase and digits
RefreshToken string `json:"RefreshToken"` // 32-chars lowercase and digits
TokenType string `json:"TokenType"` // "Bearer"
Scopes []string `json:"Scopes"` // should be [] for our usage
UID string `json:"UID"` // 32-chars lowercase and digits
LocalID uint `json:"LocalID"` // 0 in my case
}
err = json.Unmarshal(responseBody, &data)
if err != nil {
return "", "", "", "", fmt.Errorf("decoding response body: %w", err)
}
const successCode = 1000
switch {
case data.Code != successCode:
return "", "", "", "", fmt.Errorf("%w: expected %d got %d",
ErrCodeNotSuccess, successCode, data.Code)
case data.AccessToken == "":
return "", "", "", "", fmt.Errorf("%w: access token is empty", ErrDataFieldMissing)
case data.RefreshToken == "":
return "", "", "", "", fmt.Errorf("%w: refresh token is empty", ErrDataFieldMissing)
case data.TokenType == "":
return "", "", "", "", fmt.Errorf("%w: token type is empty", ErrDataFieldMissing)
case data.UID == "":
return "", "", "", "", fmt.Errorf("%w: UID is empty", ErrDataFieldMissing)
}
// Ignore Scopes and LocalID fields, we don't use them.
return data.TokenType, data.AccessToken, data.RefreshToken, data.UID, nil
}
var ErrUIDMismatch = errors.New("UID in response does not match request UID")
func (c *apiClient) cookieToken(ctx context.Context, sessionID, tokenType, accessToken,
refreshToken, uid string,
) (cookieToken string, err error) {
type requestBodySchema struct {
GrantType string `json:"GrantType"` // "refresh_token"
Persistent uint `json:"Persistent"` // 0
RedirectURI string `json:"RedirectURI"` // "https://protonmail.com"
RefreshToken string `json:"RefreshToken"` // 32-chars lowercase and digits
ResponseType string `json:"ResponseType"` // "token"
State string `json:"State"` // 24-chars letters and digits
UID string `json:"UID"` // 32-chars lowercase and digits
}
requestBody := requestBodySchema{
GrantType: "refresh_token",
Persistent: 0,
RedirectURI: "https://protonmail.com",
RefreshToken: refreshToken,
ResponseType: "token",
State: generateLettersDigits(c.generator, 24), //nolint:mnd
UID: uid,
}
buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
if err := encoder.Encode(requestBody); err != nil {
return "", fmt.Errorf("encoding request body: %w", err)
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/core/v4/auth/cookies", buffer)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
unauthCookie := cookie{
uid: uid,
sessionID: sessionID,
}
c.setHeaders(request, unauthCookie)
request.Header.Set("Authorization", tokenType+" "+accessToken)
response, err := c.httpClient.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("reading response body: %w", err)
} else if response.StatusCode != http.StatusOK {
return "", buildError(response.StatusCode, responseBody)
}
var cookies struct {
Code uint `json:"Code"` // 1000 on success
UID string `json:"UID"` // should match request UID
LocalID uint `json:"LocalID"` // 0
RefreshCounter uint `json:"RefreshCounter"` // 1
}
err = json.Unmarshal(responseBody, &cookies)
if err != nil {
return "", fmt.Errorf("decoding response body: %w", err)
}
const successCode = 1000
switch {
case cookies.Code != successCode:
return "", fmt.Errorf("%w: expected %d got %d",
ErrCodeNotSuccess, successCode, cookies.Code)
case cookies.UID != requestBody.UID:
return "", fmt.Errorf("%w: expected %s got %s",
ErrUIDMismatch, requestBody.UID, cookies.UID)
}
// Ignore LocalID and RefreshCounter fields, we don't use them.
for _, cookie := range response.Cookies() {
if cookie.Name == "AUTH-"+uid {
return cookie.Value, nil
}
}
return "", fmt.Errorf("%w", ErrAuthCookieNotFound)
}
var (
ErrUsernameDoesNotExist = errors.New("username does not exist")
ErrUsernameMismatch = errors.New("username in response does not match request username")
)
// authInfo fetches SRP parameters for the account.
func (c *apiClient) authInfo(ctx context.Context, username string, unauthCookie cookie) (
modulusPGPClearSigned, serverEphemeralBase64, saltBase64, srpSessionHex string,
version int, err error,
) {
type requestBodySchema struct {
Intent string `json:"Intent"` // "Proton"
Username string `json:"Username"` // username without @domain.com
}
requestBody := requestBodySchema{
Intent: "Proton",
Username: username,
}
buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
if err := encoder.Encode(requestBody); err != nil {
return "", "", "", "", 0, fmt.Errorf("encoding request body: %w", err)
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/core/v4/auth/info", buffer)
if err != nil {
return "", "", "", "", 0, fmt.Errorf("creating request: %w", err)
}
c.setHeaders(request, unauthCookie)
response, err := c.httpClient.Do(request)
if err != nil {
return "", "", "", "", 0, err
}
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
return "", "", "", "", 0, fmt.Errorf("reading response body: %w", err)
} else if response.StatusCode != http.StatusOK {
return "", "", "", "", 0, buildError(response.StatusCode, responseBody)
}
var info struct {
Code uint `json:"Code"` // 1000 on success
Modulus string `json:"Modulus"` // PGP clearsigned modulus string
ServerEphemeral string `json:"ServerEphemeral"` // base64
Version *uint `json:"Version,omitempty"` // 4 as of 2025-10-26
Salt string `json:"Salt"` // base64
SRPSession string `json:"SRPSession"` // hexadecimal
Username string `json:"Username"` // user without @domain.com. Mine has its first letter capitalized.
}
err = json.Unmarshal(responseBody, &info)
if err != nil {
return "", "", "", "", 0, fmt.Errorf("decoding response body: %w", err)
}
const successCode = 1000
switch {
case info.Code != successCode:
return "", "", "", "", 0, fmt.Errorf("%w: expected %d got %d",
ErrCodeNotSuccess, successCode, info.Code)
case info.Modulus == "":
return "", "", "", "", 0, fmt.Errorf("%w: modulus is empty", ErrDataFieldMissing)
case info.ServerEphemeral == "":
return "", "", "", "", 0, fmt.Errorf("%w: server ephemeral is empty", ErrDataFieldMissing)
case info.Salt == "":
return "", "", "", "", 0, fmt.Errorf("%w (salt data field is empty)", ErrUsernameDoesNotExist)
case info.SRPSession == "":
return "", "", "", "", 0, fmt.Errorf("%w: SRP session is empty", ErrDataFieldMissing)
case !strings.EqualFold(info.Username, username):
return "", "", "", "", 0, fmt.Errorf("%w: expected %s got %s",
ErrUsernameMismatch, username, info.Username)
case info.Version == nil:
return "", "", "", "", 0, fmt.Errorf("%w: version is missing", ErrDataFieldMissing)
}
version = int(*info.Version) //nolint:gosec
return info.Modulus, info.ServerEphemeral, info.Salt,
info.SRPSession, version, nil
}
type cookie struct {
uid string
token string
sessionID string
}
func (c *cookie) String() string {
s := ""
if c.token != "" {
s += fmt.Sprintf("AUTH-%s=%s; ", c.uid, c.token)
}
if c.sessionID != "" {
s += fmt.Sprintf("Session-Id=%s; ", c.sessionID)
}
if c.token != "" {
s += "Tag=default; iaas=W10; Domain=proton.me; Feature=VPNDashboard:A"
}
return s
}
var (
// ErrServerProofNotValid indicates the M2 from the server didn't match the expected proof.
ErrServerProofNotValid = errors.New("server proof from server is not valid")
ErrVPNScopeNotFound = errors.New("VPN scope not found in scopes")
ErrTwoFANotSupported = errors.New("two factor authentication not supported in this client")
ErrAuthCookieNotFound = errors.New("auth cookie not found")
)
// auth performs the SRP proof submission (and optionally TOTP) to obtain tokens.
func (c *apiClient) auth(ctx context.Context, unauthCookie cookie,
username, srpSession string, proofs *srp.Proofs,
) (authCookie cookie, err error) {
clientEphemeral := base64.StdEncoding.EncodeToString(proofs.ClientEphemeral)
clientProof := base64.StdEncoding.EncodeToString(proofs.ClientProof)
type requestBodySchema struct {
ClientEphemeral string `json:"ClientEphemeral"` // base64(A)
ClientProof string `json:"ClientProof"` // base64(M1)
Payload map[string]string `json:"Payload,omitempty"` // not sure
SRPSession string `json:"SRPSession"` // hexadecimal
Username string `json:"Username"` // user@protonmail.com
}
requestBody := requestBodySchema{
ClientEphemeral: clientEphemeral,
ClientProof: clientProof,
SRPSession: srpSession,
Username: username,
}
buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
if err := encoder.Encode(requestBody); err != nil {
return cookie{}, fmt.Errorf("encoding request body: %w", err)
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/core/v4/auth", buffer)
if err != nil {
return cookie{}, fmt.Errorf("creating request: %w", err)
}
c.setHeaders(request, unauthCookie)
response, err := c.httpClient.Do(request)
if err != nil {
return cookie{}, err
}
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
return cookie{}, fmt.Errorf("reading response body: %w", err)
} else if response.StatusCode != http.StatusOK {
return cookie{}, buildError(response.StatusCode, responseBody)
}
type twoFAStatus uint
//nolint:unused
const (
twoFADisabled twoFAStatus = iota
twoFAHasTOTP
twoFAHasFIDO2
twoFAHasFIDO2AndTOTP
)
type twoFAInfo struct {
Enabled twoFAStatus `json:"Enabled"`
FIDO2 struct {
AuthenticationOptions any `json:"AuthenticationOptions"`
RegisteredKeys []any `json:"RegisteredKeys"`
} `json:"FIDO2"`
TOTP uint `json:"TOTP"`
}
var auth struct {
Code uint `json:"Code"` // 1000 on success
LocalID uint `json:"LocalID"` // 7 in my case
Scopes []string `json:"Scopes"` // this should contain "vpn". Same as `Scope` field value.
UID string `json:"UID"` // same as `Uid` field value
UserID string `json:"UserID"` // base64
EventID string `json:"EventID"` // base64
PasswordMode uint `json:"PasswordMode"` // 1 in my case
ServerProof string `json:"ServerProof"` // base64(M2)
TwoFactor uint `json:"TwoFactor"` // 0 if 2FA not required
TwoFA twoFAInfo `json:"2FA"`
TemporaryPassword uint `json:"TemporaryPassword"` // 0 in my case
}
err = json.Unmarshal(responseBody, &auth)
if err != nil {
return cookie{}, fmt.Errorf("decoding response body: %w", err)
}
m2, err := base64.StdEncoding.DecodeString(auth.ServerProof)
if err != nil {
return cookie{}, fmt.Errorf("decoding server proof: %w", err)
}
if !bytes.Equal(m2, proofs.ExpectedServerProof) {
return cookie{}, fmt.Errorf("%w: expected %x got %x",
ErrServerProofNotValid, proofs.ExpectedServerProof, m2)
}
const successCode = 1000
switch {
case auth.Code != successCode:
return cookie{}, fmt.Errorf("%w: expected %d got %d",
ErrCodeNotSuccess, successCode, auth.Code)
case auth.UID != unauthCookie.uid:
return cookie{}, fmt.Errorf("%w: expected %s got %s",
ErrUIDMismatch, unauthCookie.uid, auth.UID)
case auth.TwoFactor != 0:
return cookie{}, fmt.Errorf("%w", ErrTwoFANotSupported)
case !slices.Contains(auth.Scopes, "vpn"):
return cookie{}, fmt.Errorf("%w: in %v", ErrVPNScopeNotFound, auth.Scopes)
}
for _, setCookieHeader := range response.Header.Values("Set-Cookie") {
parts := strings.Split(setCookieHeader, ";")
for _, part := range parts {
if strings.HasPrefix(part, "AUTH-"+unauthCookie.uid+"=") {
authCookie = unauthCookie
authCookie.token = strings.TrimPrefix(part, "AUTH-"+unauthCookie.uid+"=")
return authCookie, nil
}
}
}
return cookie{}, fmt.Errorf("%w: in HTTP headers %s",
ErrAuthCookieNotFound, httpHeadersToString(response.Header))
}
// generateLettersDigits mimicing Proton's own random string generator:
// https://github.com/ProtonMail/WebClients/blob/e4d7e4ab9babe15b79a131960185f9f8275512cd/packages/utils/generateLettersDigits.ts
func generateLettersDigits(rng *rand.ChaCha8, length uint) string {
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
return generateFromCharset(rng, length, charset)
}
func generateFromCharset(rng *rand.ChaCha8, length uint, charset string) string {
result := make([]byte, length)
randomBytes := make([]byte, length)
_, _ = rng.Read(randomBytes)
for i := range length {
result[i] = charset[int(randomBytes[i])%len(charset)]
}
return string(result)
}
func httpHeadersToString(headers http.Header) string {
var builder strings.Builder
first := true
for key, values := range headers {
for _, value := range values {
if !first {
builder.WriteString(", ")
}
builder.WriteString(fmt.Sprintf("%s: %s", key, value))
first = false
}
}
return builder.String()
}
type apiData struct { type apiData struct {
LogicalServers []logicalServer `json:"LogicalServers"` LogicalServers []logicalServer `json:"LogicalServers"`
@@ -585,25 +33,25 @@ type physicalServer struct {
X25519PublicKey string `json:"X25519PublicKey"` X25519PublicKey string `json:"X25519PublicKey"`
} }
func (c *apiClient) fetchServers(ctx context.Context, cookie cookie) ( func fetchAPI(ctx context.Context, client *http.Client) (
data apiData, err error, data apiData, err error,
) { ) {
const url = "https://account.proton.me/api/vpn/logicals" const url = "https://api.protonmail.ch/vpn/logicals"
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return data, err return data, err
} }
c.setHeaders(request, cookie)
response, err := c.httpClient.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
return data, err return data, err
} }
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
b, _ := io.ReadAll(response.Body) return data, fmt.Errorf("%w: %d %s", ErrHTTPStatusCodeNotOK,
return data, buildError(response.StatusCode, b) response.StatusCode, response.Status)
} }
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)
@@ -611,31 +59,9 @@ func (c *apiClient) fetchServers(ctx context.Context, cookie cookie) (
return data, fmt.Errorf("decoding response body: %w", err) return data, fmt.Errorf("decoding response body: %w", err)
} }
if err := response.Body.Close(); err != nil {
return data, err
}
return data, nil return data, nil
} }
var ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
func buildError(httpCode int, body []byte) error {
prettyCode := http.StatusText(httpCode)
var protonError struct {
Code *int `json:"Code,omitempty"`
Error *string `json:"Error,omitempty"`
Details map[string]string `json:"Details"`
}
decoder := json.NewDecoder(bytes.NewReader(body))
decoder.DisallowUnknownFields()
err := decoder.Decode(&protonError)
if err != nil || protonError.Error == nil || protonError.Code == nil {
return fmt.Errorf("%w: %s: %s",
ErrHTTPStatusCodeNotOK, prettyCode, body)
}
details := make([]string, 0, len(protonError.Details))
for key, value := range protonError.Details {
details = append(details, fmt.Sprintf("%s: %s", key, value))
}
return fmt.Errorf("%w: %s: %s (code %d with details: %s)",
ErrHTTPStatusCodeNotOK, prettyCode, *protonError.Error, *protonError.Code, strings.Join(details, ", "))
}

View File

@@ -13,26 +13,9 @@ import (
func (u *Updater) FetchServers(ctx context.Context, minServers int) ( func (u *Updater) FetchServers(ctx context.Context, minServers int) (
servers []models.Server, err error, servers []models.Server, err error,
) { ) {
switch { data, err := fetchAPI(ctx, u.client)
case u.username == "":
return nil, fmt.Errorf("%w: username is empty", common.ErrCredentialsMissing)
case u.password == "":
return nil, fmt.Errorf("%w: password is empty", common.ErrCredentialsMissing)
}
apiClient, err := newAPIClient(ctx, u.client)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating API client: %w", err) return nil, err
}
cookie, err := apiClient.authenticate(ctx, u.username, u.password)
if err != nil {
return nil, fmt.Errorf("authentifying with Proton: %w", err)
}
data, err := apiClient.fetchServers(ctx, cookie)
if err != nil {
return nil, fmt.Errorf("fetching logical servers: %w", err)
} }
countryCodes := constants.CountryCodes() countryCodes := constants.CountryCodes()

View File

@@ -7,17 +7,13 @@ import (
) )
type Updater struct { type Updater struct {
client *http.Client client *http.Client
username string warner common.Warner
password string
warner common.Warner
} }
func New(client *http.Client, warner common.Warner, username, password string) *Updater { func New(client *http.Client, warner common.Warner) *Updater {
return &Updater{ return &Updater{
client: client, client: client,
username: username, warner: warner,
password: password,
warner: warner,
} }
} }

View File

@@ -1,64 +0,0 @@
package updater
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
)
// getMostRecentStableTag finds the most recent proton-account stable tag version,
// in order to use it in the x-pm-appversion http request header. Because if we do
// fall behind on versioning, Proton doesn't like it because they like to create
// complications where there is no need for it. Hence this function.
func getMostRecentStableTag(ctx context.Context, client *http.Client) (version string, err error) {
page := 1
regexVersion := regexp.MustCompile(`^proton-account@(\d+\.\d+\.\d+\.\d+)$`)
for ctx.Err() == nil {
url := "https://api.github.com/repos/ProtonMail/WebClients/tags?per_page=30&page=" + fmt.Sprint(page)
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
request.Header.Set("Accept", "application/vnd.github.v3+json")
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
data, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("reading response body: %w", err)
}
if response.StatusCode != http.StatusOK {
return "", fmt.Errorf("%w: %s: %s", ErrHTTPStatusCodeNotOK, response.Status, data)
}
var tags []struct {
Name string `json:"name"`
}
err = json.Unmarshal(data, &tags)
if err != nil {
return "", fmt.Errorf("decoding JSON response: %w", err)
}
for _, tag := range tags {
if !regexVersion.MatchString(tag.Name) {
continue
}
version := "web-account@" + strings.TrimPrefix(tag.Name, "proton-account@")
return version, nil
}
page++
}
return "", fmt.Errorf("%w (queried %d pages)", context.Canceled, page)
}

View File

@@ -54,7 +54,7 @@ type Extractor interface {
func NewProviders(storage Storage, timeNow func() time.Time, func NewProviders(storage Storage, timeNow func() time.Time,
updaterWarner common.Warner, client *http.Client, unzipper common.Unzipper, updaterWarner common.Warner, client *http.Client, unzipper common.Unzipper,
parallelResolver common.ParallelResolver, ipFetcher common.IPFetcher, parallelResolver common.ParallelResolver, ipFetcher common.IPFetcher,
extractor custom.Extractor, credentials settings.Updater, extractor custom.Extractor,
) *Providers { ) *Providers {
randSource := rand.NewSource(timeNow().UnixNano()) randSource := rand.NewSource(timeNow().UnixNano())
@@ -75,7 +75,7 @@ func NewProviders(storage Storage, timeNow func() time.Time,
providers.Privado: privado.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver), providers.Privado: privado.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver),
providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client), providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client),
providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver), providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, *credentials.ProtonUsername, *credentials.ProtonPassword), providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner),
providers.Purevpn: purevpn.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver), providers.Purevpn: purevpn.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver),
providers.SlickVPN: slickvpn.New(storage, randSource, client, updaterWarner, parallelResolver), providers.SlickVPN: slickvpn.New(storage, randSource, client, updaterWarner, parallelResolver),
providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver), providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver),

View File

@@ -25,14 +25,13 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
handler := &handler{} handler := &handler{}
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger) vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
openvpn := newOpenvpnHandler(ctx, vpnLooper, logger) openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger)
dns := newDNSHandler(ctx, dnsLooper, logger) dns := newDNSHandler(ctx, dnsLooper, logger)
updater := newUpdaterHandler(ctx, updaterLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger)
publicip := newPublicIPHandler(publicIPLooper, logger) publicip := newPublicIPHandler(publicIPLooper, logger)
portForward := newPortForwardHandler(ctx, pfGetter, logger)
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, dnsLooper, updaterLooper) handler.v0 = newHandlerV0(ctx, logger, vpnLooper, dnsLooper, updaterLooper)
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip, portForward) handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
authMiddleware, err := auth.New(authSettings, logger) authMiddleware, err := auth.New(authSettings, logger)
if err != nil { if err != nil {

View File

@@ -52,7 +52,7 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.logger.Warn(err.Error()) h.logger.Warn(err.Error())
} }
case "/openvpn/portforwarded": case "/openvpn/portforwarded":
http.Redirect(w, r, "/v1/portforward", http.StatusPermanentRedirect) http.Redirect(w, r, "/v1/openvpn/portforwarded", http.StatusPermanentRedirect)
case "/openvpn/settings": case "/openvpn/settings":
http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect) http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect)
case "/updater/restart": case "/updater/restart":

View File

@@ -10,29 +10,27 @@ import (
) )
func newHandlerV1(w warner, buildInfo models.BuildInformation, func newHandlerV1(w warner, buildInfo models.BuildInformation,
vpn, openvpn, dns, updater, publicip, portForward http.Handler, vpn, openvpn, dns, updater, publicip http.Handler,
) http.Handler { ) http.Handler {
return &handlerV1{ return &handlerV1{
warner: w, warner: w,
buildInfo: buildInfo, buildInfo: buildInfo,
vpn: vpn, vpn: vpn,
openvpn: openvpn, openvpn: openvpn,
dns: dns, dns: dns,
updater: updater, updater: updater,
publicip: publicip, publicip: publicip,
portForward: portForward,
} }
} }
type handlerV1 struct { type handlerV1 struct {
warner warner warner warner
buildInfo models.BuildInformation buildInfo models.BuildInformation
vpn http.Handler vpn http.Handler
openvpn http.Handler openvpn http.Handler
dns http.Handler dns http.Handler
updater http.Handler updater http.Handler
publicip http.Handler publicip http.Handler
portForward http.Handler
} }
func (h *handlerV1) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *handlerV1) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -49,8 +47,6 @@ func (h *handlerV1) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.updater.ServeHTTP(w, r) h.updater.ServeHTTP(w, r)
case strings.HasPrefix(r.RequestURI, "/publicip"): case strings.HasPrefix(r.RequestURI, "/publicip"):
h.publicip.ServeHTTP(w, r) h.publicip.ServeHTTP(w, r)
case strings.HasPrefix(r.RequestURI, "/portforward"):
h.portForward.ServeHTTP(w, r)
default: default:
errString := fmt.Sprintf("%s %s not found", r.Method, r.RequestURI) errString := fmt.Sprintf("%s %s not found", r.Method, r.RequestURI)
http.Error(w, errString, http.StatusBadRequest) http.Error(w, errString, http.StatusBadRequest)

View File

@@ -20,7 +20,6 @@ func New(settings Settings, debugLogger DebugLogger) (
routeToRoles: routeToRoles, routeToRoles: routeToRoles,
unprotectedRoutes: map[string]struct{}{ unprotectedRoutes: map[string]struct{}{
http.MethodGet + " /openvpn/actions/restart": {}, http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /openvpn/portforwarded": {},
http.MethodGet + " /unbound/actions/restart": {}, http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {}, http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {}, http.MethodGet + " /v1/version": {},
@@ -37,7 +36,6 @@ func New(settings Settings, debugLogger DebugLogger) (
http.MethodGet + " /v1/updater/status": {}, http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {}, http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {}, http.MethodGet + " /v1/publicip/ip": {},
http.MethodGet + " /v1/portforward": {},
}, },
logger: debugLogger, logger: debugLogger,
} }

View File

@@ -1,16 +1,12 @@
package auth package auth
import ( import (
"bytes"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"slices"
"github.com/qdm12/gosettings" "github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/validate" "github.com/qdm12/gosettings/validate"
"github.com/qdm12/gotree"
) )
type Settings struct { type Settings struct {
@@ -19,50 +15,6 @@ type Settings struct {
Roles []Role Roles []Role
} }
// SetDefaultRole sets a default role to apply to all routes without a
// previously user-defined role assigned to. Note the role argument
// routes are ignored. This should be called BEFORE calling [Settings.SetDefaults].
func (s *Settings) SetDefaultRole(jsonRole string) error {
var role Role
decoder := json.NewDecoder(bytes.NewBufferString(jsonRole))
decoder.DisallowUnknownFields()
err := decoder.Decode(&role)
if err != nil {
return fmt.Errorf("decoding default role: %w", err)
}
if role.Auth == "" {
return nil // no default role to set
}
err = role.Validate()
if err != nil {
return fmt.Errorf("validating default role: %w", err)
}
authenticatedRoutes := make(map[string]struct{}, len(validRoutes))
for _, role := range s.Roles {
for _, route := range role.Routes {
authenticatedRoutes[route] = struct{}{}
}
}
if len(authenticatedRoutes) == len(validRoutes) {
return nil
}
unauthenticatedRoutes := make([]string, 0, len(validRoutes))
for route := range validRoutes {
_, authenticated := authenticatedRoutes[route]
if !authenticated {
unauthenticatedRoutes = append(unauthenticatedRoutes, route)
}
}
slices.Sort(unauthenticatedRoutes)
role.Routes = unauthenticatedRoutes
s.Roles = append(s.Roles, role)
return nil
}
func (s *Settings) SetDefaults() { func (s *Settings) SetDefaults() {
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
Name: "public", Name: "public",
@@ -70,7 +22,6 @@ func (s *Settings) SetDefaults() {
Routes: []string{ Routes: []string{
http.MethodGet + " /openvpn/actions/restart", http.MethodGet + " /openvpn/actions/restart",
http.MethodGet + " /unbound/actions/restart", http.MethodGet + " /unbound/actions/restart",
http.MethodGet + " /openvpn/portforwarded",
http.MethodGet + " /updater/restart", http.MethodGet + " /updater/restart",
http.MethodGet + " /v1/version", http.MethodGet + " /v1/version",
http.MethodGet + " /v1/vpn/status", http.MethodGet + " /v1/vpn/status",
@@ -83,14 +34,13 @@ func (s *Settings) SetDefaults() {
http.MethodGet + " /v1/updater/status", http.MethodGet + " /v1/updater/status",
http.MethodPut + " /v1/updater/status", http.MethodPut + " /v1/updater/status",
http.MethodGet + " /v1/publicip/ip", http.MethodGet + " /v1/publicip/ip",
http.MethodGet + " /v1/portforward",
}, },
}}) }})
} }
func (s Settings) Validate() (err error) { func (s Settings) Validate() (err error) {
for i, role := range s.Roles { for i, role := range s.Roles {
err = role.Validate() err = role.validate()
if err != nil { if err != nil {
return fmt.Errorf("role %s (%d of %d): %w", return fmt.Errorf("role %s (%d of %d): %w",
role.Name, i+1, len(s.Roles), err) role.Name, i+1, len(s.Roles), err)
@@ -111,18 +61,18 @@ const (
type Role struct { type Role struct {
// Name is the role name and is only used for documentation // Name is the role name and is only used for documentation
// and in the authentication middleware debug logs. // and in the authentication middleware debug logs.
Name string `json:"name"` Name string
// Auth is the authentication method to use, which can be 'none', 'basic' or 'apikey'. // Auth is the authentication method to use, which can be 'none' or 'apikey'.
Auth string `json:"auth"` Auth string
// APIKey is the API key to use when using the 'apikey' authentication. // APIKey is the API key to use when using the 'apikey' authentication.
APIKey string `json:"apikey"` APIKey string
// Username for HTTP Basic authentication method. // Username for HTTP Basic authentication method.
Username string `json:"username"` Username string
// Password for HTTP Basic authentication method. // Password for HTTP Basic authentication method.
Password string `json:"password"` Password string
// Routes is a list of routes that the role can access in the format // Routes is a list of routes that the role can access in the format
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status" // "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
Routes []string `json:"-"` Routes []string
} }
var ( var (
@@ -133,7 +83,7 @@ var (
ErrRouteNotSupported = errors.New("route not supported by the control server") ErrRouteNotSupported = errors.New("route not supported by the control server")
) )
func (r Role) Validate() (err error) { func (r Role) validate() (err error) {
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic) err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth) return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
@@ -162,8 +112,6 @@ func (r Role) Validate() (err error) {
// WARNING: do not mutate programmatically. // WARNING: do not mutate programmatically.
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
http.MethodGet + " /openvpn/actions/restart": {}, http.MethodGet + " /openvpn/actions/restart": {},
http.MethodGet + " /openvpn/portforwarded": {},
http.MethodGet + " /openvpn/settings": {},
http.MethodGet + " /unbound/actions/restart": {}, http.MethodGet + " /unbound/actions/restart": {},
http.MethodGet + " /updater/restart": {}, http.MethodGet + " /updater/restart": {},
http.MethodGet + " /v1/version": {}, http.MethodGet + " /v1/version": {},
@@ -180,22 +128,4 @@ var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
http.MethodGet + " /v1/updater/status": {}, http.MethodGet + " /v1/updater/status": {},
http.MethodPut + " /v1/updater/status": {}, http.MethodPut + " /v1/updater/status": {},
http.MethodGet + " /v1/publicip/ip": {}, http.MethodGet + " /v1/publicip/ip": {},
http.MethodGet + " /v1/portforward": {},
}
func (r Role) ToLinesNode() (node *gotree.Node) {
node = gotree.New("Role " + r.Name)
node.Appendf("Authentication method: %s", r.Auth)
switch r.Auth {
case AuthNone:
case AuthBasic:
node.Appendf("Username: %s", r.Username)
node.Appendf("Password: %s", gosettings.ObfuscateKey(r.Password))
case AuthAPIKey:
node.Appendf("API key: %s", gosettings.ObfuscateKey(r.APIKey))
default:
panic("missing code for authentication method: " + r.Auth)
}
node.Appendf("Number of routes covered: %d", len(r.Routes))
return node
} }

View File

@@ -38,7 +38,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.childHandler.ServeHTTP(statefulWriter, r) m.childHandler.ServeHTTP(statefulWriter, r)
duration := m.timeNow().Sub(tStart) duration := m.timeNow().Sub(tStart)
m.logger.Info(strconv.Itoa(statefulWriter.statusCode) + " " + m.logger.Info(strconv.Itoa(statefulWriter.statusCode) + " " +
r.Method + " " + r.URL.String() + r.Method + " " + r.RequestURI +
" wrote " + strconv.Itoa(statefulWriter.length) + "B to " + " wrote " + strconv.Itoa(statefulWriter.length) + "B to " +
r.RemoteAddr + " in " + duration.String()) r.RemoteAddr + " in " + duration.String())
} }

View File

@@ -10,10 +10,13 @@ import (
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
) )
func newOpenvpnHandler(ctx context.Context, looper VPNLooper, w warner) http.Handler { func newOpenvpnHandler(ctx context.Context, looper VPNLooper,
pfGetter PortForwardedGetter, w warner,
) http.Handler {
return &openvpnHandler{ return &openvpnHandler{
ctx: ctx, ctx: ctx,
looper: looper, looper: looper,
pf: pfGetter,
warner: w, warner: w,
} }
} }
@@ -21,6 +24,7 @@ func newOpenvpnHandler(ctx context.Context, looper VPNLooper, w warner) http.Han
type openvpnHandler struct { type openvpnHandler struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
looper VPNLooper looper VPNLooper
pf PortForwardedGetter
warner warner warner warner
} }
@@ -43,10 +47,10 @@ func (h *openvpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
default: default:
errMethodNotSupported(w, r.Method) errMethodNotSupported(w, r.Method)
} }
case "/portforwarded": // TODO v4 remove case "/portforwarded":
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
http.Redirect(w, r, "/v1/portforward", http.StatusMovedPermanently) h.getPortForwarded(w)
default: default:
errMethodNotSupported(w, r.Method) errMethodNotSupported(w, r.Method)
} }
@@ -118,3 +122,23 @@ func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
return return
} }
} }
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
ports := h.pf.GetPortsForwarded()
encoder := json.NewEncoder(w)
var data any
switch len(ports) {
case 0:
data = portWrapper{Port: 0} // TODO v4 change to portsWrapper
case 1:
data = portWrapper{Port: ports[0]} // TODO v4 change to portsWrapper
default:
data = portsWrapper{Ports: ports}
}
err := encoder.Encode(data)
if err != nil {
h.warner.Warn(err.Error())
w.WriteHeader(http.StatusInternalServerError)
}
}

View File

@@ -1,52 +0,0 @@
package server
import (
"context"
"encoding/json"
"net/http"
)
func newPortForwardHandler(ctx context.Context,
portForward PortForwardedGetter, warner warner,
) http.Handler {
return &portForwardHandler{
ctx: ctx,
portForward: portForward,
warner: warner,
}
}
type portForwardHandler struct {
ctx context.Context //nolint:containedctx
portForward PortForwardedGetter
warner warner
}
func (h *portForwardHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.getPortForwarded(w)
default:
errMethodNotSupported(w, r.Method)
}
}
func (h *portForwardHandler) getPortForwarded(w http.ResponseWriter) {
ports := h.portForward.GetPortsForwarded()
encoder := json.NewEncoder(w)
var data any
switch len(ports) {
case 0:
data = portWrapper{Port: 0} // TODO v4 change to portsWrapper
case 1:
data = portWrapper{Port: ports[0]} // TODO v4 change to portsWrapper
default:
data = portsWrapper{Ports: ports}
}
err := encoder.Encode(data)
if err != nil {
h.warner.Warn(err.Error())
w.WriteHeader(http.StatusInternalServerError)
}
}

View File

@@ -6,25 +6,33 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/httpserver" "github.com/qdm12/gluetun/internal/httpserver"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/server/middlewares/auth" "github.com/qdm12/gluetun/internal/server/middlewares/auth"
) )
func New(ctx context.Context, settings settings.ControlServer, logger Logger, func New(ctx context.Context, address string, logEnabled bool, logger Logger,
buildInfo models.BuildInformation, openvpnLooper VPNLooper, authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
pfGetter PortForwardedGetter, dnsLooper DNSLoop, pfGetter PortForwardedGetter, dnsLooper DNSLoop,
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage, updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
ipv6Supported bool) ( ipv6Supported bool) (
server *httpserver.Server, err error, server *httpserver.Server, err error,
) { ) {
authSettings, err := setupAuthMiddleware(settings.AuthFilePath, settings.AuthDefaultRole, logger) authSettings, err := auth.Read(authConfigPath)
switch {
case errors.Is(err, os.ErrNotExist): // no auth file present
case err != nil:
return nil, fmt.Errorf("reading auth settings: %w", err)
default:
logger.Infof("read %d roles from authentication file", len(authSettings.Roles))
}
authSettings.SetDefaults()
err = authSettings.Validate()
if err != nil { if err != nil {
return nil, fmt.Errorf("building authentication middleware settings: %w", err) return nil, fmt.Errorf("validating auth settings: %w", err)
} }
handler, err := newHandler(ctx, logger, *settings.Log, authSettings, buildInfo, handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
openvpnLooper, pfGetter, dnsLooper, updaterLooper, publicIPLooper, openvpnLooper, pfGetter, dnsLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported) storage, ipv6Supported)
if err != nil { if err != nil {
@@ -32,7 +40,7 @@ func New(ctx context.Context, settings settings.ControlServer, logger Logger,
} }
httpServerSettings := httpserver.Settings{ httpServerSettings := httpserver.Settings{
Address: *settings.Address, Address: address,
Handler: handler, Handler: handler,
Logger: logger, Logger: logger,
} }
@@ -44,26 +52,3 @@ func New(ctx context.Context, settings settings.ControlServer, logger Logger,
return server, nil return server, nil
} }
func setupAuthMiddleware(authPath, jsonDefaultRole string, logger Logger) (
authSettings auth.Settings, err error,
) {
authSettings, err = auth.Read(authPath)
switch {
case errors.Is(err, os.ErrNotExist): // no auth file present
case err != nil:
return auth.Settings{}, fmt.Errorf("reading auth settings: %w", err)
default:
logger.Infof("read %d roles from authentication file", len(authSettings.Roles))
}
err = authSettings.SetDefaultRole(jsonDefaultRole)
if err != nil {
return auth.Settings{}, fmt.Errorf("setting default role: %w", err)
}
authSettings.SetDefaults()
err = authSettings.Validate()
if err != nil {
return auth.Settings{}, fmt.Errorf("validating auth settings: %w", err)
}
return authSettings, nil
}

View File

@@ -1,3 +1,3 @@
package storage package storage
//go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . Logger //go:generate mockgen -destination=mocks_test.go -package $GOPACKAGE . Infoer

View File

@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/storage (interfaces: Logger) // Source: github.com/qdm12/gluetun/internal/storage (interfaces: Infoer)
// Package storage is a generated GoMock package. // Package storage is a generated GoMock package.
package storage package storage
@@ -10,49 +10,37 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
// MockLogger is a mock of Logger interface. // MockInfoer is a mock of Infoer interface.
type MockLogger struct { type MockInfoer struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockLoggerMockRecorder recorder *MockInfoerMockRecorder
} }
// MockLoggerMockRecorder is the mock recorder for MockLogger. // MockInfoerMockRecorder is the mock recorder for MockInfoer.
type MockLoggerMockRecorder struct { type MockInfoerMockRecorder struct {
mock *MockLogger mock *MockInfoer
} }
// NewMockLogger creates a new mock instance. // NewMockInfoer creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger { func NewMockInfoer(ctrl *gomock.Controller) *MockInfoer {
mock := &MockLogger{ctrl: ctrl} mock := &MockInfoer{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock} mock.recorder = &MockInfoerMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { func (m *MockInfoer) EXPECT() *MockInfoerMockRecorder {
return m.recorder return m.recorder
} }
// Info mocks base method. // Info mocks base method.
func (m *MockLogger) Info(arg0 string) { func (m *MockInfoer) Info(arg0 string) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0) m.ctrl.Call(m, "Info", arg0)
} }
// Info indicates an expected call of Info. // Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call { func (mr *MockInfoerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockInfoer)(nil).Info), arg0)
}
// Warn mocks base method.
func (m *MockLogger) Warn(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Warn", arg0)
}
// Warn indicates an expected call of Warn.
func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0)
} }

View File

@@ -95,7 +95,7 @@ func Test_extractServersFromBytes(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl) logger := NewMockInfoer(ctrl)
var previousLogCall *gomock.Call var previousLogCall *gomock.Call
for _, logged := range testCase.logged { for _, logged := range testCase.logged {
call := logger.EXPECT().Info(logged) call := logger.EXPECT().Info(logged)

View File

@@ -13,35 +13,30 @@ type Storage struct {
// the embedded JSON file on every call to the // the embedded JSON file on every call to the
// SyncServers method. // SyncServers method.
hardcodedServers models.AllServers hardcodedServers models.AllServers
logger Logger logger Infoer
filepath string filepath string
} }
type Logger interface { type Infoer interface {
Info(s string) Info(s string)
Warn(s string)
} }
// New creates a new storage and reads the servers from the // New creates a new storage and reads the servers from the
// embedded servers file and the file on disk. // embedded servers file and the file on disk.
// Passing an empty filepath disables the reading and writing of // Passing an empty filepath disables writing servers to a file.
// servers. func New(logger Infoer, filepath string) (storage *Storage, err error) {
func New(logger Logger, filepath string) (storage *Storage, err error) {
// A unit test prevents any error from being returned // A unit test prevents any error from being returned
// and ensures all providers are part of the servers returned. // and ensures all providers are part of the servers returned.
hardcodedServers, _ := parseHardcodedServers() hardcodedServers, _ := parseHardcodedServers()
storage = &Storage{ storage = &Storage{
hardcodedServers: hardcodedServers, hardcodedServers: hardcodedServers,
mergedServers: hardcodedServers,
logger: logger, logger: logger,
filepath: filepath, filepath: filepath,
} }
if filepath != "" { if err := storage.syncServers(); err != nil {
if err := storage.syncServers(); err != nil { return nil, err
return nil, err
}
} }
return storage, nil return storage, nil

View File

@@ -46,13 +46,13 @@ func (s *Storage) syncServers() (err error) {
} }
// Eventually write file // Eventually write file
if reflect.DeepEqual(serversOnFile, s.mergedServers) { if s.filepath == "" || reflect.DeepEqual(serversOnFile, s.mergedServers) {
return nil return nil
} }
err = s.flushToFile(s.filepath) err = s.flushToFile(s.filepath)
if err != nil { if err != nil {
s.logger.Warn("failed writing servers to file: " + err.Error()) return fmt.Errorf("writing servers to file: %w", err)
} }
return nil return nil
} }

View File

@@ -29,7 +29,7 @@ func (u *Updater) updateProvider(ctx context.Context, provider Provider,
u.logger.Warn("note: if running the update manually, you can use the flag " + u.logger.Warn("note: if running the update manually, you can use the flag " +
"-minratio to allow the update to succeed with less servers found") "-minratio to allow the update to succeed with less servers found")
} }
return fmt.Errorf("getting %s servers: %w", providerName, err) return fmt.Errorf("getting servers: %w", err)
} }
for _, server := range servers { for _, server := range servers {

View File

@@ -2,11 +2,9 @@ package updater
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"time" "time"
"github.com/qdm12/gluetun/internal/provider/common"
"github.com/qdm12/gluetun/internal/updater/unzip" "github.com/qdm12/gluetun/internal/updater/unzip"
"golang.org/x/text/cases" "golang.org/x/text/cases"
"golang.org/x/text/language" "golang.org/x/text/language"
@@ -50,22 +48,22 @@ func (u *Updater) UpdateServers(ctx context.Context, providers []string,
// TODO support servers offering only TCP or only UDP // TODO support servers offering only TCP or only UDP
// for NordVPN and PureVPN // for NordVPN and PureVPN
err := u.updateProvider(ctx, fetcher, minRatio) err := u.updateProvider(ctx, fetcher, minRatio)
switch { if err == nil {
case err == nil:
continue continue
case errors.Is(err, common.ErrCredentialsMissing):
u.logger.Warn(err.Error() + " - skipping update for " + providerName)
continue
case len(providers) == 1:
// return the only error for the single provider.
return err
case ctx.Err() != nil:
// stop updating other providers if context is done
return ctx.Err()
default: // error encountered updating one of multiple providers
// Log the error and continue updating the next provider.
u.logger.Error(err.Error())
} }
// return the only error for the single provider.
if len(providers) == 1 {
return err
}
// stop updating the next providers if context is canceled.
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
// Log the error and continue updating the next provider.
u.logger.Error(err.Error())
} }
return nil return nil

View File

@@ -81,6 +81,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error) LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) (err error)
} }
type DNSLoop interface { type DNSLoop interface {

View File

@@ -47,6 +47,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue continue
} }
tunnelUpData := tunnelUpData{ tunnelUpData := tunnelUpData{
vpnType: settings.Type,
serverIP: connection.IP, serverIP: connection.IP,
serverName: connection.ServerName, serverName: connection.ServerName,
canPortForward: connection.PortForward, canPortForward: connection.PortForward,

View File

@@ -2,15 +2,24 @@ package vpn
import ( import (
"context" "context"
"errors"
"fmt"
"net/netip" "net/netip"
"time"
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/pmtud"
"github.com/qdm12/gluetun/internal/version" "github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/log"
) )
type tunnelUpData struct { type tunnelUpData struct {
// Healthcheck // Healthcheck
serverIP netip.Addr serverIP netip.Addr
// vpnType is used for path MTU discovery to find the protocol overhead.
// It can be "wireguard" or "openvpn".
vpnType string
// Port forwarding // Port forwarding
vpnIntf string vpnIntf string
serverName string // used for PIA serverName string // used for PIA
@@ -45,7 +54,21 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
return return
} }
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running) mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
err = updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
l.netLinker, l.routing, mtuLogger)
if err != nil {
mtuLogger.Error(err.Error())
}
if *l.dnsLooper.GetSettings().ServerEnabled {
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
} else {
err := check.WaitForDNS(ctx, check.Settings{})
if err != nil {
l.logger.Error("waiting for DNS to be ready: " + err.Error())
}
}
err = l.publicip.RunOnce(ctx) err = l.publicip.RunOnce(ctx)
if err != nil { if err != nil {
@@ -104,3 +127,65 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
_, _ = l.ApplyStatus(ctx, constants.Stopped) _, _ = l.ApplyStatus(ctx, constants.Stopped)
_, _ = l.ApplyStatus(ctx, constants.Running) _, _ = l.ApplyStatus(ctx, constants.Running)
} }
var errVPNTypeUnknown = errors.New("unknown VPN type")
func updateToMaxMTU(ctx context.Context, vpnInterface string,
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
) error {
logger.Info("finding maximum MTU, this can take up to 4 seconds")
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN gateway IP address: %w", err)
}
link, err := netlinker.LinkByName(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN interface by name: %w", err)
}
originalMTU := link.MTU
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
// protocol overhead, so start lower than 1500 according to the protocol used.
const physicalLinkMTU = 1500
vpnLinkMTU := physicalLinkMTU
switch vpnType {
case "wireguard":
vpnLinkMTU -= 60 // Wireguard overhead
case "openvpn":
vpnLinkMTU -= 41 // OpenVPN overhead
default:
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
}
// Setting the VPN link MTU to 1500 might interrupt the connection until
// the new MTU is set again, but this is necessary to find the highest valid MTU.
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
switch {
case err == nil:
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
vpnLinkMTU = int(originalMTU)
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
vpnInterface, originalMTU, err)
default:
return fmt.Errorf("path MTU discovering: %w", err)
}
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
return nil
}