Wireguard support for Mullvad and Windscribe (#565)
- `internal/wireguard` client package with unit tests - Implementation works with kernel space or user space if unavailable - `WIREGUARD_PRIVATE_KEY` - `WIREGUARD_ADDRESS` - `WIREGUARD_PRESHARED_KEY` - `WIREGUARD_PORT` - `internal/netlink` package used by `internal/wireguard`
This commit is contained in:
@@ -1 +1,2 @@
|
||||
FROM qmcgaw/godevcontainer
|
||||
RUN apk add wireguard-tools
|
||||
|
||||
@@ -4,6 +4,8 @@ services:
|
||||
vscode:
|
||||
build: .
|
||||
image: godevcontainer
|
||||
devices:
|
||||
- /dev/net/tun:/dev/net/tun
|
||||
volumes:
|
||||
- ../:/workspace
|
||||
# Docker socket to access Docker server
|
||||
@@ -23,7 +25,8 @@ services:
|
||||
- TZ=
|
||||
cap_add:
|
||||
# For debugging with dlv
|
||||
- SYS_PTRACE
|
||||
# - SYS_PTRACE
|
||||
- NET_ADMIN
|
||||
security_opt:
|
||||
# For debugging with dlv
|
||||
- seccomp:unconfined
|
||||
|
||||
3
.github/labels.yml
vendored
3
.github/labels.yml
vendored
@@ -70,6 +70,9 @@
|
||||
- name: "Openvpn"
|
||||
color: "ffc7ea"
|
||||
description: ""
|
||||
- name: "Wireguard"
|
||||
color: "ffc7ea"
|
||||
description: ""
|
||||
- name: "Unbound (DNS over TLS)"
|
||||
color: "ffc7ea"
|
||||
description: ""
|
||||
|
||||
@@ -68,6 +68,7 @@ LABEL \
|
||||
org.opencontainers.image.description="VPN swiss-knife like client to tunnel to multiple VPN servers using OpenVPN, IPtables, DNS over TLS, Shadowsocks, an HTTP proxy and Alpine Linux"
|
||||
ENV VPNSP=pia \
|
||||
VERSION_INFORMATION=on \
|
||||
VPN_TYPE=openvpn \
|
||||
PROTOCOL=udp \
|
||||
OPENVPN_VERSION=2.5 \
|
||||
OPENVPN_VERBOSITY=1 \
|
||||
@@ -77,6 +78,11 @@ ENV VPNSP=pia \
|
||||
OPENVPN_IPV6=off \
|
||||
OPENVPN_CUSTOM_CONFIG= \
|
||||
OPENVPN_INTERFACE=tun0 \
|
||||
WIREGUARD_PRIVATE_KEY= \
|
||||
WIREGUARD_PRESHARED_KEY= \
|
||||
WIREGUARD_ADDRESS= \
|
||||
WIREGUARD_PORT= \
|
||||
WIREGUARD_INTERFACE=wg0 \
|
||||
TZ= \
|
||||
PUID= \
|
||||
PGID= \
|
||||
|
||||
@@ -5,7 +5,7 @@ HideMyAss, IPVanish, IVPN, Mullvad, NordVPN, Privado, Private Internet Access, P
|
||||
ProtonVPN, PureVPN, Surfshark, TorGuard, VPNUnlimited, VyprVPN and Windscribe VPN servers
|
||||
using Go, OpenVPN, iptables, DNS over TLS, ShadowSocks and an HTTP proxy*
|
||||
|
||||
**ANNOUNCEMENT**: You can try Wireguard, see #565
|
||||
**ANNOUNCEMENT**: Wireguard is now supported for all providers supporting it!
|
||||
|
||||

|
||||
|
||||
@@ -55,9 +55,10 @@ using Go, OpenVPN, iptables, DNS over TLS, ShadowSocks and an HTTP proxy*
|
||||
|
||||
## Features
|
||||
|
||||
- Based on Alpine 3.14 for a small Docker image of 30MB
|
||||
- Based on Alpine 3.14 for a small Docker image of 31MB
|
||||
- Supports: **Cyberghost**, **FastestVPN**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **Surfshark**, **TorGuard**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
|
||||
- Supports OpenVPN and Wireguard (the latter in progress, see PR #565 and issue #134)
|
||||
- Supports OpenVPN
|
||||
- Supports Wireguard for **Mullvad** and **Windscribe** (more in progress, see #134)
|
||||
- DNS over TLS baked in with service provider(s) of your choice
|
||||
- DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses, with live update every 24 hours
|
||||
- Choose the vpn network protocol, `udp` or `tcp`
|
||||
@@ -110,6 +111,7 @@ The following points are all optional but should give you insights on all the po
|
||||
|
||||
- [Test your setup](https://github.com/qdm12/gluetun/wiki/Test-your-setup)
|
||||
- [How to connect other containers and devices to Gluetun](https://github.com/qdm12/gluetun/wiki/Connect-to-gluetun)
|
||||
- [How to use Wireguard](https://github.com/qdm12/gluetun/wiki/Wireguard)
|
||||
- [VPN server side port forwarding](https://github.com/qdm12/gluetun/wiki/Port-forwarding)
|
||||
- [HTTP control server](https://github.com/qdm12/gluetun/wiki/HTTP-Control-server) to automate things, restart Openvpn etc.
|
||||
- Update the image with `docker pull qmcgaw/gluetun:latest`. See this [Wiki document](https://github.com/qdm12/gluetun/wiki/Docker-image-tags) for Docker tags available.
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/healthcheck"
|
||||
"github.com/qdm12/gluetun/internal/httpproxy"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
"github.com/qdm12/gluetun/internal/publicip"
|
||||
@@ -69,13 +70,14 @@ func main() {
|
||||
|
||||
args := os.Args
|
||||
tun := tun.New()
|
||||
netLinker := netlink.New()
|
||||
cli := cli.New()
|
||||
env := params.NewEnv()
|
||||
cmder := command.NewCmder()
|
||||
|
||||
errorCh := make(chan error)
|
||||
go func() {
|
||||
errorCh <- _main(ctx, buildInfo, args, logger, env, tun, cmder, cli)
|
||||
errorCh <- _main(ctx, buildInfo, args, logger, env, tun, netLinker, cmder, cli)
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -116,7 +118,8 @@ var (
|
||||
//nolint:gocognit,gocyclo
|
||||
func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
args []string, logger logging.ParentLogger, env params.Env,
|
||||
tun tun.Interface, cmder command.RunStarter, cli cli.CLIer) error {
|
||||
tun tun.Interface, netLinker netlink.NetLinker, cmder command.RunStarter,
|
||||
cli cli.CLIer) error {
|
||||
if len(args) > 1 { // cli operation
|
||||
switch args[1] {
|
||||
case "healthcheck":
|
||||
@@ -153,7 +156,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
dnsConf := unbound.NewConfigurator(nil, cmder, dnsCrypto,
|
||||
"/etc/unbound", "/usr/sbin/unbound", cacertsPath)
|
||||
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2021-07-22T00:00:00Z")
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2021-10-02T00:00:00Z")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -164,7 +167,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
Version: buildInfo.Version,
|
||||
Commit: buildInfo.Commit,
|
||||
BuildDate: buildInfo.Created,
|
||||
Announcement: "",
|
||||
Announcement: "Wireguard is now supported!",
|
||||
AnnounceExp: announcementExp,
|
||||
// Sponsor information
|
||||
PaypalUser: "qmcgaw",
|
||||
@@ -357,12 +360,12 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
vpnLogger := logger.NewChild(logging.Settings{Prefix: "vpn: "})
|
||||
vpnLooper := vpn.NewLoop(allSettings.VPN,
|
||||
allServers, ovpnConf, firewallConf, routingConf, portForwardLooper,
|
||||
allServers, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper,
|
||||
cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient,
|
||||
buildInfo, allSettings.VersionInformation)
|
||||
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
|
||||
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
|
||||
go vpnLooper.Run(openvpnCtx, openvpnDone)
|
||||
vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler(
|
||||
"vpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
|
||||
go vpnLooper.Run(vpnCtx, vpnDone)
|
||||
|
||||
updaterLooper := updater.NewLooper(allSettings.Updater,
|
||||
allServers, storage, vpnLooper.SetServers, httpClient,
|
||||
@@ -417,7 +420,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
}
|
||||
orderHandler := goshutdown.NewOrder("gluetun", orderSettings)
|
||||
orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler,
|
||||
openvpnHandler, portForwardHandler, otherGroupHandler)
|
||||
vpnHandler, portForwardHandler, otherGroupHandler)
|
||||
|
||||
// Start VPN for the first time in a blocking call
|
||||
// until the VPN is launched
|
||||
|
||||
@@ -15,10 +15,11 @@ services:
|
||||
volumes:
|
||||
- /yourpath:/gluetun
|
||||
environment:
|
||||
# More variables are available, see the readme table
|
||||
# More variables are available, see the Wiki table
|
||||
- OPENVPN_USER=
|
||||
- OPENVPN_PASSWORD=
|
||||
- VPNSP=private internet access
|
||||
- VPN_TYPE=openvpn
|
||||
# Timezone for accurate logs times
|
||||
- TZ=
|
||||
restart: always
|
||||
|
||||
8
go.mod
8
go.mod
@@ -14,13 +14,19 @@ require (
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/vishvananda/netlink v1.1.0
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c
|
||||
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c
|
||||
inet.af/netaddr v0.0.0-20210718074554-06ca8145d722
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/go-cmp v0.5.5 // indirect
|
||||
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
|
||||
github.com/mattn/go-colorable v0.1.8 // indirect
|
||||
github.com/mattn/go-isatty v0.0.12 // indirect
|
||||
github.com/mdlayher/genetlink v1.0.0 // indirect
|
||||
github.com/mdlayher/netlink v1.4.0 // indirect
|
||||
github.com/miekg/dns v1.1.40 // indirect
|
||||
github.com/mr-tron/base58 v1.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
@@ -29,6 +35,6 @@ require (
|
||||
go4.org/intern v0.0.0-20210108033219-3eb7198706b2 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 // indirect
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 // indirect
|
||||
golang.org/x/net v0.0.0-20210504132125-bbd867fde50d // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
|
||||
)
|
||||
|
||||
73
go.sum
73
go.sum
@@ -33,11 +33,28 @@ github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3K
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gotify/go-api-client/v2 v2.0.4/go.mod h1:VKiah/UK20bXsr0JObE1eBVLW44zbBouzjuri9iwjFU=
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 h1:uhL5Gw7BINiiPAo24A2sxkcDI0Jt/sqp1v5xQCniEFA=
|
||||
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20201216134343-bde56ed16391/go.mod h1:cR77jAZG3Y3bsb8hF6fHJbFoyFukLFOkQ98S0pQz3xw=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20201220180245-69540ac93943/go.mod h1:z4c53zj6Eex712ROyh8WI0ihysb5j2ROyV42iNogmAs=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20210122163228-8d122574c736/go.mod h1:ZXpIyOK59ZnN7J0BV99cZUPmsqDRZ3eq5X+st7u/oSA=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b h1:c3NTyLNozICy8B4mlMXemD3z/gXgQzVXZS/HqT+i3do=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U=
|
||||
github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
@@ -51,8 +68,24 @@ github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ
|
||||
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43 h1:WgyLFv10Ov49JAQI/ZLUkCZ7VJS3r74hwFIGXJsgZlY=
|
||||
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo=
|
||||
github.com/mdlayher/genetlink v1.0.0 h1:OoHN1OdyEIkScEmRgxLEe2M9U8ClMytqA5niynLtfj0=
|
||||
github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc=
|
||||
github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA=
|
||||
github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M=
|
||||
github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY=
|
||||
github.com/mdlayher/netlink v1.1.1/go.mod h1:WTYpFb/WTvlRJAyKhZL5/uy69TDDpHHu2VZmb2XgV7o=
|
||||
github.com/mdlayher/netlink v1.2.0/go.mod h1:kwVW1io0AZy9A1E2YYgaD4Cj+C+GPkU6klXCMzIJ9p8=
|
||||
github.com/mdlayher/netlink v1.2.1/go.mod h1:bacnNlfhqHqqLo4WsYeXSqfyXkInQ9JneWI68v1KwSU=
|
||||
github.com/mdlayher/netlink v1.2.2-0.20210123213345-5cc92139ae3e/go.mod h1:bacnNlfhqHqqLo4WsYeXSqfyXkInQ9JneWI68v1KwSU=
|
||||
github.com/mdlayher/netlink v1.3.0/go.mod h1:xK/BssKuwcRXHrtN04UBkwQ6dY9VviGGuriDdoPSWys=
|
||||
github.com/mdlayher/netlink v1.4.0 h1:n3ARR+Fm0dDv37dj5wSWZXDKcy+U0zwcXS3zKMnSiT0=
|
||||
github.com/mdlayher/netlink v1.4.0/go.mod h1:dRJi5IABcZpBD2A3D0Mv/AiX8I9uDEu5oGkAVrekmf8=
|
||||
github.com/miekg/dns v1.1.40 h1:pyyPFfGMnciYUk/mXpKkVmeMQjfXqt3FAJ2hy7tPiLA=
|
||||
github.com/miekg/dns v1.1.40/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
||||
@@ -106,6 +139,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.0.0-20210503195802-e9a32991a82e/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
||||
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
@@ -113,38 +148,67 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191007182048-72f939374954/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210504132125-bbd867fde50d h1:nTDGCTeAu2LhcsHTRzjyIUbZHCJ4QePArsm27Hka0UM=
|
||||
golang.org/x/net v0.0.0-20210504132125-bbd867fde50d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190221075227-b4e8571b14e0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201118182958-a01c418693c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201218084310-7d0127a74742/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210110051926-789bb1bd4061/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210123111255-9b0068b26619/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210216163648-f7da38b97c65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210503173754-0981d6026fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190729092621-ff9f1409240a/go.mod h1:jcCCGcm9btYwXyDqrUWc6MKQKKGJCWEQ3AfLSRIbEuI=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
@@ -153,7 +217,14 @@ golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20210427022245-097af6e1351b/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 h1:ab2jcw2W91Rz07eHAb8Lic7sFQKO0NhBftjv6m/gL/0=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19/go.mod h1:laHzsbfMhGSobUmruXWAyMKKHSqvIcrqZJMyHD+/3O8=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c h1:ADNrRDI5NR23/TUCnEmlLZLt4u9DnZ2nwRkPrAcFvto=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c/go.mod h1:+1XihzyZUBJcSc5WO9SwNA7v26puQwOEDwanaxfNXPQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -49,7 +49,7 @@ func (settings *OpenVPNSelection) readMullvad(env params.Env) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
settings.CustomPort, err = readCustomPort(env, settings.TCP,
|
||||
settings.CustomPort, err = readOpenVPNCustomPort(env, settings.TCP,
|
||||
[]uint16{80, 443, 1401}, []uint16{53, 1194, 1195, 1196, 1197, 1300, 1301, 1302, 1303, 1400})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -43,7 +43,7 @@ var (
|
||||
)
|
||||
|
||||
func (settings *Provider) read(r reader, vpnType string) error {
|
||||
err := settings.readVPNServiceProvider(r)
|
||||
err := settings.readVPNServiceProvider(r, vpnType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,11 +94,17 @@ func (settings *Provider) read(r reader, vpnType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (settings *Provider) readVPNServiceProvider(r reader) (err error) {
|
||||
allowedVPNServiceProviders := []string{
|
||||
"cyberghost", "fastestvpn", "hidemyass", "ipvanish", "ivpn", "mullvad", "nordvpn",
|
||||
"privado", "pia", "private internet access", "privatevpn", "protonvpn",
|
||||
"purevpn", "surfshark", "torguard", constants.VPNUnlimited, "vyprvpn", "windscribe"}
|
||||
func (settings *Provider) readVPNServiceProvider(r reader, vpnType string) (err error) {
|
||||
var allowedVPNServiceProviders []string
|
||||
switch vpnType {
|
||||
case constants.OpenVPN:
|
||||
allowedVPNServiceProviders = []string{
|
||||
"cyberghost", "fastestvpn", "hidemyass", "ipvanish", "ivpn", "mullvad", "nordvpn",
|
||||
"privado", "pia", "private internet access", "privatevpn", "protonvpn",
|
||||
"purevpn", "surfshark", "torguard", constants.VPNUnlimited, "vyprvpn", "windscribe"}
|
||||
case constants.Wireguard:
|
||||
allowedVPNServiceProviders = []string{constants.Mullvad, constants.Windscribe}
|
||||
}
|
||||
|
||||
vpnsp, err := r.env.Inside("VPNSP", allowedVPNServiceProviders,
|
||||
params.Default("private internet access"))
|
||||
@@ -132,7 +138,7 @@ func readTargetIP(env params.Env) (targetIP net.IP, err error) {
|
||||
return targetIP, nil
|
||||
}
|
||||
|
||||
func readCustomPort(env params.Env, tcp bool,
|
||||
func readOpenVPNCustomPort(env params.Env, tcp bool,
|
||||
allowedTCP, allowedUDP []uint16) (port uint16, err error) {
|
||||
port, err = readPortOrZero(env, "PORT")
|
||||
if err != nil {
|
||||
@@ -147,12 +153,42 @@ func readCustomPort(env params.Env, tcp bool,
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("environment variable PORT: %w: port %d for TCP protocol", ErrInvalidPort, port)
|
||||
return 0, fmt.Errorf(
|
||||
"environment variable PORT: %w: port %d for TCP protocol, can only be one of %s",
|
||||
ErrInvalidPort, port, portsToString(allowedTCP))
|
||||
}
|
||||
for i := range allowedUDP {
|
||||
if allowedUDP[i] == port {
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("environment variable PORT: %w: port %d for UDP protocol", ErrInvalidPort, port)
|
||||
return 0, fmt.Errorf(
|
||||
"environment variable PORT: %w: port %d for UDP protocol, can only be one of %s",
|
||||
ErrInvalidPort, port, portsToString(allowedUDP))
|
||||
}
|
||||
|
||||
func readWireguardCustomPort(env params.Env, allowed []uint16) (port uint16, err error) {
|
||||
port, err = readPortOrZero(env, "WIREGUARD_PORT")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("environment variable WIREGUARD_PORT: %w", err)
|
||||
} else if port == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
for i := range allowed {
|
||||
if allowed[i] == port {
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf(
|
||||
"environment variable WIREGUARD_PORT: %w: port %d, can only be one of %s",
|
||||
ErrInvalidPort, port, portsToString(allowed))
|
||||
}
|
||||
|
||||
func portsToString(ports []uint16) string {
|
||||
slice := make([]string, len(ports))
|
||||
for i := range ports {
|
||||
slice[i] = fmt.Sprint(ports[i])
|
||||
}
|
||||
return strings.Join(slice, ", ")
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Cyberghost,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Groups: []string{"group"},
|
||||
Regions: []string{"a", "El country"},
|
||||
},
|
||||
@@ -40,6 +41,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Fastestvpn,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Hostnames: []string{"a", "b"},
|
||||
Countries: []string{"c", "d"},
|
||||
},
|
||||
@@ -56,6 +58,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.HideMyAss,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Countries: []string{"a", "b"},
|
||||
Cities: []string{"c", "d"},
|
||||
Hostnames: []string{"e", "f"},
|
||||
@@ -74,6 +77,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Ipvanish,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Countries: []string{"a", "b"},
|
||||
Cities: []string{"c", "d"},
|
||||
Hostnames: []string{"e", "f"},
|
||||
@@ -92,6 +96,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Ivpn,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Countries: []string{"a", "b"},
|
||||
Cities: []string{"c", "d"},
|
||||
Hostnames: []string{"e", "f"},
|
||||
@@ -110,6 +115,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Mullvad,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Countries: []string{"a", "b"},
|
||||
Cities: []string{"c", "d"},
|
||||
ISPs: []string{"e", "f"},
|
||||
@@ -132,6 +138,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Nordvpn,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Regions: []string{"a", "b"},
|
||||
Numbers: []uint16{1, 2},
|
||||
},
|
||||
@@ -148,6 +155,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Privado,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Hostnames: []string{"a", "b"},
|
||||
},
|
||||
},
|
||||
@@ -162,6 +170,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Privatevpn,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Hostnames: []string{"a", "b"},
|
||||
Countries: []string{"c", "d"},
|
||||
Cities: []string{"e", "f"},
|
||||
@@ -180,6 +189,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Protonvpn,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Countries: []string{"a", "b"},
|
||||
Regions: []string{"c", "d"},
|
||||
Cities: []string{"e", "f"},
|
||||
@@ -202,6 +212,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.PrivateInternetAccess,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Regions: []string{"a", "b"},
|
||||
OpenVPN: OpenVPNSelection{
|
||||
CustomPort: 1,
|
||||
@@ -226,6 +237,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Purevpn,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Regions: []string{"a", "b"},
|
||||
Countries: []string{"c", "d"},
|
||||
Cities: []string{"e", "f"},
|
||||
@@ -244,6 +256,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Surfshark,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Regions: []string{"a", "b"},
|
||||
},
|
||||
},
|
||||
@@ -258,6 +271,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Torguard,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Countries: []string{"a", "b"},
|
||||
Cities: []string{"c", "d"},
|
||||
Hostnames: []string{"e"},
|
||||
@@ -276,6 +290,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.VPNUnlimited,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Countries: []string{"a", "b"},
|
||||
Cities: []string{"c", "d"},
|
||||
Hostnames: []string{"e", "f"},
|
||||
@@ -298,6 +313,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Vyprvpn,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Regions: []string{"a", "b"},
|
||||
},
|
||||
},
|
||||
@@ -312,6 +328,7 @@ func Test_Provider_lines(t *testing.T) {
|
||||
settings: Provider{
|
||||
Name: constants.Windscribe,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
Regions: []string{"a", "b"},
|
||||
Cities: []string{"c", "d"},
|
||||
Hostnames: []string{"e", "f"},
|
||||
|
||||
@@ -109,12 +109,12 @@ func readIP(env params.Env, key string) (ip net.IP, err error) {
|
||||
}
|
||||
|
||||
func readPortOrZero(env params.Env, key string) (port uint16, err error) {
|
||||
s, err := env.Get(key)
|
||||
s, err := env.Get(key, params.Default("0"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if s == "" || s == "0" {
|
||||
if s == "0" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
type ServerSelection struct { //nolint:maligned
|
||||
// Common
|
||||
VPN string `json:"vpn"`
|
||||
VPN string `json:"vpn"` // note: this is required
|
||||
TargetIP net.IP `json:"target_ip,omitempty"`
|
||||
// TODO comments
|
||||
// Cyberghost, PIA, Protonvpn, Surfshark, Windscribe, Vyprvpn, NordVPN
|
||||
@@ -39,7 +40,8 @@ type ServerSelection struct { //nolint:maligned
|
||||
// VPNUnlimited
|
||||
StreamOnly bool `json:"stream_only"`
|
||||
|
||||
OpenVPN OpenVPNSelection `json:"openvpn"`
|
||||
OpenVPN OpenVPNSelection `json:"openvpn"`
|
||||
Wireguard WireguardSelection `json:"wireguard"`
|
||||
}
|
||||
|
||||
func (selection ServerSelection) toLines() (lines []string) {
|
||||
@@ -91,7 +93,11 @@ func (selection ServerSelection) toLines() (lines []string) {
|
||||
lines = append(lines, lastIndent+"Numbers: "+commaJoin(numbersString))
|
||||
}
|
||||
|
||||
lines = append(lines, selection.OpenVPN.lines()...)
|
||||
if selection.VPN == constants.OpenVPN {
|
||||
lines = append(lines, selection.OpenVPN.lines()...)
|
||||
} else { // wireguard
|
||||
lines = append(lines, selection.Wireguard.lines()...)
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
@@ -137,6 +143,20 @@ func (settings *OpenVPNSelection) readProtocolAndPort(env params.Env) (err error
|
||||
return nil
|
||||
}
|
||||
|
||||
type WireguardSelection struct {
|
||||
CustomPort uint16 `json:"custom_port"` // Mullvad
|
||||
}
|
||||
|
||||
func (settings *WireguardSelection) lines() (lines []string) {
|
||||
lines = append(lines, lastIndent+"Wireguard selection:")
|
||||
|
||||
if settings.CustomPort != 0 {
|
||||
lines = append(lines, indent+lastIndent+"Custom port: "+fmt.Sprint(settings.CustomPort))
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
// PortForwarding contains settings for port forwarding.
|
||||
type PortForwarding struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -20,6 +20,9 @@ func Test_Settings_lines(t *testing.T) {
|
||||
Type: constants.OpenVPN,
|
||||
Provider: Provider{
|
||||
Name: constants.Mullvad,
|
||||
ServerSelection: ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
},
|
||||
OpenVPN: OpenVPN{
|
||||
Version: constants.Openvpn25,
|
||||
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
)
|
||||
|
||||
type VPN struct {
|
||||
Type string `json:"type"`
|
||||
OpenVPN OpenVPN `json:"openvpn"`
|
||||
Provider Provider `json:"provider"`
|
||||
Type string `json:"type"`
|
||||
OpenVPN OpenVPN `json:"openvpn"`
|
||||
Wireguard Wireguard `json:"wireguard"`
|
||||
Provider Provider `json:"provider"`
|
||||
}
|
||||
|
||||
func (settings *VPN) String() string {
|
||||
@@ -24,7 +25,14 @@ func (settings *VPN) lines() (lines []string) {
|
||||
|
||||
lines = append(lines, indent+lastIndent+"Type: "+settings.Type)
|
||||
|
||||
for _, line := range settings.OpenVPN.lines() {
|
||||
var vpnLines []string
|
||||
switch settings.Type {
|
||||
case constants.OpenVPN:
|
||||
vpnLines = settings.OpenVPN.lines()
|
||||
case constants.Wireguard:
|
||||
vpnLines = settings.Wireguard.lines()
|
||||
}
|
||||
for _, line := range vpnLines {
|
||||
lines = append(lines, indent+line)
|
||||
}
|
||||
|
||||
@@ -36,13 +44,15 @@ func (settings *VPN) lines() (lines []string) {
|
||||
}
|
||||
|
||||
var (
|
||||
errReadProviderSettings = errors.New("cannot read provider settings")
|
||||
errReadOpenVPNSettings = errors.New("cannot read OpenVPN settings")
|
||||
errReadProviderSettings = errors.New("cannot read provider settings")
|
||||
errReadOpenVPNSettings = errors.New("cannot read OpenVPN settings")
|
||||
errReadWireguardSettings = errors.New("cannot read Wireguard settings")
|
||||
)
|
||||
|
||||
func (settings *VPN) read(r reader) (err error) {
|
||||
vpnType, err := r.env.Inside("VPN_TYPE",
|
||||
[]string{constants.OpenVPN}, params.Default(constants.OpenVPN))
|
||||
[]string{constants.OpenVPN, constants.Wireguard},
|
||||
params.Default(constants.OpenVPN))
|
||||
if err != nil {
|
||||
return fmt.Errorf("environment variable VPN_TYPE: %w", err)
|
||||
}
|
||||
@@ -54,9 +64,17 @@ func (settings *VPN) read(r reader) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
err = settings.OpenVPN.read(r, settings.Provider.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", errReadOpenVPNSettings, err)
|
||||
switch settings.Type {
|
||||
case constants.OpenVPN:
|
||||
err = settings.OpenVPN.read(r, settings.Provider.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", errReadOpenVPNSettings, err)
|
||||
}
|
||||
case constants.Wireguard:
|
||||
err = settings.Wireguard.read(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", errReadWireguardSettings, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -30,7 +30,12 @@ func (settings *Provider) readWindscribe(r reader) (err error) {
|
||||
return fmt.Errorf("environment variable SERVER_HOSTNAME: %w", err)
|
||||
}
|
||||
|
||||
return settings.ServerSelection.OpenVPN.readWindscribe(r.env)
|
||||
err = settings.ServerSelection.OpenVPN.readWindscribe(r.env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return settings.ServerSelection.Wireguard.readWindscribe(r.env)
|
||||
}
|
||||
|
||||
func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
|
||||
@@ -39,7 +44,7 @@ func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
settings.CustomPort, err = readCustomPort(env, settings.TCP,
|
||||
settings.CustomPort, err = readOpenVPNCustomPort(env, settings.TCP,
|
||||
[]uint16{21, 22, 80, 123, 143, 443, 587, 1194, 3306, 8080, 54783},
|
||||
[]uint16{53, 80, 123, 443, 1194, 54783})
|
||||
if err != nil {
|
||||
@@ -48,3 +53,13 @@ func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (settings *WireguardSelection) readWindscribe(env params.Env) (err error) {
|
||||
settings.CustomPort, err = readWireguardCustomPort(env,
|
||||
[]uint16{53, 80, 123, 443, 1194, 65142})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
73
internal/configuration/wireguard.go
Normal file
73
internal/configuration/wireguard.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package configuration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
// Wireguard contains settings to configure the Wireguard client.
|
||||
type Wireguard struct {
|
||||
PrivateKey string `json:"privatekey"`
|
||||
PreSharedKey string `json:"presharedkey"`
|
||||
Address *net.IPNet `json:"address"`
|
||||
Interface string `json:"interface"`
|
||||
}
|
||||
|
||||
func (settings *Wireguard) String() string {
|
||||
return strings.Join(settings.lines(), "\n")
|
||||
}
|
||||
|
||||
func (settings *Wireguard) lines() (lines []string) {
|
||||
lines = append(lines, lastIndent+"Wireguard:")
|
||||
|
||||
lines = append(lines, indent+lastIndent+"Network interface: "+settings.Interface)
|
||||
|
||||
if settings.PrivateKey != "" {
|
||||
lines = append(lines, indent+lastIndent+"Private key is set")
|
||||
}
|
||||
|
||||
if settings.PreSharedKey != "" {
|
||||
lines = append(lines, indent+lastIndent+"Pre-shared key is set")
|
||||
}
|
||||
|
||||
if settings.Address != nil {
|
||||
lines = append(lines, indent+lastIndent+"Address: "+settings.Address.String())
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
func (settings *Wireguard) read(r reader) (err error) {
|
||||
settings.PrivateKey, err = r.env.Get("WIREGUARD_PRIVATE_KEY",
|
||||
params.CaseSensitiveValue(), params.Unset(), params.Compulsory())
|
||||
if err != nil {
|
||||
return fmt.Errorf("environment variable WIREGUARD_PRIVATE_KEY: %w", err)
|
||||
}
|
||||
|
||||
settings.PreSharedKey, err = r.env.Get("WIREGUARD_PRESHARED_KEY",
|
||||
params.CaseSensitiveValue(), params.Unset())
|
||||
if err != nil {
|
||||
return fmt.Errorf("environment variable WIREGUARD_PRESHARED_KEY: %w", err)
|
||||
}
|
||||
|
||||
addressString, err := r.env.Get("WIREGUARD_ADDRESS", params.Compulsory())
|
||||
if err != nil {
|
||||
return fmt.Errorf("environment variable WIREGUARD_ADDRESS: %w", err)
|
||||
}
|
||||
ip, ipNet, err := net.ParseCIDR(addressString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("environment variable WIREGUARD_ADDRESS: %w", err)
|
||||
}
|
||||
ipNet.IP = ip
|
||||
settings.Address = ipNet
|
||||
|
||||
settings.Interface, err = r.env.Get("WIREGUARD_INTERFACE", params.Default("wg0"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("environment variable WIREGUARD_INTERFACE: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -73,7 +73,7 @@ func Test_versions(t *testing.T) {
|
||||
"Mullvad": {
|
||||
model: models.MullvadServer{},
|
||||
version: allServers.Mullvad.Version,
|
||||
digest: "2a009192",
|
||||
digest: "ec56f19d",
|
||||
},
|
||||
"Nordvpn": {
|
||||
model: models.NordvpnServer{},
|
||||
@@ -128,7 +128,7 @@ func Test_versions(t *testing.T) {
|
||||
"Windscribe": {
|
||||
model: models.WindscribeServer{},
|
||||
version: allServers.Windscribe.Version,
|
||||
digest: "6f6c16d6",
|
||||
digest: "4bd0fc4f",
|
||||
},
|
||||
}
|
||||
for name, testCase := range testCases {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
OpenVPN = "openvpn"
|
||||
OpenVPN = "openvpn"
|
||||
Wireguard = "wireguard"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
// Type is the connection type and can be "openvpn"
|
||||
// Type is the connection type and can be "openvpn" or "wireguard"
|
||||
Type string `json:"type"`
|
||||
// IP is the VPN server IP address.
|
||||
IP net.IP `json:"ip"`
|
||||
@@ -15,13 +15,17 @@ type Connection struct {
|
||||
// Protocol can be "tcp" or "udp".
|
||||
Protocol string `json:"protocol"`
|
||||
// Hostname is used for IPVanish, IVPN, Privado
|
||||
// and Windscribe for TLS verification
|
||||
// and Windscribe for TLS verification.
|
||||
Hostname string `json:"hostname"`
|
||||
// PubKey is the public key of the VPN server,
|
||||
// used only for Wireguard.
|
||||
PubKey string `json:"pubkey"`
|
||||
}
|
||||
|
||||
func (c *Connection) Equal(other Connection) bool {
|
||||
return c.IP.Equal(other.IP) && c.Port == other.Port &&
|
||||
c.Protocol == other.Protocol && c.Hostname == other.Hostname
|
||||
c.Protocol == other.Protocol && c.Hostname == other.Hostname &&
|
||||
c.PubKey == other.PubKey
|
||||
}
|
||||
|
||||
func (c Connection) OpenVPNRemoteLine() (line string) {
|
||||
|
||||
@@ -48,6 +48,7 @@ type IvpnServer struct {
|
||||
}
|
||||
|
||||
type MullvadServer struct {
|
||||
VPN string `json:"vpn"`
|
||||
IPs []net.IP `json:"ips"`
|
||||
IPsV6 []net.IP `json:"ipsv6"`
|
||||
Country string `json:"country"`
|
||||
@@ -55,6 +56,7 @@ type MullvadServer struct {
|
||||
Hostname string `json:"hostname"`
|
||||
ISP string `json:"isp"`
|
||||
Owned bool `json:"owned"`
|
||||
WgPubKey string `json:"wgpubkey,omitempty"`
|
||||
}
|
||||
|
||||
type NordvpnServer struct { //nolint:maligned
|
||||
@@ -149,9 +151,11 @@ type VyprvpnServer struct {
|
||||
}
|
||||
|
||||
type WindscribeServer struct {
|
||||
VPN string `json:"vpn"`
|
||||
Region string `json:"region"`
|
||||
City string `json:"city"`
|
||||
Hostname string `json:"hostname"`
|
||||
OvpnX509 string `json:"x509"`
|
||||
OvpnX509 string `json:"x509,omitempty"`
|
||||
WgPubKey string `json:"wgpubkey,omitempty"`
|
||||
IPs []net.IP `json:"ips"`
|
||||
}
|
||||
|
||||
7
internal/netlink/address.go
Normal file
7
internal/netlink/address.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package netlink
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
|
||||
func (n *NetLink) AddrAdd(link netlink.Link, addr *netlink.Addr) error {
|
||||
return netlink.AddrAdd(link, addr)
|
||||
}
|
||||
14
internal/netlink/interface.go
Normal file
14
internal/netlink/interface.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package netlink
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
|
||||
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . NetLinker
|
||||
|
||||
var _ NetLinker = (*NetLink)(nil)
|
||||
|
||||
type NetLinker interface {
|
||||
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
||||
RouteAdd(route *netlink.Route) error
|
||||
RuleAdd(rule *netlink.Rule) error
|
||||
RuleDel(rule *netlink.Rule) error
|
||||
}
|
||||
7
internal/netlink/netlink.go
Normal file
7
internal/netlink/netlink.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package netlink
|
||||
|
||||
type NetLink struct{}
|
||||
|
||||
func New() *NetLink {
|
||||
return &NetLink{}
|
||||
}
|
||||
7
internal/netlink/route.go
Normal file
7
internal/netlink/route.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package netlink
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
|
||||
func (n *NetLink) RouteAdd(route *netlink.Route) error {
|
||||
return netlink.RouteAdd(route)
|
||||
}
|
||||
11
internal/netlink/rule.go
Normal file
11
internal/netlink/rule.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package netlink
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
|
||||
func (n *NetLink) RuleAdd(rule *netlink.Rule) error {
|
||||
return netlink.RuleAdd(rule)
|
||||
}
|
||||
|
||||
func (n *NetLink) RuleDel(rule *netlink.Rule) error {
|
||||
return netlink.RuleDel(rule)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -19,7 +20,8 @@ func Test_Cyberghost_filterServers(t *testing.T) {
|
||||
err error
|
||||
}{
|
||||
"no servers": {
|
||||
err: errors.New("no server found: for protocol udp"),
|
||||
selection: configuration.ServerSelection{VPN: constants.OpenVPN},
|
||||
err: errors.New("no server found: for VPN openvpn; protocol udp"),
|
||||
},
|
||||
"servers without filter defaults to UDP": {
|
||||
servers: []models.CyberghostServer{
|
||||
|
||||
@@ -9,16 +9,8 @@ import (
|
||||
|
||||
func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
var port uint16 = 1194
|
||||
protocol := constants.UDP
|
||||
if selection.OpenVPN.TCP {
|
||||
port = 443
|
||||
protocol = constants.TCP
|
||||
}
|
||||
|
||||
if selection.OpenVPN.CustomPort > 0 {
|
||||
port = selection.OpenVPN.CustomPort
|
||||
}
|
||||
port := getPort(selection)
|
||||
protocol := getProtocol(selection)
|
||||
|
||||
servers, err := m.filterServers(selection)
|
||||
if err != nil {
|
||||
@@ -33,6 +25,7 @@ func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
|
||||
IP: IP,
|
||||
Port: port,
|
||||
Protocol: protocol,
|
||||
PubKey: server.WgPubKey, // Wireguard only
|
||||
}
|
||||
connections = append(connections, connection)
|
||||
}
|
||||
@@ -44,3 +37,33 @@ func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
|
||||
|
||||
return utils.PickRandomConnection(connections, m.randSource), nil
|
||||
}
|
||||
|
||||
func getPort(selection configuration.ServerSelection) (port uint16) {
|
||||
switch selection.VPN {
|
||||
case constants.Wireguard:
|
||||
customPort := selection.Wireguard.CustomPort
|
||||
if customPort > 0 {
|
||||
return customPort
|
||||
}
|
||||
const defaultPort = 51820
|
||||
return defaultPort
|
||||
default: // OpenVPN
|
||||
customPort := selection.OpenVPN.CustomPort
|
||||
if customPort > 0 {
|
||||
return customPort
|
||||
}
|
||||
port = 1194
|
||||
if selection.OpenVPN.TCP {
|
||||
port = 443
|
||||
}
|
||||
return port
|
||||
}
|
||||
}
|
||||
|
||||
func getProtocol(selection configuration.ServerSelection) (protocol string) {
|
||||
protocol = constants.UDP
|
||||
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
|
||||
protocol = constants.TCP
|
||||
}
|
||||
return protocol
|
||||
}
|
||||
|
||||
204
internal/provider/mullvad/connection_test.go
Normal file
204
internal/provider/mullvad/connection_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package mullvad
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Mullvad_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.MullvadServer
|
||||
selection configuration.ServerSelection
|
||||
connection models.Connection
|
||||
err error
|
||||
}{
|
||||
"no server available": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
err: errors.New("no server found: for VPN openvpn; protocol udp"),
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.MullvadServer{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
},
|
||||
connection: models.Connection{
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"target IP": {
|
||||
selection: configuration.ServerSelection{
|
||||
TargetIP: net.IPv4(2, 2, 2, 2),
|
||||
},
|
||||
servers: []models.MullvadServer{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
},
|
||||
connection: models.Connection{
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"with filter": {
|
||||
selection: configuration.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
},
|
||||
servers: []models.MullvadServer{
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
},
|
||||
connection: models.Connection{
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
m := New(testCase.servers, randSource)
|
||||
|
||||
connection, err := m.GetConnection(testCase.selection)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.connection, connection)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
selection configuration.ServerSelection
|
||||
port uint16
|
||||
}{
|
||||
"default": {
|
||||
port: 1194,
|
||||
},
|
||||
"OpenVPN UDP": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
port: 1194,
|
||||
},
|
||||
"OpenVPN TCP": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
OpenVPN: configuration.OpenVPNSelection{
|
||||
TCP: true,
|
||||
},
|
||||
},
|
||||
port: 443,
|
||||
},
|
||||
"OpenVPN custom port": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
OpenVPN: configuration.OpenVPNSelection{
|
||||
CustomPort: 1234,
|
||||
},
|
||||
},
|
||||
port: 1234,
|
||||
},
|
||||
"Wireguard": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.Wireguard,
|
||||
},
|
||||
port: 51820,
|
||||
},
|
||||
"Wireguard custom port": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.Wireguard,
|
||||
Wireguard: configuration.WireguardSelection{
|
||||
CustomPort: 1234,
|
||||
},
|
||||
},
|
||||
port: 1234,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
port := getPort(testCase.selection)
|
||||
|
||||
assert.Equal(t, testCase.port, port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
selection configuration.ServerSelection
|
||||
protocol string
|
||||
}{
|
||||
"default": {
|
||||
protocol: constants.UDP,
|
||||
},
|
||||
"OpenVPN UDP": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
protocol: constants.UDP,
|
||||
},
|
||||
"OpenVPN TCP": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
OpenVPN: configuration.OpenVPNSelection{
|
||||
TCP: true,
|
||||
},
|
||||
},
|
||||
protocol: constants.TCP,
|
||||
},
|
||||
"Wireguard": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.Wireguard,
|
||||
},
|
||||
protocol: constants.UDP,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protocol := getProtocol(testCase.selection)
|
||||
|
||||
assert.Equal(t, testCase.protocol, protocol)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ func (m *Mullvad) filterServers(selection configuration.ServerSelection) (
|
||||
for _, server := range m.servers {
|
||||
switch {
|
||||
case
|
||||
server.VPN != selection.VPN,
|
||||
utils.FilterByPossibilities(server.Country, selection.Countries),
|
||||
utils.FilterByPossibilities(server.City, selection.Cities),
|
||||
utils.FilterByPossibilities(server.ISP, selection.ISPs),
|
||||
|
||||
143
internal/provider/mullvad/filter_test.go
Normal file
143
internal/provider/mullvad/filter_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package mullvad
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Mullvad_filterServers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.MullvadServer
|
||||
selection configuration.ServerSelection
|
||||
filtered []models.MullvadServer
|
||||
err error
|
||||
}{
|
||||
"no server available": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
err: errors.New("no server found: for VPN openvpn; protocol udp"),
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.MullvadServer{
|
||||
{Hostname: "a"},
|
||||
{Hostname: "b"},
|
||||
{Hostname: "c"},
|
||||
},
|
||||
filtered: []models.MullvadServer{
|
||||
{Hostname: "a"},
|
||||
{Hostname: "b"},
|
||||
{Hostname: "c"},
|
||||
},
|
||||
},
|
||||
"filter OpenVPN out": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.Wireguard,
|
||||
},
|
||||
servers: []models.MullvadServer{
|
||||
{VPN: constants.OpenVPN, Hostname: "a"},
|
||||
{VPN: constants.Wireguard, Hostname: "b"},
|
||||
{VPN: constants.OpenVPN, Hostname: "c"},
|
||||
},
|
||||
filtered: []models.MullvadServer{
|
||||
{VPN: constants.Wireguard, Hostname: "b"},
|
||||
},
|
||||
},
|
||||
"filter by country": {
|
||||
selection: configuration.ServerSelection{
|
||||
Countries: []string{"b"},
|
||||
},
|
||||
servers: []models.MullvadServer{
|
||||
{Country: "a"},
|
||||
{Country: "b"},
|
||||
{Country: "c"},
|
||||
},
|
||||
filtered: []models.MullvadServer{
|
||||
{Country: "b"},
|
||||
},
|
||||
},
|
||||
"filter by city": {
|
||||
selection: configuration.ServerSelection{
|
||||
Cities: []string{"b"},
|
||||
},
|
||||
servers: []models.MullvadServer{
|
||||
{City: "a"},
|
||||
{City: "b"},
|
||||
{City: "c"},
|
||||
},
|
||||
filtered: []models.MullvadServer{
|
||||
{City: "b"},
|
||||
},
|
||||
},
|
||||
"filter by ISP": {
|
||||
selection: configuration.ServerSelection{
|
||||
ISPs: []string{"b"},
|
||||
},
|
||||
servers: []models.MullvadServer{
|
||||
{ISP: "a"},
|
||||
{ISP: "b"},
|
||||
{ISP: "c"},
|
||||
},
|
||||
filtered: []models.MullvadServer{
|
||||
{ISP: "b"},
|
||||
},
|
||||
},
|
||||
"filter by hostname": {
|
||||
selection: configuration.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
},
|
||||
servers: []models.MullvadServer{
|
||||
{Hostname: "a"},
|
||||
{Hostname: "b"},
|
||||
{Hostname: "c"},
|
||||
},
|
||||
filtered: []models.MullvadServer{
|
||||
{Hostname: "b"},
|
||||
},
|
||||
},
|
||||
"filter by owned": {
|
||||
selection: configuration.ServerSelection{
|
||||
Owned: true,
|
||||
},
|
||||
servers: []models.MullvadServer{
|
||||
{Hostname: "a"},
|
||||
{Hostname: "b", Owned: true},
|
||||
{Hostname: "c"},
|
||||
},
|
||||
filtered: []models.MullvadServer{
|
||||
{Hostname: "b", Owned: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
m := New(testCase.servers, randSource)
|
||||
|
||||
servers, err := m.filterServers(testCase.selection)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.filtered, servers)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,8 @@ var ErrNoServerFound = errors.New("no server found")
|
||||
func NoServerFoundError(selection configuration.ServerSelection) (err error) {
|
||||
var messageParts []string
|
||||
|
||||
messageParts = append(messageParts, "VPN "+selection.VPN)
|
||||
|
||||
protocol := constants.UDP
|
||||
if selection.OpenVPN.TCP {
|
||||
protocol = constants.TCP
|
||||
|
||||
34
internal/provider/utils/wireguard.go
Normal file
34
internal/provider/utils/wireguard.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
)
|
||||
|
||||
func BuildWireguardSettings(connection models.Connection,
|
||||
userSettings configuration.Wireguard) (settings wireguard.Settings) {
|
||||
settings.PrivateKey = userSettings.PrivateKey
|
||||
settings.PublicKey = connection.PubKey
|
||||
settings.PreSharedKey = userSettings.PreSharedKey
|
||||
settings.InterfaceName = userSettings.Interface
|
||||
|
||||
const routePriority = 101 // 100 is to receive external connections
|
||||
settings.RulePriority = routePriority
|
||||
|
||||
settings.Endpoint = new(net.UDPAddr)
|
||||
settings.Endpoint.IP = make(net.IP, len(connection.IP))
|
||||
copy(settings.Endpoint.IP, connection.IP)
|
||||
settings.Endpoint.Port = int(connection.Port)
|
||||
|
||||
address := new(net.IPNet)
|
||||
address.IP = make(net.IP, len(userSettings.Address.IP))
|
||||
copy(address.IP, userSettings.Address.IP)
|
||||
address.Mask = make(net.IPMask, len(userSettings.Address.Mask))
|
||||
copy(address.Mask, userSettings.Address.Mask)
|
||||
settings.Addresses = append(settings.Addresses, address)
|
||||
|
||||
return settings
|
||||
}
|
||||
@@ -9,16 +9,8 @@ import (
|
||||
|
||||
func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
|
||||
connection models.Connection, err error) {
|
||||
protocol := constants.UDP
|
||||
var port uint16 = 443
|
||||
if selection.OpenVPN.TCP {
|
||||
protocol = constants.TCP
|
||||
port = 1194
|
||||
}
|
||||
|
||||
if selection.OpenVPN.CustomPort > 0 {
|
||||
port = selection.OpenVPN.CustomPort
|
||||
}
|
||||
port := getPort(selection)
|
||||
protocol := getProtocol(selection)
|
||||
|
||||
servers, err := w.filterServers(selection)
|
||||
if err != nil {
|
||||
@@ -34,6 +26,7 @@ func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
|
||||
Port: port,
|
||||
Protocol: protocol,
|
||||
Hostname: server.OvpnX509,
|
||||
PubKey: server.WgPubKey,
|
||||
}
|
||||
connections = append(connections, connection)
|
||||
}
|
||||
@@ -45,3 +38,33 @@ func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
|
||||
|
||||
return utils.PickRandomConnection(connections, w.randSource), nil
|
||||
}
|
||||
|
||||
func getPort(selection configuration.ServerSelection) (port uint16) {
|
||||
switch selection.VPN {
|
||||
case constants.Wireguard:
|
||||
customPort := selection.Wireguard.CustomPort
|
||||
if customPort > 0 {
|
||||
return customPort
|
||||
}
|
||||
const defaultPort = 1194
|
||||
return defaultPort
|
||||
default: // OpenVPN
|
||||
customPort := selection.OpenVPN.CustomPort
|
||||
if customPort > 0 {
|
||||
return customPort
|
||||
}
|
||||
port = 1194
|
||||
if selection.OpenVPN.TCP {
|
||||
port = 443
|
||||
}
|
||||
return port
|
||||
}
|
||||
}
|
||||
|
||||
func getProtocol(selection configuration.ServerSelection) (protocol string) {
|
||||
protocol = constants.UDP
|
||||
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
|
||||
protocol = constants.TCP
|
||||
}
|
||||
return protocol
|
||||
}
|
||||
|
||||
204
internal/provider/windscribe/connection_test.go
Normal file
204
internal/provider/windscribe/connection_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package windscribe
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Windscribe_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.WindscribeServer
|
||||
selection configuration.ServerSelection
|
||||
connection models.Connection
|
||||
err error
|
||||
}{
|
||||
"no server available": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
err: errors.New("no server found: for VPN openvpn; protocol udp"),
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.WindscribeServer{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
},
|
||||
connection: models.Connection{
|
||||
IP: net.IPv4(1, 1, 1, 1),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"target IP": {
|
||||
selection: configuration.ServerSelection{
|
||||
TargetIP: net.IPv4(2, 2, 2, 2),
|
||||
},
|
||||
servers: []models.WindscribeServer{
|
||||
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
},
|
||||
connection: models.Connection{
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
"with filter": {
|
||||
selection: configuration.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
},
|
||||
servers: []models.WindscribeServer{
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
|
||||
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
|
||||
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
|
||||
},
|
||||
connection: models.Connection{
|
||||
IP: net.IPv4(2, 2, 2, 2),
|
||||
Port: 1194,
|
||||
Protocol: constants.UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
m := New(testCase.servers, randSource)
|
||||
|
||||
connection, err := m.GetConnection(testCase.selection)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.connection, connection)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
selection configuration.ServerSelection
|
||||
port uint16
|
||||
}{
|
||||
"default": {
|
||||
port: 1194,
|
||||
},
|
||||
"OpenVPN UDP": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
port: 1194,
|
||||
},
|
||||
"OpenVPN TCP": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
OpenVPN: configuration.OpenVPNSelection{
|
||||
TCP: true,
|
||||
},
|
||||
},
|
||||
port: 443,
|
||||
},
|
||||
"OpenVPN custom port": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
OpenVPN: configuration.OpenVPNSelection{
|
||||
CustomPort: 1234,
|
||||
},
|
||||
},
|
||||
port: 1234,
|
||||
},
|
||||
"Wireguard": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.Wireguard,
|
||||
},
|
||||
port: 1194,
|
||||
},
|
||||
"Wireguard custom port": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.Wireguard,
|
||||
Wireguard: configuration.WireguardSelection{
|
||||
CustomPort: 1234,
|
||||
},
|
||||
},
|
||||
port: 1234,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
port := getPort(testCase.selection)
|
||||
|
||||
assert.Equal(t, testCase.port, port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
selection configuration.ServerSelection
|
||||
protocol string
|
||||
}{
|
||||
"default": {
|
||||
protocol: constants.UDP,
|
||||
},
|
||||
"OpenVPN UDP": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
protocol: constants.UDP,
|
||||
},
|
||||
"OpenVPN TCP": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
OpenVPN: configuration.OpenVPNSelection{
|
||||
TCP: true,
|
||||
},
|
||||
},
|
||||
protocol: constants.TCP,
|
||||
},
|
||||
"Wireguard": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.Wireguard,
|
||||
},
|
||||
protocol: constants.UDP,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protocol := getProtocol(testCase.selection)
|
||||
|
||||
assert.Equal(t, testCase.protocol, protocol)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ func (w *Windscribe) filterServers(selection configuration.ServerSelection) (
|
||||
for _, server := range w.servers {
|
||||
switch {
|
||||
case
|
||||
server.VPN != selection.VPN,
|
||||
utils.FilterByPossibilities(server.Region, selection.Regions),
|
||||
utils.FilterByPossibilities(server.City, selection.Cities),
|
||||
utils.FilterByPossibilities(server.Hostname, selection.Hostnames):
|
||||
|
||||
117
internal/provider/windscribe/filter_test.go
Normal file
117
internal/provider/windscribe/filter_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package windscribe
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Windscribe_filterServers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.WindscribeServer
|
||||
selection configuration.ServerSelection
|
||||
filtered []models.WindscribeServer
|
||||
err error
|
||||
}{
|
||||
"no server available": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.OpenVPN,
|
||||
},
|
||||
err: errors.New("no server found: for VPN openvpn; protocol udp"),
|
||||
},
|
||||
"no filter": {
|
||||
servers: []models.WindscribeServer{
|
||||
{Hostname: "a"},
|
||||
{Hostname: "b"},
|
||||
{Hostname: "c"},
|
||||
},
|
||||
filtered: []models.WindscribeServer{
|
||||
{Hostname: "a"},
|
||||
{Hostname: "b"},
|
||||
{Hostname: "c"},
|
||||
},
|
||||
},
|
||||
"filter OpenVPN out": {
|
||||
selection: configuration.ServerSelection{
|
||||
VPN: constants.Wireguard,
|
||||
},
|
||||
servers: []models.WindscribeServer{
|
||||
{VPN: constants.OpenVPN, Hostname: "a"},
|
||||
{VPN: constants.Wireguard, Hostname: "b"},
|
||||
{VPN: constants.OpenVPN, Hostname: "c"},
|
||||
},
|
||||
filtered: []models.WindscribeServer{
|
||||
{VPN: constants.Wireguard, Hostname: "b"},
|
||||
},
|
||||
},
|
||||
"filter by region": {
|
||||
selection: configuration.ServerSelection{
|
||||
Regions: []string{"b"},
|
||||
},
|
||||
servers: []models.WindscribeServer{
|
||||
{Region: "a"},
|
||||
{Region: "b"},
|
||||
{Region: "c"},
|
||||
},
|
||||
filtered: []models.WindscribeServer{
|
||||
{Region: "b"},
|
||||
},
|
||||
},
|
||||
"filter by city": {
|
||||
selection: configuration.ServerSelection{
|
||||
Cities: []string{"b"},
|
||||
},
|
||||
servers: []models.WindscribeServer{
|
||||
{City: "a"},
|
||||
{City: "b"},
|
||||
{City: "c"},
|
||||
},
|
||||
filtered: []models.WindscribeServer{
|
||||
{City: "b"},
|
||||
},
|
||||
},
|
||||
"filter by hostname": {
|
||||
selection: configuration.ServerSelection{
|
||||
Hostnames: []string{"b"},
|
||||
},
|
||||
servers: []models.WindscribeServer{
|
||||
{Hostname: "a"},
|
||||
{Hostname: "b"},
|
||||
{Hostname: "c"},
|
||||
},
|
||||
filtered: []models.WindscribeServer{
|
||||
{Hostname: "b"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
randSource := rand.NewSource(0)
|
||||
|
||||
m := New(testCase.servers, randSource)
|
||||
|
||||
servers, err := m.filterServers(testCase.selection)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.filtered, servers)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -22,10 +22,12 @@ type serverData struct {
|
||||
Provider string `json:"provider"`
|
||||
IPv4 string `json:"ipv4_addr_in"`
|
||||
IPv6 string `json:"ipv6_addr_in"`
|
||||
Type string `json:"type"`
|
||||
PubKey string `json:"pubkey"` // Wireguard public key
|
||||
}
|
||||
|
||||
func fetchAPI(ctx context.Context, client *http.Client) (data []serverData, err error) {
|
||||
const url = "https://api.mullvad.net/www/relays/openvpn/"
|
||||
const url = "https://api.mullvad.net/www/relays/all/"
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,14 +6,17 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type hostToServer map[string]models.MullvadServer
|
||||
|
||||
var (
|
||||
ErrParseIPv4 = errors.New("cannot parse IPv4 address")
|
||||
ErrParseIPv6 = errors.New("cannot parse IPv6 address")
|
||||
ErrNoIP = errors.New("no IP address for VPN server")
|
||||
ErrParseIPv4 = errors.New("cannot parse IPv4 address")
|
||||
ErrParseIPv6 = errors.New("cannot parse IPv6 address")
|
||||
ErrVPNTypeNotSupported = errors.New("VPN type not supported")
|
||||
)
|
||||
|
||||
func (hts hostToServer) add(data serverData) (err error) {
|
||||
@@ -21,14 +24,8 @@ func (hts hostToServer) add(data serverData) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
ipv4 := net.ParseIP(data.IPv4)
|
||||
if ipv4 == nil || ipv4.To4() == nil {
|
||||
return fmt.Errorf("%w: %s", ErrParseIPv4, data.IPv4)
|
||||
}
|
||||
|
||||
ipv6 := net.ParseIP(data.IPv6)
|
||||
if ipv6 == nil || ipv6.To4() != nil {
|
||||
return fmt.Errorf("%w: %s", ErrParseIPv6, data.IPv6)
|
||||
if data.IPv4 == "" && data.IPv6 == "" {
|
||||
return ErrNoIP
|
||||
}
|
||||
|
||||
server, ok := hts[data.Hostname]
|
||||
@@ -36,13 +33,40 @@ func (hts hostToServer) add(data serverData) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch data.Type {
|
||||
case "openvpn":
|
||||
server.VPN = constants.OpenVPN
|
||||
case "wireguard":
|
||||
server.VPN = constants.Wireguard
|
||||
case "bridge":
|
||||
// ignore bridge servers
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrVPNTypeNotSupported, data.Type)
|
||||
}
|
||||
|
||||
if data.IPv4 != "" {
|
||||
ipv4 := net.ParseIP(data.IPv4)
|
||||
if ipv4 == nil || ipv4.To4() == nil {
|
||||
return fmt.Errorf("%w: %s", ErrParseIPv4, data.IPv4)
|
||||
}
|
||||
server.IPs = []net.IP{ipv4}
|
||||
}
|
||||
|
||||
if data.IPv6 != "" {
|
||||
ipv6 := net.ParseIP(data.IPv6)
|
||||
if ipv6 == nil || ipv6.To4() != nil {
|
||||
return fmt.Errorf("%w: %s", ErrParseIPv6, data.IPv6)
|
||||
}
|
||||
server.IPsV6 = []net.IP{ipv6}
|
||||
}
|
||||
|
||||
server.Country = data.Country
|
||||
server.City = strings.ReplaceAll(data.City, ",", "")
|
||||
server.Hostname = data.Hostname
|
||||
server.ISP = data.Provider
|
||||
server.Owned = data.Owned
|
||||
server.IPs = []net.IP{ipv4}
|
||||
server.IPsV6 = []net.IP{ipv6}
|
||||
server.WgPubKey = data.PubKey
|
||||
|
||||
hts[data.Hostname] = server
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ type groupData struct {
|
||||
City string `json:"city"`
|
||||
Nodes []serverData `json:"nodes"`
|
||||
OvpnX509 string `json:"ovpn_x509"`
|
||||
WgPubKey string `json:"wg_pubkey"`
|
||||
}
|
||||
|
||||
type serverData struct {
|
||||
|
||||
@@ -9,10 +9,14 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var ErrNotEnoughServers = errors.New("not enough servers found")
|
||||
var (
|
||||
ErrNotEnoughServers = errors.New("not enough servers found")
|
||||
ErrNoWireguardKey = errors.New("no wireguard public key found")
|
||||
)
|
||||
|
||||
func GetServers(ctx context.Context, client *http.Client, minServers int) (
|
||||
servers []models.WindscribeServer, err error) {
|
||||
@@ -26,19 +30,17 @@ func GetServers(ctx context.Context, client *http.Client, minServers int) (
|
||||
for _, group := range regionData.Groups {
|
||||
city := group.City
|
||||
x5090Name := group.OvpnX509
|
||||
wgPubKey := group.WgPubKey
|
||||
for _, node := range group.Nodes {
|
||||
const maxIPsPerNode = 3
|
||||
ips := make([]net.IP, 0, maxIPsPerNode)
|
||||
ips := make([]net.IP, 0, 2) // nolint:gomnd
|
||||
if node.IP != nil {
|
||||
ips = append(ips, node.IP)
|
||||
}
|
||||
if node.IP2 != nil {
|
||||
ips = append(ips, node.IP2)
|
||||
}
|
||||
// if node.IP3 != nil { // Wireguard + Stealth
|
||||
// ips = append(ips, node.IP3)
|
||||
// }
|
||||
server := models.WindscribeServer{
|
||||
VPN: constants.OpenVPN,
|
||||
Region: region,
|
||||
City: city,
|
||||
Hostname: node.Hostname,
|
||||
@@ -46,6 +48,18 @@ func GetServers(ctx context.Context, client *http.Client, minServers int) (
|
||||
IPs: ips,
|
||||
}
|
||||
servers = append(servers, server)
|
||||
|
||||
if node.IP3 == nil { // Wireguard + Stealth
|
||||
continue
|
||||
} else if wgPubKey == "" {
|
||||
return nil, fmt.Errorf("%w: for node %s", ErrNoWireguardKey, node.Hostname)
|
||||
}
|
||||
|
||||
server.VPN = constants.Wireguard
|
||||
server.OvpnX509 = ""
|
||||
server.WgPubKey = wgPubKey
|
||||
server.IPs = []net.IP{node.IP3}
|
||||
servers = append(servers, server)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,9 @@ func sortServers(servers []models.WindscribeServer) {
|
||||
sort.Slice(servers, func(i, j int) bool {
|
||||
if servers[i].Region == servers[j].Region {
|
||||
if servers[i].City == servers[j].City {
|
||||
if servers[i].Hostname == servers[j].Hostname {
|
||||
return servers[i].VPN < servers[j].VPN
|
||||
}
|
||||
return servers[i].Hostname < servers[j].Hostname
|
||||
}
|
||||
return servers[i].City < servers[j].City
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
"github.com/qdm12/gluetun/internal/publicip"
|
||||
@@ -37,6 +38,7 @@ type Loop struct {
|
||||
versionInfo bool
|
||||
// Configurators
|
||||
openvpnConf openvpn.Interface
|
||||
netLinker netlink.NetLinker
|
||||
fw firewallConfigurer
|
||||
routing routing.VPNGetter
|
||||
portForward portforward.StartStopper
|
||||
@@ -67,7 +69,7 @@ const (
|
||||
|
||||
func NewLoop(vpnSettings configuration.VPN,
|
||||
allServers models.AllServers, openvpnConf openvpn.Interface,
|
||||
fw firewallConfigurer, routing routing.VPNGetter,
|
||||
netLinker netlink.NetLinker, fw firewallConfigurer, routing routing.VPNGetter,
|
||||
portForward portforward.StartStopper, starter command.Starter,
|
||||
publicip publicip.Looper, dnsLooper dns.Looper,
|
||||
logger logging.Logger, client *http.Client,
|
||||
@@ -86,6 +88,7 @@ func NewLoop(vpnSettings configuration.VPN,
|
||||
buildInfo: buildInfo,
|
||||
versionInfo: versionInfo,
|
||||
openvpnConf: openvpnConf,
|
||||
netLinker: netLinker,
|
||||
fw: fw,
|
||||
routing: routing,
|
||||
portForward: portForward,
|
||||
|
||||
@@ -30,8 +30,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
|
||||
providerConf := provider.New(settings.Provider.Name, allServers, time.Now)
|
||||
|
||||
vpnRunner, serverName, err := setupOpenVPN(ctx, l.fw,
|
||||
l.openvpnConf, providerConf, settings, l.starter, l.logger)
|
||||
var vpnRunner vpnRunner
|
||||
var serverName, vpnInterface string
|
||||
var err error
|
||||
if settings.Type == constants.OpenVPN {
|
||||
vpnInterface = settings.OpenVPN.Interface
|
||||
vpnRunner, serverName, err = setupOpenVPN(ctx, l.fw,
|
||||
l.openvpnConf, providerConf, settings, l.starter, l.logger)
|
||||
} else { // Wireguard
|
||||
vpnInterface = settings.Wireguard.Interface
|
||||
vpnRunner, serverName, err = setupWireguard(ctx, l.netLinker, l.fw, providerConf, settings, l.logger)
|
||||
}
|
||||
if err != nil {
|
||||
l.crashed(ctx, err)
|
||||
continue
|
||||
@@ -40,7 +49,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
portForwarding: settings.Provider.PortForwarding.Enabled,
|
||||
serverName: serverName,
|
||||
portForwarder: providerConf,
|
||||
vpnIntf: settings.OpenVPN.Interface,
|
||||
vpnIntf: vpnInterface,
|
||||
}
|
||||
|
||||
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
|
||||
|
||||
@@ -17,13 +17,6 @@ type tunnelUpData struct {
|
||||
}
|
||||
|
||||
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
vpnDestination, err := l.routing.VPNDestinationIP()
|
||||
if err != nil {
|
||||
l.logger.Warn(err.Error())
|
||||
} else {
|
||||
l.logger.Info("VPN routing IP address: " + vpnDestination.String())
|
||||
}
|
||||
|
||||
if l.dnsLooper.GetSettings().Enabled {
|
||||
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
|
||||
}
|
||||
@@ -40,7 +33,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
}
|
||||
}
|
||||
|
||||
err = l.startPortForwarding(ctx, data)
|
||||
err := l.startPortForwarding(ctx, data)
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
45
internal/vpn/wireguard.go
Normal file
45
internal/vpn/wireguard.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
)
|
||||
|
||||
var (
|
||||
errGetServer = errors.New("failed finding a VPN server")
|
||||
errCreateWireguard = errors.New("failed creating Wireguard")
|
||||
)
|
||||
|
||||
// setupWireguard sets Wireguard up using the configurators and settings given.
|
||||
// It returns a serverName for port forwarding (PIA) and an error if it fails.
|
||||
func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
|
||||
fw firewall.VPNConnectionSetter, providerConf provider.Provider,
|
||||
settings configuration.VPN, logger wireguard.Logger) (
|
||||
wireguarder wireguard.Wireguarder, serverName string, err error) {
|
||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("%w: %s", errGetServer, err)
|
||||
}
|
||||
|
||||
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard)
|
||||
|
||||
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("%w: %s", errCreateWireguard, err)
|
||||
}
|
||||
|
||||
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
|
||||
}
|
||||
|
||||
return wireguarder, connection.Hostname, nil
|
||||
}
|
||||
25
internal/wireguard/address.go
Normal file
25
internal/wireguard/address.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addAddresses(link netlink.Link,
|
||||
addresses []*net.IPNet) (err error) {
|
||||
for _, ipNet := range addresses {
|
||||
address := &netlink.Addr{
|
||||
IPNet: ipNet,
|
||||
}
|
||||
|
||||
err = w.netlink.AddrAdd(link, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when adding address %s to link %s",
|
||||
err, address, link.Attrs().Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
94
internal/wireguard/address_test.go
Normal file
94
internal/wireguard/address_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
|
||||
ipNetTwo := &net.IPNet{IP: net.IPv4(4, 5, 6, 7), Mask: net.IPv4Mask(255, 255, 255, 128)}
|
||||
|
||||
newLink := func() netlink.Link {
|
||||
linkAttrs := netlink.NewLinkAttrs()
|
||||
linkAttrs.Name = "a_bridge"
|
||||
return &netlink.Bridge{
|
||||
LinkAttrs: linkAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
link netlink.Link
|
||||
addrs []*net.IPNet
|
||||
expectedAddrs []*netlink.Addr
|
||||
addrAddErrs []error
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
link: newLink(),
|
||||
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
||||
expectedAddrs: []*netlink.Addr{
|
||||
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
|
||||
},
|
||||
addrAddErrs: []error{nil, nil},
|
||||
},
|
||||
"first add error": {
|
||||
link: newLink(),
|
||||
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
||||
expectedAddrs: []*netlink.Addr{
|
||||
{IPNet: ipNetOne},
|
||||
},
|
||||
addrAddErrs: []error{errDummy},
|
||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
|
||||
},
|
||||
"second add error": {
|
||||
link: newLink(),
|
||||
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
||||
expectedAddrs: []*netlink.Addr{
|
||||
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
|
||||
},
|
||||
addrAddErrs: []error{nil, errDummy},
|
||||
err: errors.New("dummy: when adding address 4.5.6.7/25 to link a_bridge"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
require.Equal(t, len(testCase.expectedAddrs), len(testCase.addrAddErrs))
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
for i := range testCase.expectedAddrs {
|
||||
netLinker.EXPECT().
|
||||
AddrAdd(testCase.link, testCase.expectedAddrs[i]).
|
||||
Return(testCase.addrAddErrs[i])
|
||||
}
|
||||
|
||||
err := wg.addAddresses(testCase.link, testCase.addrs)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
59
internal/wireguard/cleanup.go
Normal file
59
internal/wireguard/cleanup.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package wireguard
|
||||
|
||||
import "sort"
|
||||
|
||||
type closer struct {
|
||||
operation string
|
||||
step step
|
||||
close func() error
|
||||
closed bool
|
||||
}
|
||||
|
||||
type closers []closer
|
||||
|
||||
func (c *closers) add(operation string, step step,
|
||||
closeFunc func() error) {
|
||||
closer := closer{
|
||||
operation: operation,
|
||||
step: step,
|
||||
close: closeFunc,
|
||||
}
|
||||
*c = append(*c, closer)
|
||||
}
|
||||
|
||||
func (c *closers) cleanup(logger Logger) {
|
||||
closers := *c
|
||||
|
||||
sort.Slice(closers, func(i, j int) bool {
|
||||
return closers[i].step < closers[j].step
|
||||
})
|
||||
|
||||
for i, closer := range closers {
|
||||
if closer.closed {
|
||||
continue
|
||||
} else {
|
||||
closers[i].closed = true
|
||||
}
|
||||
logger.Debug(closer.operation + "...")
|
||||
err := closer.close()
|
||||
if err != nil {
|
||||
logger.Error("failed " + closer.operation + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type step int
|
||||
|
||||
const (
|
||||
// stepOne closes the wireguard controller client,
|
||||
// and removes the IP rule.
|
||||
stepOne step = iota
|
||||
// stepTwo closes the UAPI listener.
|
||||
stepTwo
|
||||
// stepThree closes the UAPI file.
|
||||
stepThree
|
||||
// stepFour closes the Wireguard device.
|
||||
stepFour
|
||||
// stepFive closes the bind connection and the TUN device file.
|
||||
stepFive
|
||||
)
|
||||
57
internal/wireguard/cleanup_test.go
Normal file
57
internal/wireguard/cleanup_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_closers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var ACloseCalled, BCloseCalled, CCloseCalled bool
|
||||
var (
|
||||
AErr error
|
||||
BErr = errors.New("B failed")
|
||||
CErr = errors.New("C failed")
|
||||
)
|
||||
|
||||
var closers closers
|
||||
closers.add("closing A", stepFive, func() error {
|
||||
ACloseCalled = true
|
||||
return AErr
|
||||
})
|
||||
|
||||
closers.add("closing B", stepThree, func() error {
|
||||
BCloseCalled = true
|
||||
return BErr
|
||||
})
|
||||
|
||||
closers.add("closing C", stepTwo, func() error {
|
||||
CCloseCalled = true
|
||||
return CErr
|
||||
})
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
prevCall := logger.EXPECT().Debug("closing C...")
|
||||
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
|
||||
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
|
||||
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
|
||||
logger.EXPECT().Debug("closing A...").After(prevCall)
|
||||
|
||||
closers.cleanup(logger)
|
||||
|
||||
closers.cleanup(logger) // run twice should not close already closed
|
||||
|
||||
for _, closer := range closers {
|
||||
assert.True(t, closer.closed)
|
||||
}
|
||||
|
||||
assert.True(t, ACloseCalled)
|
||||
assert.True(t, BCloseCalled)
|
||||
assert.True(t, CCloseCalled)
|
||||
}
|
||||
86
internal/wireguard/config.go
Normal file
86
internal/wireguard/config.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var (
|
||||
errMakeConfig = errors.New("cannot make device configuration")
|
||||
errConfigureDevice = errors.New("cannot configure device")
|
||||
)
|
||||
|
||||
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
||||
deviceConfig, err := makeDeviceConfig(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", errMakeConfig, err)
|
||||
}
|
||||
|
||||
err = client.ConfigureDevice(settings.InterfaceName, deviceConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", errConfigureDevice, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
|
||||
privateKey, err := wgtypes.ParseKey(settings.PrivateKey)
|
||||
if err != nil {
|
||||
return config, ErrPrivateKeyInvalid
|
||||
}
|
||||
|
||||
publicKey, err := wgtypes.ParseKey(settings.PublicKey)
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("%w: %s", ErrPublicKeyInvalid, settings.PublicKey)
|
||||
}
|
||||
|
||||
var preSharedKey *wgtypes.Key
|
||||
if settings.PreSharedKey != "" {
|
||||
preSharedKeyValue, err := wgtypes.ParseKey(settings.PreSharedKey)
|
||||
if err != nil {
|
||||
return config, ErrPreSharedKeyInvalid
|
||||
}
|
||||
preSharedKey = &preSharedKeyValue
|
||||
}
|
||||
|
||||
firewallMark := settings.FirewallMark
|
||||
|
||||
config = wgtypes.Config{
|
||||
PrivateKey: &privateKey,
|
||||
ReplacePeers: true,
|
||||
FirewallMark: &firewallMark,
|
||||
Peers: []wgtypes.PeerConfig{
|
||||
{
|
||||
PublicKey: publicKey,
|
||||
PresharedKey: preSharedKey,
|
||||
AllowedIPs: []net.IPNet{
|
||||
*allIPv4(),
|
||||
*allIPv6(),
|
||||
},
|
||||
ReplaceAllowedIPs: true,
|
||||
Endpoint: settings.Endpoint,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func allIPv4() (ipNet *net.IPNet) {
|
||||
return &net.IPNet{
|
||||
IP: net.IPv4(0, 0, 0, 0),
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
}
|
||||
}
|
||||
|
||||
func allIPv6() (ipNet *net.IPNet) {
|
||||
return &net.IPNet{
|
||||
IP: net.IPv6zero,
|
||||
Mask: []byte(net.IPv6zero),
|
||||
}
|
||||
}
|
||||
126
internal/wireguard/config_test.go
Normal file
126
internal/wireguard/config_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func Test_makeDeviceConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
validKey1 = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
|
||||
validKey2 = "aPjc9US5ICB30D1P4glR9tO7bkB2Ga+KZiFqnoypBHk="
|
||||
validKey3 = "gFIW0lTmBYEucynoIg+XmeWckDUXTcC4Po5ijR5G+HM="
|
||||
)
|
||||
|
||||
parseKey := func(t *testing.T, s string) *wgtypes.Key {
|
||||
t.Helper()
|
||||
key, err := wgtypes.ParseKey(s)
|
||||
require.NoError(t, err)
|
||||
return &key
|
||||
}
|
||||
|
||||
intPtr := func(n int) *int { return &n }
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
config wgtypes.Config
|
||||
err error
|
||||
}{
|
||||
"bad private key": {
|
||||
settings: Settings{
|
||||
PrivateKey: "bad key",
|
||||
},
|
||||
err: ErrPrivateKeyInvalid,
|
||||
},
|
||||
"bad public key": {
|
||||
settings: Settings{
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: "bad key",
|
||||
},
|
||||
err: errors.New("cannot parse public key: bad key"),
|
||||
},
|
||||
"bad pre-shared key": {
|
||||
settings: Settings{
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
PreSharedKey: "bad key",
|
||||
},
|
||||
err: errors.New("cannot parse pre-shared key"),
|
||||
},
|
||||
"valid settings": {
|
||||
settings: Settings{
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
PreSharedKey: validKey3,
|
||||
FirewallMark: 9876,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(99, 99, 99, 99),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
config: wgtypes.Config{
|
||||
PrivateKey: parseKey(t, validKey1),
|
||||
ReplacePeers: true,
|
||||
FirewallMark: intPtr(9876),
|
||||
Peers: []wgtypes.PeerConfig{
|
||||
{
|
||||
PublicKey: *parseKey(t, validKey2),
|
||||
PresharedKey: parseKey(t, validKey3),
|
||||
AllowedIPs: []net.IPNet{
|
||||
{
|
||||
IP: net.IPv4(0, 0, 0, 0),
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
},
|
||||
{
|
||||
IP: net.IPv6zero,
|
||||
Mask: []byte(net.IPv6zero),
|
||||
},
|
||||
},
|
||||
ReplaceAllowedIPs: true,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(99, 99, 99, 99),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := makeDeviceConfig(testCase.settings)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.config, config)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_allIPv4(t *testing.T) {
|
||||
t.Parallel()
|
||||
ipNet := allIPv4()
|
||||
assert.Equal(t, "0.0.0.0/0", ipNet.String())
|
||||
}
|
||||
|
||||
func Test_allIPv6(t *testing.T) {
|
||||
t.Parallel()
|
||||
ipNet := allIPv6()
|
||||
assert.Equal(t, "::/0", ipNet.String())
|
||||
}
|
||||
30
internal/wireguard/constructor.go
Normal file
30
internal/wireguard/constructor.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package wireguard
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/netlink"
|
||||
|
||||
var _ Wireguarder = (*Wireguard)(nil)
|
||||
|
||||
type Wireguarder interface {
|
||||
Runner
|
||||
Runner
|
||||
}
|
||||
|
||||
type Wireguard struct {
|
||||
logger Logger
|
||||
settings Settings
|
||||
netlink netlink.NetLinker
|
||||
}
|
||||
|
||||
func New(settings Settings, netlink NetLinker,
|
||||
logger Logger) (w *Wireguard, err error) {
|
||||
settings.SetDefaults()
|
||||
if err := settings.Check(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Wireguard{
|
||||
logger: logger,
|
||||
settings: settings,
|
||||
netlink: netlink,
|
||||
}, nil
|
||||
}
|
||||
80
internal/wireguard/constructor_test.go
Normal file
80
internal/wireguard/constructor_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_New(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const validKeyString = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
|
||||
logger := NewMockLogger(nil)
|
||||
netLinker := NewMockNetLinker(nil)
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
wireguard *Wireguard
|
||||
err error
|
||||
}{
|
||||
"bad settings": {
|
||||
settings: Settings{
|
||||
PrivateKey: "",
|
||||
},
|
||||
err: ErrPrivateKeyMissing,
|
||||
},
|
||||
"minimal valid settings": {
|
||||
settings: Settings{
|
||||
PrivateKey: validKeyString,
|
||||
PublicKey: validKeyString,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
},
|
||||
Addresses: []*net.IPNet{{
|
||||
IP: net.IPv4(5, 6, 7, 8),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255)},
|
||||
},
|
||||
FirewallMark: 100,
|
||||
},
|
||||
wireguard: &Wireguard{
|
||||
logger: logger,
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKeyString,
|
||||
PublicKey: validKeyString,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{
|
||||
IP: net.IPv4(5, 6, 7, 8),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255)},
|
||||
},
|
||||
FirewallMark: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wireguard, err := New(testCase.settings, netLinker, logger)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.wireguard, wireguard)
|
||||
})
|
||||
}
|
||||
}
|
||||
26
internal/wireguard/log.go
Normal file
26
internal/wireguard/log.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
|
||||
return &device.Logger{
|
||||
Verbosef: func(format string, args ...interface{}) {
|
||||
logger.Debug(fmt.Sprintf(format, args...))
|
||||
},
|
||||
Errorf: func(format string, args ...interface{}) {
|
||||
logger.Error(fmt.Sprintf(format, args...))
|
||||
},
|
||||
}
|
||||
}
|
||||
70
internal/wireguard/log_mock_test.go
Normal file
70
internal/wireguard/log_mock_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/wireguard (interfaces: Logger)
|
||||
|
||||
// Package wireguard is a generated GoMock package.
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockLogger) Info(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Info", arg0)
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||
}
|
||||
23
internal/wireguard/log_test.go
Normal file
23
internal/wireguard/log_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
func Test_makeDeviceLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
|
||||
deviceLogger := makeDeviceLogger(logger)
|
||||
|
||||
logger.EXPECT().Debug("test 1")
|
||||
deviceLogger.Verbosef("test %d", 1)
|
||||
|
||||
logger.EXPECT().Error("test 2")
|
||||
deviceLogger.Errorf("test %d", 2)
|
||||
}
|
||||
113
internal/wireguard/netlink_integration_test.go
Normal file
113
internal/wireguard/netlink_integration_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// +build netlink
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
inetlink "github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := inetlink.New()
|
||||
wg := &Wireguard{
|
||||
netlink: netlinker,
|
||||
}
|
||||
|
||||
intfName := "test_" + fmt.Sprint(rand.Intn(10000)) //nolint:gosec
|
||||
|
||||
// Add link
|
||||
linkAttrs := netlink.NewLinkAttrs()
|
||||
linkAttrs.Name = intfName
|
||||
link := &netlink.Bridge{
|
||||
LinkAttrs: linkAttrs,
|
||||
}
|
||||
err := netlink.LinkAdd(link)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err = netlink.LinkDel(link)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
addresses := []*net.IPNet{
|
||||
{IP: net.IP{1, 2, 3, 4}, Mask: net.IPv4Mask(255, 255, 255, 255)},
|
||||
{IP: net.IP{5, 6, 7, 8}, Mask: net.IPv4Mask(255, 255, 255, 255)},
|
||||
}
|
||||
|
||||
// Success
|
||||
err = wg.addAddresses(link, addresses)
|
||||
require.NoError(t, err)
|
||||
|
||||
netlinkAddresses, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(addresses), len(netlinkAddresses))
|
||||
for i, netlinkAddress := range netlinkAddresses {
|
||||
ipNet := netlinkAddress.IPNet
|
||||
assert.Equal(t, addresses[i], ipNet)
|
||||
}
|
||||
|
||||
// Existing address cannot be added
|
||||
err = wg.addAddresses(link, addresses)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "file exists: when adding address 1.2.3.4/32 to link test_8081", err.Error())
|
||||
}
|
||||
|
||||
func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := inetlink.New()
|
||||
wg := &Wireguard{
|
||||
netlink: netlinker,
|
||||
}
|
||||
|
||||
rulePriority := 10000
|
||||
const firewallMark = 999
|
||||
|
||||
cleanup, err := wg.addRule(rulePriority, firewallMark)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := cleanup()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
|
||||
require.NoError(t, err)
|
||||
var rule netlink.Rule
|
||||
var ruleFound bool
|
||||
for _, rule = range rules {
|
||||
if rule.Mark == firewallMark {
|
||||
ruleFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, ruleFound)
|
||||
expectedRule := netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
Mask: 4294967295,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
}
|
||||
assert.Equal(t, expectedRule, rule)
|
||||
|
||||
// Existing rule cannot be added
|
||||
nilCleanup, err := wg.addRule(rulePriority, firewallMark)
|
||||
if nilCleanup != nil {
|
||||
_ = nilCleanup() // in case it succeeds
|
||||
}
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "file exists: when adding rule: ip rule 10000: from <nil> table 999", err.Error())
|
||||
}
|
||||
12
internal/wireguard/netlinker.go
Normal file
12
internal/wireguard/netlinker.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package wireguard
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
|
||||
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
|
||||
|
||||
type NetLinker interface {
|
||||
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
||||
RouteAdd(route *netlink.Route) error
|
||||
RuleAdd(rule *netlink.Rule) error
|
||||
RuleDel(rule *netlink.Rule) error
|
||||
}
|
||||
91
internal/wireguard/netlinker_mock_test.go
Normal file
91
internal/wireguard/netlinker_mock_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/wireguard (interfaces: NetLinker)
|
||||
|
||||
// Package wireguard is a generated GoMock package.
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
netlink "github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// MockNetLinker is a mock of NetLinker interface.
|
||||
type MockNetLinker struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockNetLinkerMockRecorder
|
||||
}
|
||||
|
||||
// MockNetLinkerMockRecorder is the mock recorder for MockNetLinker.
|
||||
type MockNetLinkerMockRecorder struct {
|
||||
mock *MockNetLinker
|
||||
}
|
||||
|
||||
// NewMockNetLinker creates a new mock instance.
|
||||
func NewMockNetLinker(ctrl *gomock.Controller) *MockNetLinker {
|
||||
mock := &MockNetLinker{ctrl: ctrl}
|
||||
mock.recorder = &MockNetLinkerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddrAdd mocks base method.
|
||||
func (m *MockNetLinker) AddrAdd(arg0 netlink.Link, arg1 *netlink.Addr) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddrAdd", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddrAdd indicates an expected call of AddrAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) AddrAdd(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1)
|
||||
}
|
||||
|
||||
// RouteAdd mocks base method.
|
||||
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RouteAdd", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RouteAdd indicates an expected call of RouteAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0)
|
||||
}
|
||||
|
||||
// RuleAdd mocks base method.
|
||||
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RuleAdd", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RuleAdd indicates an expected call of RuleAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleAdd", reflect.TypeOf((*MockNetLinker)(nil).RuleAdd), arg0)
|
||||
}
|
||||
|
||||
// RuleDel mocks base method.
|
||||
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RuleDel", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RuleDel indicates an expected call of RuleDel.
|
||||
func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleDel", reflect.TypeOf((*MockNetLinker)(nil).RuleDel), arg0)
|
||||
}
|
||||
26
internal/wireguard/route.go
Normal file
26
internal/wireguard/route.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// TODO add IPv6 route if IPv6 is supported
|
||||
|
||||
func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
|
||||
firewallMark int) (err error) {
|
||||
route := &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: dst,
|
||||
Table: firewallMark,
|
||||
}
|
||||
|
||||
err = w.netlink.RouteAdd(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when adding route: %s", err, route)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
85
internal/wireguard/route_test.go
Normal file
85
internal/wireguard/route_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const linkIndex = 88
|
||||
newLink := func() netlink.Link {
|
||||
linkAttrs := netlink.NewLinkAttrs()
|
||||
linkAttrs.Name = "a_bridge"
|
||||
linkAttrs.Index = linkIndex
|
||||
return &netlink.Bridge{
|
||||
LinkAttrs: linkAttrs,
|
||||
}
|
||||
}
|
||||
ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
|
||||
const firewallMark = 51820
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
link netlink.Link
|
||||
dst *net.IPNet
|
||||
expectedRoute *netlink.Route
|
||||
routeAddErr error
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
link: newLink(),
|
||||
dst: ipNet,
|
||||
expectedRoute: &netlink.Route{
|
||||
LinkIndex: linkIndex,
|
||||
Dst: ipNet,
|
||||
Table: firewallMark,
|
||||
},
|
||||
},
|
||||
"route add error": {
|
||||
link: newLink(),
|
||||
dst: ipNet,
|
||||
expectedRoute: &netlink.Route{
|
||||
LinkIndex: linkIndex,
|
||||
Dst: ipNet,
|
||||
Table: firewallMark,
|
||||
},
|
||||
routeAddErr: errDummy,
|
||||
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820}"), //nolint:lll
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
netLinker.EXPECT().
|
||||
RouteAdd(testCase.expectedRoute).
|
||||
Return(testCase.routeAddErr)
|
||||
|
||||
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
28
internal/wireguard/rule.go
Normal file
28
internal/wireguard/rule.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRule(rulePriority, firewallMark int) (
|
||||
cleanup func() error, err error) {
|
||||
rule := netlink.NewRule()
|
||||
rule.Invert = true
|
||||
rule.Priority = rulePriority
|
||||
rule.Mark = firewallMark
|
||||
rule.Table = firewallMark
|
||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
||||
return nil, fmt.Errorf("%w: when adding rule: %s", err, rule)
|
||||
}
|
||||
|
||||
cleanup = func() error {
|
||||
err := w.netlink.RuleDel(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when deleting rule: %s", err, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return cleanup, nil
|
||||
}
|
||||
106
internal/wireguard/rule_test.go
Normal file
106
internal/wireguard/rule_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rulePriority = 987
|
||||
const firewallMark = 456
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
expectedRule *netlink.Rule
|
||||
ruleAddErr error
|
||||
err error
|
||||
ruleDelErr error
|
||||
cleanupErr error
|
||||
}{
|
||||
"success": {
|
||||
expectedRule: &netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
},
|
||||
},
|
||||
"rule add error": {
|
||||
expectedRule: &netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
},
|
||||
ruleAddErr: errDummy,
|
||||
err: errors.New("dummy: when adding rule: ip rule 987: from <nil> table 456"),
|
||||
},
|
||||
"rule delete error": {
|
||||
expectedRule: &netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
},
|
||||
ruleDelErr: errDummy,
|
||||
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from <nil> table 456"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
|
||||
Return(testCase.ruleAddErr)
|
||||
cleanup, err := wg.addRule(rulePriority, firewallMark)
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
netLinker.EXPECT().RuleDel(testCase.expectedRule).
|
||||
Return(testCase.ruleDelErr)
|
||||
err = cleanup()
|
||||
if testCase.cleanupErr != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.cleanupErr.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
165
internal/wireguard/run.go
Normal file
165
internal/wireguard/run.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCreateTun = errors.New("cannot create TUN device")
|
||||
ErrFindLink = errors.New("cannot find link")
|
||||
ErrFindDevice = errors.New("cannot find Wireguard device")
|
||||
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
|
||||
ErrWgctrlOpen = errors.New("cannot open wgctrl")
|
||||
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
|
||||
ErrAddAddress = errors.New("cannot add address to wireguard interface")
|
||||
ErrConfigure = errors.New("cannot configure wireguard interface")
|
||||
ErrIfaceUp = errors.New("cannot set the interface to UP")
|
||||
ErrRouteAdd = errors.New("cannot add route for interface")
|
||||
ErrRuleAdd = errors.New("cannot add rule for interface")
|
||||
ErrDeviceWaited = errors.New("device waited for")
|
||||
)
|
||||
|
||||
type Runner interface {
|
||||
Run(ctx context.Context, waitError chan<- error, ready chan<- struct{})
|
||||
}
|
||||
|
||||
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
||||
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
|
||||
return
|
||||
}
|
||||
|
||||
var closers closers
|
||||
closers.add("closing controller client", stepOne, client.Close)
|
||||
|
||||
defer closers.cleanup(w.logger)
|
||||
|
||||
tun, err := tun.CreateTUN(w.settings.InterfaceName, device.DefaultMTU)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||
return
|
||||
}
|
||||
|
||||
closers.add("closing TUN device", stepFive, tun.Close)
|
||||
|
||||
tunName, err := tun.Name()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||
return
|
||||
} else if tunName != w.settings.InterfaceName {
|
||||
waitError <- fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||
ErrCreateTun, w.settings.InterfaceName, tunName)
|
||||
return
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(w.settings.InterfaceName)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s: %s", ErrFindLink, w.settings.InterfaceName, err)
|
||||
return
|
||||
}
|
||||
|
||||
bind := conn.NewDefaultBind()
|
||||
|
||||
closers.add("closing bind", stepFive, bind.Close)
|
||||
|
||||
deviceLogger := makeDeviceLogger(w.logger)
|
||||
device := device.NewDevice(tun, bind, deviceLogger)
|
||||
|
||||
closers.add("closing Wireguard device", stepFour, func() error {
|
||||
device.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
uapiFile, err := ipc.UAPIOpen(w.settings.InterfaceName)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||
return
|
||||
}
|
||||
|
||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
||||
|
||||
uapiListener, err := ipc.UAPIListen(w.settings.InterfaceName, uapiFile)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||
return
|
||||
}
|
||||
|
||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
||||
|
||||
// acceptAndHandle exits when uapiListener is closed
|
||||
uapiAcceptErrorCh := make(chan error)
|
||||
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
|
||||
|
||||
err = w.addAddresses(link, w.settings.Addresses)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = configureDevice(client, w.settings)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
||||
return
|
||||
}
|
||||
|
||||
ruleCleanup, err := w.addRule(
|
||||
w.settings.RulePriority, w.settings.FirewallMark)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrRuleAdd, err)
|
||||
return
|
||||
}
|
||||
closers.add("removing rule", stepOne, ruleCleanup)
|
||||
|
||||
w.logger.Info("Wireguard is up")
|
||||
ready <- struct{}{}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case err = <-uapiAcceptErrorCh:
|
||||
close(uapiAcceptErrorCh)
|
||||
case <-device.Wait():
|
||||
err = ErrDeviceWaited
|
||||
}
|
||||
|
||||
closers.cleanup(w.logger)
|
||||
|
||||
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
|
||||
|
||||
waitError <- err
|
||||
}
|
||||
|
||||
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
||||
uapiAcceptErrorCh chan<- error) {
|
||||
for { // stopped by uapiFile.Close()
|
||||
conn, err := uapi.Accept()
|
||||
if err != nil {
|
||||
uapiAcceptErrorCh <- err
|
||||
return
|
||||
}
|
||||
go device.IpcHandle(conn)
|
||||
}
|
||||
}
|
||||
212
internal/wireguard/settings.go
Normal file
212
internal/wireguard/settings.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
// Interface name for the Wireguard interface.
|
||||
// It defaults to wg0 if unset.
|
||||
InterfaceName string
|
||||
// Private key in base 64 format
|
||||
PrivateKey string
|
||||
// Public key in base 64 format
|
||||
PublicKey string
|
||||
// Pre shared key in base 64 format
|
||||
PreSharedKey string
|
||||
// Wireguard server endpoint to connect to.
|
||||
Endpoint *net.UDPAddr
|
||||
// Addresses assigned to the client.
|
||||
Addresses []*net.IPNet
|
||||
// FirewallMark to be used in routing tables and IP rules.
|
||||
// It defaults to 51820 if left to 0.
|
||||
FirewallMark int
|
||||
// RulePriority is the priority for the rule created with the
|
||||
// FirewallMark.
|
||||
RulePriority int
|
||||
}
|
||||
|
||||
func (s *Settings) SetDefaults() {
|
||||
if s.InterfaceName == "" {
|
||||
const defaultInterfaceName = "wg0"
|
||||
s.InterfaceName = defaultInterfaceName
|
||||
}
|
||||
|
||||
if s.Endpoint != nil && s.Endpoint.Port == 0 {
|
||||
const defaultPort = 51820
|
||||
s.Endpoint.Port = defaultPort
|
||||
}
|
||||
|
||||
if s.FirewallMark == 0 {
|
||||
const defaultFirewallMark = 51820
|
||||
s.FirewallMark = defaultFirewallMark
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInterfaceNameInvalid = errors.New("invalid interface name")
|
||||
ErrPrivateKeyMissing = errors.New("private key is missing")
|
||||
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
|
||||
ErrPublicKeyMissing = errors.New("public key is missing")
|
||||
ErrPublicKeyInvalid = errors.New("cannot parse public key")
|
||||
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
|
||||
ErrEndpointMissing = errors.New("endpoint is missing")
|
||||
ErrEndpointIPMissing = errors.New("endpoint IP is missing")
|
||||
ErrEndpointPortMissing = errors.New("endpoint port is missing")
|
||||
ErrAddressMissing = errors.New("interface address is missing")
|
||||
ErrAddressNil = errors.New("interface address is nil")
|
||||
ErrAddressIPMissing = errors.New("interface address IP is missing")
|
||||
ErrAddressMaskMissing = errors.New("interface address mask is missing")
|
||||
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
||||
)
|
||||
|
||||
var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
|
||||
func (s *Settings) Check() (err error) {
|
||||
if !interfaceNameRegexp.MatchString(s.InterfaceName) {
|
||||
return fmt.Errorf("%w: %s", ErrInterfaceNameInvalid, s.InterfaceName)
|
||||
}
|
||||
|
||||
if s.PrivateKey == "" {
|
||||
return ErrPrivateKeyMissing
|
||||
} else if _, err := wgtypes.ParseKey(s.PrivateKey); err != nil {
|
||||
return ErrPrivateKeyInvalid
|
||||
}
|
||||
|
||||
if s.PublicKey == "" {
|
||||
return ErrPublicKeyMissing
|
||||
} else if _, err := wgtypes.ParseKey(s.PublicKey); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrPublicKeyInvalid, s.PublicKey)
|
||||
}
|
||||
|
||||
if s.PreSharedKey != "" {
|
||||
if _, err := wgtypes.ParseKey(s.PreSharedKey); err != nil {
|
||||
return ErrPreSharedKeyInvalid
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case s.Endpoint == nil:
|
||||
return ErrEndpointMissing
|
||||
case s.Endpoint.IP == nil:
|
||||
return ErrEndpointIPMissing
|
||||
case s.Endpoint.Port == 0:
|
||||
return ErrEndpointPortMissing
|
||||
}
|
||||
|
||||
if len(s.Addresses) == 0 {
|
||||
return ErrAddressMissing
|
||||
}
|
||||
for i, addr := range s.Addresses {
|
||||
switch {
|
||||
case addr == nil:
|
||||
return fmt.Errorf("%w: for address %d of %d",
|
||||
ErrAddressNil, i+1, len(s.Addresses))
|
||||
case addr.IP == nil:
|
||||
return fmt.Errorf("%w: for address %d of %d",
|
||||
ErrAddressIPMissing, i+1, len(s.Addresses))
|
||||
case addr.Mask == nil:
|
||||
return fmt.Errorf("%w: for address %d of %d",
|
||||
ErrAddressMaskMissing, i+1, len(s.Addresses))
|
||||
}
|
||||
}
|
||||
|
||||
if s.FirewallMark == 0 {
|
||||
return ErrFirewallMarkMissing
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Settings) String() string {
|
||||
lines := s.ToLines(ToLinesSettings{})
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
type ToLinesSettings struct {
|
||||
// Indent defaults to 4 spaces " ".
|
||||
Indent *string
|
||||
// FieldPrefix defaults to "├── ".
|
||||
FieldPrefix *string
|
||||
// LastFieldPrefix defaults to "└── ".
|
||||
LastFieldPrefix *string
|
||||
}
|
||||
|
||||
func (settings *ToLinesSettings) setDefaults() {
|
||||
toStringPtr := func(s string) *string { return &s }
|
||||
if settings.Indent == nil {
|
||||
settings.Indent = toStringPtr(" ")
|
||||
}
|
||||
if settings.FieldPrefix == nil {
|
||||
settings.FieldPrefix = toStringPtr("├── ")
|
||||
}
|
||||
if settings.LastFieldPrefix == nil {
|
||||
settings.LastFieldPrefix = toStringPtr("└── ")
|
||||
}
|
||||
}
|
||||
|
||||
// ToLines serializes the settings to a slice of strings for display.
|
||||
func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
|
||||
settings.setDefaults()
|
||||
|
||||
indent := *settings.Indent
|
||||
fieldPrefix := *settings.FieldPrefix
|
||||
lastFieldPrefix := *settings.LastFieldPrefix
|
||||
|
||||
lines = append(lines, fieldPrefix+"Interface name: "+s.InterfaceName)
|
||||
const (
|
||||
set = "set"
|
||||
notSet = "not set"
|
||||
)
|
||||
|
||||
isSet := notSet
|
||||
if s.PrivateKey != "" {
|
||||
isSet = set
|
||||
}
|
||||
lines = append(lines, fieldPrefix+"Private key: "+isSet)
|
||||
|
||||
if s.PublicKey != "" {
|
||||
lines = append(lines, fieldPrefix+"PublicKey: "+s.PublicKey)
|
||||
}
|
||||
|
||||
isSet = notSet
|
||||
if s.PreSharedKey != "" {
|
||||
isSet = set
|
||||
}
|
||||
lines = append(lines, fieldPrefix+"Pre shared key: "+isSet)
|
||||
|
||||
endpointStr := notSet
|
||||
if s.Endpoint != nil {
|
||||
endpointStr = s.Endpoint.String()
|
||||
}
|
||||
lines = append(lines, fieldPrefix+"Endpoint: "+endpointStr)
|
||||
|
||||
if s.FirewallMark != 0 {
|
||||
lines = append(lines, fieldPrefix+"Firewall mark: "+fmt.Sprint(s.FirewallMark))
|
||||
}
|
||||
|
||||
if s.RulePriority != 0 {
|
||||
lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority))
|
||||
}
|
||||
|
||||
if len(s.Addresses) == 0 {
|
||||
lines = append(lines, lastFieldPrefix+"Addresses: "+notSet)
|
||||
} else {
|
||||
lines = append(lines, lastFieldPrefix+"Addresses:")
|
||||
for i, address := range s.Addresses {
|
||||
prefix := fieldPrefix
|
||||
if i == len(s.Addresses)-1 {
|
||||
prefix = lastFieldPrefix
|
||||
}
|
||||
lines = append(lines, indent+prefix+address.String())
|
||||
}
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
377
internal/wireguard/settings_test.go
Normal file
377
internal/wireguard/settings_test.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Settings_SetDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
original Settings
|
||||
expected Settings
|
||||
}{
|
||||
"empty settings": {
|
||||
expected: Settings{
|
||||
InterfaceName: "wg0",
|
||||
FirewallMark: 51820,
|
||||
},
|
||||
},
|
||||
"default endpoint port": {
|
||||
original: Settings{
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
},
|
||||
},
|
||||
expected: Settings{
|
||||
InterfaceName: "wg0",
|
||||
FirewallMark: 51820,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
},
|
||||
"not empty settings": {
|
||||
original: Settings{
|
||||
InterfaceName: "wg1",
|
||||
FirewallMark: 999,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 9999,
|
||||
},
|
||||
},
|
||||
expected: Settings{
|
||||
InterfaceName: "wg1",
|
||||
FirewallMark: 999,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 9999,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCase.original.SetDefaults()
|
||||
|
||||
assert.Equal(t, testCase.expected, testCase.original)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Settings_Check(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
validKey1 = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
|
||||
validKey2 = "aPjc9US5ICB30D1P4glR9tO7bkB2Ga+KZiFqnoypBHk="
|
||||
)
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
err error
|
||||
}{
|
||||
"empty settings": {
|
||||
err: errors.New("invalid interface name: "),
|
||||
},
|
||||
"bad interface name": {
|
||||
settings: Settings{
|
||||
InterfaceName: "$H1T",
|
||||
},
|
||||
err: errors.New("invalid interface name: $H1T"),
|
||||
},
|
||||
"empty private key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
},
|
||||
err: ErrPrivateKeyMissing,
|
||||
},
|
||||
"bad private key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: "bad key",
|
||||
},
|
||||
err: ErrPrivateKeyInvalid,
|
||||
},
|
||||
"empty public key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
},
|
||||
err: ErrPublicKeyMissing,
|
||||
},
|
||||
"bad public key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: "bad key",
|
||||
},
|
||||
err: errors.New("cannot parse public key: bad key"),
|
||||
},
|
||||
"bad preshared key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
PreSharedKey: "bad key",
|
||||
},
|
||||
err: errors.New("cannot parse pre-shared key"),
|
||||
},
|
||||
"empty endpoint": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
},
|
||||
err: ErrEndpointMissing,
|
||||
},
|
||||
"nil endpoint IP": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{},
|
||||
},
|
||||
err: ErrEndpointIPMissing,
|
||||
},
|
||||
"nil endpoint port": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
},
|
||||
},
|
||||
err: ErrEndpointPortMissing,
|
||||
},
|
||||
"no address": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
err: ErrAddressMissing,
|
||||
},
|
||||
"nil address": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{nil},
|
||||
},
|
||||
err: errors.New("interface address is nil: for address 1 of 1"),
|
||||
},
|
||||
"nil address IP": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{}},
|
||||
},
|
||||
err: errors.New("interface address IP is missing: for address 1 of 1"),
|
||||
},
|
||||
"nil address mask": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4)}},
|
||||
},
|
||||
err: errors.New("interface address mask is missing: for address 1 of 1"),
|
||||
},
|
||||
"zero firewall mark": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
|
||||
},
|
||||
err: ErrFirewallMarkMissing,
|
||||
},
|
||||
"all valid": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
|
||||
FirewallMark: 999,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := testCase.settings.Check()
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func toStringPtr(s string) *string { return &s }
|
||||
|
||||
func Test_ToLinesSettings_setDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
settings := ToLinesSettings{
|
||||
Indent: toStringPtr("indent"),
|
||||
}
|
||||
|
||||
someFunc := func(settings ToLinesSettings) {
|
||||
settings.setDefaults()
|
||||
expectedSettings := ToLinesSettings{
|
||||
Indent: toStringPtr("indent"),
|
||||
FieldPrefix: toStringPtr("├── "),
|
||||
LastFieldPrefix: toStringPtr("└── "),
|
||||
}
|
||||
assert.Equal(t, expectedSettings, settings)
|
||||
}
|
||||
someFunc(settings)
|
||||
|
||||
untouchedSettings := ToLinesSettings{
|
||||
Indent: toStringPtr("indent"),
|
||||
}
|
||||
assert.Equal(t, untouchedSettings, settings)
|
||||
}
|
||||
|
||||
func Test_Settings_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
settings := Settings{
|
||||
InterfaceName: "wg0",
|
||||
}
|
||||
const expected = `├── Interface name: wg0
|
||||
├── Private key: not set
|
||||
├── Pre shared key: not set
|
||||
├── Endpoint: not set
|
||||
└── Addresses: not set`
|
||||
s := settings.String()
|
||||
assert.Equal(t, expected, s)
|
||||
}
|
||||
|
||||
func Test_Settings_Lines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
lineSettings ToLinesSettings
|
||||
lines []string
|
||||
}{
|
||||
"empty settings": {
|
||||
lines: []string{
|
||||
"├── Interface name: ",
|
||||
"├── Private key: not set",
|
||||
"├── Pre shared key: not set",
|
||||
"├── Endpoint: not set",
|
||||
"└── Addresses: not set",
|
||||
},
|
||||
},
|
||||
"settings all set": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: "private key",
|
||||
PublicKey: "public key",
|
||||
PreSharedKey: "pre-shared key",
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
FirewallMark: 999,
|
||||
RulePriority: 888,
|
||||
Addresses: []*net.IPNet{
|
||||
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
|
||||
},
|
||||
},
|
||||
lines: []string{
|
||||
"├── Interface name: wg0",
|
||||
"├── Private key: set",
|
||||
"├── PublicKey: public key",
|
||||
"├── Pre shared key: set",
|
||||
"├── Endpoint: 1.2.3.4:51820",
|
||||
"├── Firewall mark: 999",
|
||||
"├── Rule priority: 888",
|
||||
"└── Addresses:",
|
||||
" ├── 1.1.1.1/24",
|
||||
" └── 2.2.2.2/32",
|
||||
},
|
||||
},
|
||||
"custom line settings": {
|
||||
lineSettings: ToLinesSettings{
|
||||
Indent: toStringPtr(" "),
|
||||
FieldPrefix: toStringPtr("- "),
|
||||
LastFieldPrefix: toStringPtr("* "),
|
||||
},
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
Addresses: []*net.IPNet{
|
||||
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
|
||||
},
|
||||
},
|
||||
lines: []string{
|
||||
"- Interface name: wg0",
|
||||
"- Private key: not set",
|
||||
"- Pre shared key: not set",
|
||||
"- Endpoint: not set",
|
||||
"* Addresses:",
|
||||
" - 1.1.1.1/24",
|
||||
" * 2.2.2.2/32",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lines := testCase.settings.ToLines(testCase.lineSettings)
|
||||
|
||||
assert.Equal(t, testCase.lines, lines)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
- Filter servers by protocol for all
|
||||
- Multiple IPs addresses support for all proviedrs
|
||||
- Use `internal/netlink` in firewall and routing packages
|
||||
|
||||
## Code
|
||||
|
||||
|
||||
Reference in New Issue
Block a user