Wireguard support for Mullvad and Windscribe (#565)

- `internal/wireguard` client package with unit tests
- Implementation works with kernel space or user space if unavailable
- `WIREGUARD_PRIVATE_KEY`
- `WIREGUARD_ADDRESS`
- `WIREGUARD_PRESHARED_KEY`
- `WIREGUARD_PORT`
- `internal/netlink` package used by `internal/wireguard`
This commit is contained in:
Quentin McGaw
2021-08-22 14:58:39 -07:00
committed by GitHub
parent 0bfd58a3f5
commit 614eb10d67
70 changed files with 13595 additions and 148 deletions

View File

@@ -1 +1,2 @@
FROM qmcgaw/godevcontainer
RUN apk add wireguard-tools

View File

@@ -4,6 +4,8 @@ services:
vscode:
build: .
image: godevcontainer
devices:
- /dev/net/tun:/dev/net/tun
volumes:
- ../:/workspace
# Docker socket to access Docker server
@@ -23,7 +25,8 @@ services:
- TZ=
cap_add:
# For debugging with dlv
- SYS_PTRACE
# - SYS_PTRACE
- NET_ADMIN
security_opt:
# For debugging with dlv
- seccomp:unconfined

3
.github/labels.yml vendored
View File

@@ -70,6 +70,9 @@
- name: "Openvpn"
color: "ffc7ea"
description: ""
- name: "Wireguard"
color: "ffc7ea"
description: ""
- name: "Unbound (DNS over TLS)"
color: "ffc7ea"
description: ""

View File

@@ -68,6 +68,7 @@ LABEL \
org.opencontainers.image.description="VPN swiss-knife like client to tunnel to multiple VPN servers using OpenVPN, IPtables, DNS over TLS, Shadowsocks, an HTTP proxy and Alpine Linux"
ENV VPNSP=pia \
VERSION_INFORMATION=on \
VPN_TYPE=openvpn \
PROTOCOL=udp \
OPENVPN_VERSION=2.5 \
OPENVPN_VERBOSITY=1 \
@@ -77,6 +78,11 @@ ENV VPNSP=pia \
OPENVPN_IPV6=off \
OPENVPN_CUSTOM_CONFIG= \
OPENVPN_INTERFACE=tun0 \
WIREGUARD_PRIVATE_KEY= \
WIREGUARD_PRESHARED_KEY= \
WIREGUARD_ADDRESS= \
WIREGUARD_PORT= \
WIREGUARD_INTERFACE=wg0 \
TZ= \
PUID= \
PGID= \

View File

@@ -5,7 +5,7 @@ HideMyAss, IPVanish, IVPN, Mullvad, NordVPN, Privado, Private Internet Access, P
ProtonVPN, PureVPN, Surfshark, TorGuard, VPNUnlimited, VyprVPN and Windscribe VPN servers
using Go, OpenVPN, iptables, DNS over TLS, ShadowSocks and an HTTP proxy*
**ANNOUNCEMENT**: You can try Wireguard, see #565
**ANNOUNCEMENT**: Wireguard is now supported for all providers supporting it!
![Title image](https://raw.githubusercontent.com/qdm12/gluetun/master/title.svg)
@@ -55,9 +55,10 @@ using Go, OpenVPN, iptables, DNS over TLS, ShadowSocks and an HTTP proxy*
## Features
- Based on Alpine 3.14 for a small Docker image of 30MB
- Based on Alpine 3.14 for a small Docker image of 31MB
- Supports: **Cyberghost**, **FastestVPN**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **Surfshark**, **TorGuard**, **VPNUnlimited**, **Vyprvpn**, **Windscribe** servers
- Supports OpenVPN and Wireguard (the latter in progress, see PR #565 and issue #134)
- Supports OpenVPN
- Supports Wireguard for **Mullvad** and **Windscribe** (more in progress, see #134)
- DNS over TLS baked in with service provider(s) of your choice
- DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses, with live update every 24 hours
- Choose the vpn network protocol, `udp` or `tcp`
@@ -110,6 +111,7 @@ The following points are all optional but should give you insights on all the po
- [Test your setup](https://github.com/qdm12/gluetun/wiki/Test-your-setup)
- [How to connect other containers and devices to Gluetun](https://github.com/qdm12/gluetun/wiki/Connect-to-gluetun)
- [How to use Wireguard](https://github.com/qdm12/gluetun/wiki/Wireguard)
- [VPN server side port forwarding](https://github.com/qdm12/gluetun/wiki/Port-forwarding)
- [HTTP control server](https://github.com/qdm12/gluetun/wiki/HTTP-Control-server) to automate things, restart Openvpn etc.
- Update the image with `docker pull qmcgaw/gluetun:latest`. See this [Wiki document](https://github.com/qdm12/gluetun/wiki/Docker-image-tags) for Docker tags available.

View File

@@ -22,6 +22,7 @@ import (
"github.com/qdm12/gluetun/internal/healthcheck"
"github.com/qdm12/gluetun/internal/httpproxy"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/publicip"
@@ -69,13 +70,14 @@ func main() {
args := os.Args
tun := tun.New()
netLinker := netlink.New()
cli := cli.New()
env := params.NewEnv()
cmder := command.NewCmder()
errorCh := make(chan error)
go func() {
errorCh <- _main(ctx, buildInfo, args, logger, env, tun, cmder, cli)
errorCh <- _main(ctx, buildInfo, args, logger, env, tun, netLinker, cmder, cli)
}()
select {
@@ -116,7 +118,8 @@ var (
//nolint:gocognit,gocyclo
func _main(ctx context.Context, buildInfo models.BuildInformation,
args []string, logger logging.ParentLogger, env params.Env,
tun tun.Interface, cmder command.RunStarter, cli cli.CLIer) error {
tun tun.Interface, netLinker netlink.NetLinker, cmder command.RunStarter,
cli cli.CLIer) error {
if len(args) > 1 { // cli operation
switch args[1] {
case "healthcheck":
@@ -153,7 +156,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
dnsConf := unbound.NewConfigurator(nil, cmder, dnsCrypto,
"/etc/unbound", "/usr/sbin/unbound", cacertsPath)
announcementExp, err := time.Parse(time.RFC3339, "2021-07-22T00:00:00Z")
announcementExp, err := time.Parse(time.RFC3339, "2021-10-02T00:00:00Z")
if err != nil {
return err
}
@@ -164,7 +167,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
Version: buildInfo.Version,
Commit: buildInfo.Commit,
BuildDate: buildInfo.Created,
Announcement: "",
Announcement: "Wireguard is now supported!",
AnnounceExp: announcementExp,
// Sponsor information
PaypalUser: "qmcgaw",
@@ -357,12 +360,12 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
vpnLogger := logger.NewChild(logging.Settings{Prefix: "vpn: "})
vpnLooper := vpn.NewLoop(allSettings.VPN,
allServers, ovpnConf, firewallConf, routingConf, portForwardLooper,
allServers, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper,
cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient,
buildInfo, allSettings.VersionInformation)
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
go vpnLooper.Run(openvpnCtx, openvpnDone)
vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler(
"vpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
go vpnLooper.Run(vpnCtx, vpnDone)
updaterLooper := updater.NewLooper(allSettings.Updater,
allServers, storage, vpnLooper.SetServers, httpClient,
@@ -417,7 +420,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
}
orderHandler := goshutdown.NewOrder("gluetun", orderSettings)
orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler,
openvpnHandler, portForwardHandler, otherGroupHandler)
vpnHandler, portForwardHandler, otherGroupHandler)
// Start VPN for the first time in a blocking call
// until the VPN is launched

View File

@@ -15,10 +15,11 @@ services:
volumes:
- /yourpath:/gluetun
environment:
# More variables are available, see the readme table
# More variables are available, see the Wiki table
- OPENVPN_USER=
- OPENVPN_PASSWORD=
- VPNSP=private internet access
- VPN_TYPE=openvpn
# Timezone for accurate logs times
- TZ=
restart: always

8
go.mod
View File

@@ -14,13 +14,19 @@ require (
github.com/stretchr/testify v1.7.0
github.com/vishvananda/netlink v1.1.0
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c
inet.af/netaddr v0.0.0-20210718074554-06ca8145d722
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.5.5 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-isatty v0.0.12 // indirect
github.com/mdlayher/genetlink v1.0.0 // indirect
github.com/mdlayher/netlink v1.4.0 // indirect
github.com/miekg/dns v1.1.40 // indirect
github.com/mr-tron/base58 v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -29,6 +35,6 @@ require (
go4.org/intern v0.0.0-20210108033219-3eb7198706b2 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 // indirect
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 // indirect
golang.org/x/net v0.0.0-20210504132125-bbd867fde50d // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

73
go.sum
View File

@@ -33,11 +33,28 @@ github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3K
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gotify/go-api-client/v2 v2.0.4/go.mod h1:VKiah/UK20bXsr0JObE1eBVLW44zbBouzjuri9iwjFU=
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 h1:uhL5Gw7BINiiPAo24A2sxkcDI0Jt/sqp1v5xQCniEFA=
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok=
github.com/jsimonetti/rtnetlink v0.0.0-20201216134343-bde56ed16391/go.mod h1:cR77jAZG3Y3bsb8hF6fHJbFoyFukLFOkQ98S0pQz3xw=
github.com/jsimonetti/rtnetlink v0.0.0-20201220180245-69540ac93943/go.mod h1:z4c53zj6Eex712ROyh8WI0ihysb5j2ROyV42iNogmAs=
github.com/jsimonetti/rtnetlink v0.0.0-20210122163228-8d122574c736/go.mod h1:ZXpIyOK59ZnN7J0BV99cZUPmsqDRZ3eq5X+st7u/oSA=
github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b h1:c3NTyLNozICy8B4mlMXemD3z/gXgQzVXZS/HqT+i3do=
github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U=
github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -51,8 +68,24 @@ github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43 h1:WgyLFv10Ov49JAQI/ZLUkCZ7VJS3r74hwFIGXJsgZlY=
github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo=
github.com/mdlayher/genetlink v1.0.0 h1:OoHN1OdyEIkScEmRgxLEe2M9U8ClMytqA5niynLtfj0=
github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc=
github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA=
github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M=
github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY=
github.com/mdlayher/netlink v1.1.1/go.mod h1:WTYpFb/WTvlRJAyKhZL5/uy69TDDpHHu2VZmb2XgV7o=
github.com/mdlayher/netlink v1.2.0/go.mod h1:kwVW1io0AZy9A1E2YYgaD4Cj+C+GPkU6klXCMzIJ9p8=
github.com/mdlayher/netlink v1.2.1/go.mod h1:bacnNlfhqHqqLo4WsYeXSqfyXkInQ9JneWI68v1KwSU=
github.com/mdlayher/netlink v1.2.2-0.20210123213345-5cc92139ae3e/go.mod h1:bacnNlfhqHqqLo4WsYeXSqfyXkInQ9JneWI68v1KwSU=
github.com/mdlayher/netlink v1.3.0/go.mod h1:xK/BssKuwcRXHrtN04UBkwQ6dY9VviGGuriDdoPSWys=
github.com/mdlayher/netlink v1.4.0 h1:n3ARR+Fm0dDv37dj5wSWZXDKcy+U0zwcXS3zKMnSiT0=
github.com/mdlayher/netlink v1.4.0/go.mod h1:dRJi5IABcZpBD2A3D0Mv/AiX8I9uDEu5oGkAVrekmf8=
github.com/miekg/dns v1.1.40 h1:pyyPFfGMnciYUk/mXpKkVmeMQjfXqt3FAJ2hy7tPiLA=
github.com/miekg/dns v1.1.40/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
@@ -106,6 +139,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210503195802-e9a32991a82e/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
@@ -113,38 +148,67 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191007182048-72f939374954/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210504132125-bbd867fde50d h1:nTDGCTeAu2LhcsHTRzjyIUbZHCJ4QePArsm27Hka0UM=
golang.org/x/net v0.0.0-20210504132125-bbd867fde50d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190221075227-b4e8571b14e0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201118182958-a01c418693c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201218084310-7d0127a74742/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210110051926-789bb1bd4061/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210123111255-9b0068b26619/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210216163648-f7da38b97c65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210309040221-94ec62e08169/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210503173754-0981d6026fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190729092621-ff9f1409240a/go.mod h1:jcCCGcm9btYwXyDqrUWc6MKQKKGJCWEQ3AfLSRIbEuI=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
@@ -153,7 +217,14 @@ golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20210427022245-097af6e1351b/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 h1:ab2jcw2W91Rz07eHAb8Lic7sFQKO0NhBftjv6m/gL/0=
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19/go.mod h1:laHzsbfMhGSobUmruXWAyMKKHSqvIcrqZJMyHD+/3O8=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c h1:ADNrRDI5NR23/TUCnEmlLZLt4u9DnZ2nwRkPrAcFvto=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c/go.mod h1:+1XihzyZUBJcSc5WO9SwNA7v26puQwOEDwanaxfNXPQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -49,7 +49,7 @@ func (settings *OpenVPNSelection) readMullvad(env params.Env) (err error) {
return err
}
settings.CustomPort, err = readCustomPort(env, settings.TCP,
settings.CustomPort, err = readOpenVPNCustomPort(env, settings.TCP,
[]uint16{80, 443, 1401}, []uint16{53, 1194, 1195, 1196, 1197, 1300, 1301, 1302, 1303, 1400})
if err != nil {
return err

View File

@@ -43,7 +43,7 @@ var (
)
func (settings *Provider) read(r reader, vpnType string) error {
err := settings.readVPNServiceProvider(r)
err := settings.readVPNServiceProvider(r, vpnType)
if err != nil {
return err
}
@@ -94,11 +94,17 @@ func (settings *Provider) read(r reader, vpnType string) error {
return nil
}
func (settings *Provider) readVPNServiceProvider(r reader) (err error) {
allowedVPNServiceProviders := []string{
"cyberghost", "fastestvpn", "hidemyass", "ipvanish", "ivpn", "mullvad", "nordvpn",
"privado", "pia", "private internet access", "privatevpn", "protonvpn",
"purevpn", "surfshark", "torguard", constants.VPNUnlimited, "vyprvpn", "windscribe"}
func (settings *Provider) readVPNServiceProvider(r reader, vpnType string) (err error) {
var allowedVPNServiceProviders []string
switch vpnType {
case constants.OpenVPN:
allowedVPNServiceProviders = []string{
"cyberghost", "fastestvpn", "hidemyass", "ipvanish", "ivpn", "mullvad", "nordvpn",
"privado", "pia", "private internet access", "privatevpn", "protonvpn",
"purevpn", "surfshark", "torguard", constants.VPNUnlimited, "vyprvpn", "windscribe"}
case constants.Wireguard:
allowedVPNServiceProviders = []string{constants.Mullvad, constants.Windscribe}
}
vpnsp, err := r.env.Inside("VPNSP", allowedVPNServiceProviders,
params.Default("private internet access"))
@@ -132,7 +138,7 @@ func readTargetIP(env params.Env) (targetIP net.IP, err error) {
return targetIP, nil
}
func readCustomPort(env params.Env, tcp bool,
func readOpenVPNCustomPort(env params.Env, tcp bool,
allowedTCP, allowedUDP []uint16) (port uint16, err error) {
port, err = readPortOrZero(env, "PORT")
if err != nil {
@@ -147,12 +153,42 @@ func readCustomPort(env params.Env, tcp bool,
return port, nil
}
}
return 0, fmt.Errorf("environment variable PORT: %w: port %d for TCP protocol", ErrInvalidPort, port)
return 0, fmt.Errorf(
"environment variable PORT: %w: port %d for TCP protocol, can only be one of %s",
ErrInvalidPort, port, portsToString(allowedTCP))
}
for i := range allowedUDP {
if allowedUDP[i] == port {
return port, nil
}
}
return 0, fmt.Errorf("environment variable PORT: %w: port %d for UDP protocol", ErrInvalidPort, port)
return 0, fmt.Errorf(
"environment variable PORT: %w: port %d for UDP protocol, can only be one of %s",
ErrInvalidPort, port, portsToString(allowedUDP))
}
func readWireguardCustomPort(env params.Env, allowed []uint16) (port uint16, err error) {
port, err = readPortOrZero(env, "WIREGUARD_PORT")
if err != nil {
return 0, fmt.Errorf("environment variable WIREGUARD_PORT: %w", err)
} else if port == 0 {
return 0, nil
}
for i := range allowed {
if allowed[i] == port {
return port, nil
}
}
return 0, fmt.Errorf(
"environment variable WIREGUARD_PORT: %w: port %d, can only be one of %s",
ErrInvalidPort, port, portsToString(allowed))
}
func portsToString(ports []uint16) string {
slice := make([]string, len(ports))
for i := range ports {
slice[i] = fmt.Sprint(ports[i])
}
return strings.Join(slice, ", ")
}

View File

@@ -24,6 +24,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Cyberghost,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Groups: []string{"group"},
Regions: []string{"a", "El country"},
},
@@ -40,6 +41,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Fastestvpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Hostnames: []string{"a", "b"},
Countries: []string{"c", "d"},
},
@@ -56,6 +58,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.HideMyAss,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},
@@ -74,6 +77,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Ipvanish,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},
@@ -92,6 +96,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Ivpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},
@@ -110,6 +115,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Mullvad,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
ISPs: []string{"e", "f"},
@@ -132,6 +138,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Nordvpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
Numbers: []uint16{1, 2},
},
@@ -148,6 +155,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Privado,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Hostnames: []string{"a", "b"},
},
},
@@ -162,6 +170,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Privatevpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Hostnames: []string{"a", "b"},
Countries: []string{"c", "d"},
Cities: []string{"e", "f"},
@@ -180,6 +189,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Protonvpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Regions: []string{"c", "d"},
Cities: []string{"e", "f"},
@@ -202,6 +212,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.PrivateInternetAccess,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
OpenVPN: OpenVPNSelection{
CustomPort: 1,
@@ -226,6 +237,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Purevpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
Countries: []string{"c", "d"},
Cities: []string{"e", "f"},
@@ -244,6 +256,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Surfshark,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
},
},
@@ -258,6 +271,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Torguard,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e"},
@@ -276,6 +290,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.VPNUnlimited,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},
@@ -298,6 +313,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Vyprvpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
},
},
@@ -312,6 +328,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Windscribe,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},

View File

@@ -109,12 +109,12 @@ func readIP(env params.Env, key string) (ip net.IP, err error) {
}
func readPortOrZero(env params.Env, key string) (port uint16, err error) {
s, err := env.Get(key)
s, err := env.Get(key, params.Default("0"))
if err != nil {
return 0, err
}
if s == "" || s == "0" {
if s == "0" {
return 0, nil
}

View File

@@ -4,12 +4,13 @@ import (
"fmt"
"net"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/params"
)
type ServerSelection struct { //nolint:maligned
// Common
VPN string `json:"vpn"`
VPN string `json:"vpn"` // note: this is required
TargetIP net.IP `json:"target_ip,omitempty"`
// TODO comments
// Cyberghost, PIA, Protonvpn, Surfshark, Windscribe, Vyprvpn, NordVPN
@@ -39,7 +40,8 @@ type ServerSelection struct { //nolint:maligned
// VPNUnlimited
StreamOnly bool `json:"stream_only"`
OpenVPN OpenVPNSelection `json:"openvpn"`
OpenVPN OpenVPNSelection `json:"openvpn"`
Wireguard WireguardSelection `json:"wireguard"`
}
func (selection ServerSelection) toLines() (lines []string) {
@@ -91,7 +93,11 @@ func (selection ServerSelection) toLines() (lines []string) {
lines = append(lines, lastIndent+"Numbers: "+commaJoin(numbersString))
}
lines = append(lines, selection.OpenVPN.lines()...)
if selection.VPN == constants.OpenVPN {
lines = append(lines, selection.OpenVPN.lines()...)
} else { // wireguard
lines = append(lines, selection.Wireguard.lines()...)
}
return lines
}
@@ -137,6 +143,20 @@ func (settings *OpenVPNSelection) readProtocolAndPort(env params.Env) (err error
return nil
}
type WireguardSelection struct {
CustomPort uint16 `json:"custom_port"` // Mullvad
}
func (settings *WireguardSelection) lines() (lines []string) {
lines = append(lines, lastIndent+"Wireguard selection:")
if settings.CustomPort != 0 {
lines = append(lines, indent+lastIndent+"Custom port: "+fmt.Sprint(settings.CustomPort))
}
return lines
}
// PortForwarding contains settings for port forwarding.
type PortForwarding struct {
Enabled bool `json:"enabled"`

View File

@@ -20,6 +20,9 @@ func Test_Settings_lines(t *testing.T) {
Type: constants.OpenVPN,
Provider: Provider{
Name: constants.Mullvad,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
},
},
OpenVPN: OpenVPN{
Version: constants.Openvpn25,

View File

@@ -10,9 +10,10 @@ import (
)
type VPN struct {
Type string `json:"type"`
OpenVPN OpenVPN `json:"openvpn"`
Provider Provider `json:"provider"`
Type string `json:"type"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
Provider Provider `json:"provider"`
}
func (settings *VPN) String() string {
@@ -24,7 +25,14 @@ func (settings *VPN) lines() (lines []string) {
lines = append(lines, indent+lastIndent+"Type: "+settings.Type)
for _, line := range settings.OpenVPN.lines() {
var vpnLines []string
switch settings.Type {
case constants.OpenVPN:
vpnLines = settings.OpenVPN.lines()
case constants.Wireguard:
vpnLines = settings.Wireguard.lines()
}
for _, line := range vpnLines {
lines = append(lines, indent+line)
}
@@ -36,13 +44,15 @@ func (settings *VPN) lines() (lines []string) {
}
var (
errReadProviderSettings = errors.New("cannot read provider settings")
errReadOpenVPNSettings = errors.New("cannot read OpenVPN settings")
errReadProviderSettings = errors.New("cannot read provider settings")
errReadOpenVPNSettings = errors.New("cannot read OpenVPN settings")
errReadWireguardSettings = errors.New("cannot read Wireguard settings")
)
func (settings *VPN) read(r reader) (err error) {
vpnType, err := r.env.Inside("VPN_TYPE",
[]string{constants.OpenVPN}, params.Default(constants.OpenVPN))
[]string{constants.OpenVPN, constants.Wireguard},
params.Default(constants.OpenVPN))
if err != nil {
return fmt.Errorf("environment variable VPN_TYPE: %w", err)
}
@@ -54,9 +64,17 @@ func (settings *VPN) read(r reader) (err error) {
}
}
err = settings.OpenVPN.read(r, settings.Provider.Name)
if err != nil {
return fmt.Errorf("%w: %s", errReadOpenVPNSettings, err)
switch settings.Type {
case constants.OpenVPN:
err = settings.OpenVPN.read(r, settings.Provider.Name)
if err != nil {
return fmt.Errorf("%w: %s", errReadOpenVPNSettings, err)
}
case constants.Wireguard:
err = settings.Wireguard.read(r)
if err != nil {
return fmt.Errorf("%w: %s", errReadWireguardSettings, err)
}
}
return nil

View File

@@ -30,7 +30,12 @@ func (settings *Provider) readWindscribe(r reader) (err error) {
return fmt.Errorf("environment variable SERVER_HOSTNAME: %w", err)
}
return settings.ServerSelection.OpenVPN.readWindscribe(r.env)
err = settings.ServerSelection.OpenVPN.readWindscribe(r.env)
if err != nil {
return err
}
return settings.ServerSelection.Wireguard.readWindscribe(r.env)
}
func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
@@ -39,7 +44,7 @@ func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
return err
}
settings.CustomPort, err = readCustomPort(env, settings.TCP,
settings.CustomPort, err = readOpenVPNCustomPort(env, settings.TCP,
[]uint16{21, 22, 80, 123, 143, 443, 587, 1194, 3306, 8080, 54783},
[]uint16{53, 80, 123, 443, 1194, 54783})
if err != nil {
@@ -48,3 +53,13 @@ func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
return nil
}
func (settings *WireguardSelection) readWindscribe(env params.Env) (err error) {
settings.CustomPort, err = readWireguardCustomPort(env,
[]uint16{53, 80, 123, 443, 1194, 65142})
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,73 @@
package configuration
import (
"fmt"
"net"
"strings"
"github.com/qdm12/golibs/params"
)
// Wireguard contains settings to configure the Wireguard client.
type Wireguard struct {
PrivateKey string `json:"privatekey"`
PreSharedKey string `json:"presharedkey"`
Address *net.IPNet `json:"address"`
Interface string `json:"interface"`
}
func (settings *Wireguard) String() string {
return strings.Join(settings.lines(), "\n")
}
func (settings *Wireguard) lines() (lines []string) {
lines = append(lines, lastIndent+"Wireguard:")
lines = append(lines, indent+lastIndent+"Network interface: "+settings.Interface)
if settings.PrivateKey != "" {
lines = append(lines, indent+lastIndent+"Private key is set")
}
if settings.PreSharedKey != "" {
lines = append(lines, indent+lastIndent+"Pre-shared key is set")
}
if settings.Address != nil {
lines = append(lines, indent+lastIndent+"Address: "+settings.Address.String())
}
return lines
}
func (settings *Wireguard) read(r reader) (err error) {
settings.PrivateKey, err = r.env.Get("WIREGUARD_PRIVATE_KEY",
params.CaseSensitiveValue(), params.Unset(), params.Compulsory())
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_PRIVATE_KEY: %w", err)
}
settings.PreSharedKey, err = r.env.Get("WIREGUARD_PRESHARED_KEY",
params.CaseSensitiveValue(), params.Unset())
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_PRESHARED_KEY: %w", err)
}
addressString, err := r.env.Get("WIREGUARD_ADDRESS", params.Compulsory())
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_ADDRESS: %w", err)
}
ip, ipNet, err := net.ParseCIDR(addressString)
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_ADDRESS: %w", err)
}
ipNet.IP = ip
settings.Address = ipNet
settings.Interface, err = r.env.Get("WIREGUARD_INTERFACE", params.Default("wg0"))
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_INTERFACE: %w", err)
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -73,7 +73,7 @@ func Test_versions(t *testing.T) {
"Mullvad": {
model: models.MullvadServer{},
version: allServers.Mullvad.Version,
digest: "2a009192",
digest: "ec56f19d",
},
"Nordvpn": {
model: models.NordvpnServer{},
@@ -128,7 +128,7 @@ func Test_versions(t *testing.T) {
"Windscribe": {
model: models.WindscribeServer{},
version: allServers.Windscribe.Version,
digest: "6f6c16d6",
digest: "4bd0fc4f",
},
}
for name, testCase := range testCases {

View File

@@ -1,7 +1,8 @@
package constants
const (
OpenVPN = "openvpn"
OpenVPN = "openvpn"
Wireguard = "wireguard"
)
const (

View File

@@ -6,7 +6,7 @@ import (
)
type Connection struct {
// Type is the connection type and can be "openvpn"
// Type is the connection type and can be "openvpn" or "wireguard"
Type string `json:"type"`
// IP is the VPN server IP address.
IP net.IP `json:"ip"`
@@ -15,13 +15,17 @@ type Connection struct {
// Protocol can be "tcp" or "udp".
Protocol string `json:"protocol"`
// Hostname is used for IPVanish, IVPN, Privado
// and Windscribe for TLS verification
// and Windscribe for TLS verification.
Hostname string `json:"hostname"`
// PubKey is the public key of the VPN server,
// used only for Wireguard.
PubKey string `json:"pubkey"`
}
func (c *Connection) Equal(other Connection) bool {
return c.IP.Equal(other.IP) && c.Port == other.Port &&
c.Protocol == other.Protocol && c.Hostname == other.Hostname
c.Protocol == other.Protocol && c.Hostname == other.Hostname &&
c.PubKey == other.PubKey
}
func (c Connection) OpenVPNRemoteLine() (line string) {

View File

@@ -48,6 +48,7 @@ type IvpnServer struct {
}
type MullvadServer struct {
VPN string `json:"vpn"`
IPs []net.IP `json:"ips"`
IPsV6 []net.IP `json:"ipsv6"`
Country string `json:"country"`
@@ -55,6 +56,7 @@ type MullvadServer struct {
Hostname string `json:"hostname"`
ISP string `json:"isp"`
Owned bool `json:"owned"`
WgPubKey string `json:"wgpubkey,omitempty"`
}
type NordvpnServer struct { //nolint:maligned
@@ -149,9 +151,11 @@ type VyprvpnServer struct {
}
type WindscribeServer struct {
VPN string `json:"vpn"`
Region string `json:"region"`
City string `json:"city"`
Hostname string `json:"hostname"`
OvpnX509 string `json:"x509"`
OvpnX509 string `json:"x509,omitempty"`
WgPubKey string `json:"wgpubkey,omitempty"`
IPs []net.IP `json:"ips"`
}

View File

@@ -0,0 +1,7 @@
package netlink
import "github.com/vishvananda/netlink"
func (n *NetLink) AddrAdd(link netlink.Link, addr *netlink.Addr) error {
return netlink.AddrAdd(link, addr)
}

View File

@@ -0,0 +1,14 @@
package netlink
import "github.com/vishvananda/netlink"
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . NetLinker
var _ NetLinker = (*NetLink)(nil)
type NetLinker interface {
AddrAdd(link netlink.Link, addr *netlink.Addr) error
RouteAdd(route *netlink.Route) error
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
}

View File

@@ -0,0 +1,7 @@
package netlink
type NetLink struct{}
func New() *NetLink {
return &NetLink{}
}

View File

@@ -0,0 +1,7 @@
package netlink
import "github.com/vishvananda/netlink"
func (n *NetLink) RouteAdd(route *netlink.Route) error {
return netlink.RouteAdd(route)
}

11
internal/netlink/rule.go Normal file
View File

@@ -0,0 +1,11 @@
package netlink
import "github.com/vishvananda/netlink"
func (n *NetLink) RuleAdd(rule *netlink.Rule) error {
return netlink.RuleAdd(rule)
}
func (n *NetLink) RuleDel(rule *netlink.Rule) error {
return netlink.RuleDel(rule)
}

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -19,7 +20,8 @@ func Test_Cyberghost_filterServers(t *testing.T) {
err error
}{
"no servers": {
err: errors.New("no server found: for protocol udp"),
selection: configuration.ServerSelection{VPN: constants.OpenVPN},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"servers without filter defaults to UDP": {
servers: []models.CyberghostServer{

View File

@@ -9,16 +9,8 @@ import (
func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
connection models.Connection, err error) {
var port uint16 = 1194
protocol := constants.UDP
if selection.OpenVPN.TCP {
port = 443
protocol = constants.TCP
}
if selection.OpenVPN.CustomPort > 0 {
port = selection.OpenVPN.CustomPort
}
port := getPort(selection)
protocol := getProtocol(selection)
servers, err := m.filterServers(selection)
if err != nil {
@@ -33,6 +25,7 @@ func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
IP: IP,
Port: port,
Protocol: protocol,
PubKey: server.WgPubKey, // Wireguard only
}
connections = append(connections, connection)
}
@@ -44,3 +37,33 @@ func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
return utils.PickRandomConnection(connections, m.randSource), nil
}
func getPort(selection configuration.ServerSelection) (port uint16) {
switch selection.VPN {
case constants.Wireguard:
customPort := selection.Wireguard.CustomPort
if customPort > 0 {
return customPort
}
const defaultPort = 51820
return defaultPort
default: // OpenVPN
customPort := selection.OpenVPN.CustomPort
if customPort > 0 {
return customPort
}
port = 1194
if selection.OpenVPN.TCP {
port = 443
}
return port
}
}
func getProtocol(selection configuration.ServerSelection) (protocol string) {
protocol = constants.UDP
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
protocol = constants.TCP
}
return protocol
}

View File

@@ -0,0 +1,204 @@
package mullvad
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Mullvad_GetConnection(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
servers []models.MullvadServer
selection configuration.ServerSelection
connection models.Connection
err error
}{
"no server available": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"no filter": {
servers: []models.MullvadServer{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(1, 1, 1, 1),
Port: 1194,
Protocol: constants.UDP,
},
},
"target IP": {
selection: configuration.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
},
servers: []models.MullvadServer{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
"with filter": {
selection: configuration.ServerSelection{
Hostnames: []string{"b"},
},
servers: []models.MullvadServer{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
connection, err := m.GetConnection(testCase.selection)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.connection, connection)
})
}
}
func Test_getPort(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
port uint16
}{
"default": {
port: 1194,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
port: 1194,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
port: 443,
},
"OpenVPN custom port": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
CustomPort: 1234,
},
},
port: 1234,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
port: 51820,
},
"Wireguard custom port": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
Wireguard: configuration.WireguardSelection{
CustomPort: 1234,
},
},
port: 1234,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
port := getPort(testCase.selection)
assert.Equal(t, testCase.port, port)
})
}
}
func Test_getProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
protocol string
}{
"default": {
protocol: constants.UDP,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
protocol: constants.UDP,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
protocol: constants.TCP,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
protocol: constants.UDP,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := getProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}

View File

@@ -11,6 +11,7 @@ func (m *Mullvad) filterServers(selection configuration.ServerSelection) (
for _, server := range m.servers {
switch {
case
server.VPN != selection.VPN,
utils.FilterByPossibilities(server.Country, selection.Countries),
utils.FilterByPossibilities(server.City, selection.Cities),
utils.FilterByPossibilities(server.ISP, selection.ISPs),

View File

@@ -0,0 +1,143 @@
package mullvad
import (
"errors"
"math/rand"
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Mullvad_filterServers(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
servers []models.MullvadServer
selection configuration.ServerSelection
filtered []models.MullvadServer
err error
}{
"no server available": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"no filter": {
servers: []models.MullvadServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
filtered: []models.MullvadServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
},
"filter OpenVPN out": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
servers: []models.MullvadServer{
{VPN: constants.OpenVPN, Hostname: "a"},
{VPN: constants.Wireguard, Hostname: "b"},
{VPN: constants.OpenVPN, Hostname: "c"},
},
filtered: []models.MullvadServer{
{VPN: constants.Wireguard, Hostname: "b"},
},
},
"filter by country": {
selection: configuration.ServerSelection{
Countries: []string{"b"},
},
servers: []models.MullvadServer{
{Country: "a"},
{Country: "b"},
{Country: "c"},
},
filtered: []models.MullvadServer{
{Country: "b"},
},
},
"filter by city": {
selection: configuration.ServerSelection{
Cities: []string{"b"},
},
servers: []models.MullvadServer{
{City: "a"},
{City: "b"},
{City: "c"},
},
filtered: []models.MullvadServer{
{City: "b"},
},
},
"filter by ISP": {
selection: configuration.ServerSelection{
ISPs: []string{"b"},
},
servers: []models.MullvadServer{
{ISP: "a"},
{ISP: "b"},
{ISP: "c"},
},
filtered: []models.MullvadServer{
{ISP: "b"},
},
},
"filter by hostname": {
selection: configuration.ServerSelection{
Hostnames: []string{"b"},
},
servers: []models.MullvadServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
filtered: []models.MullvadServer{
{Hostname: "b"},
},
},
"filter by owned": {
selection: configuration.ServerSelection{
Owned: true,
},
servers: []models.MullvadServer{
{Hostname: "a"},
{Hostname: "b", Owned: true},
{Hostname: "c"},
},
filtered: []models.MullvadServer{
{Hostname: "b", Owned: true},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
servers, err := m.filterServers(testCase.selection)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.filtered, servers)
})
}
}

View File

@@ -19,6 +19,8 @@ var ErrNoServerFound = errors.New("no server found")
func NoServerFoundError(selection configuration.ServerSelection) (err error) {
var messageParts []string
messageParts = append(messageParts, "VPN "+selection.VPN)
protocol := constants.UDP
if selection.OpenVPN.TCP {
protocol = constants.TCP

View File

@@ -0,0 +1,34 @@
package utils
import (
"net"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/wireguard"
)
func BuildWireguardSettings(connection models.Connection,
userSettings configuration.Wireguard) (settings wireguard.Settings) {
settings.PrivateKey = userSettings.PrivateKey
settings.PublicKey = connection.PubKey
settings.PreSharedKey = userSettings.PreSharedKey
settings.InterfaceName = userSettings.Interface
const routePriority = 101 // 100 is to receive external connections
settings.RulePriority = routePriority
settings.Endpoint = new(net.UDPAddr)
settings.Endpoint.IP = make(net.IP, len(connection.IP))
copy(settings.Endpoint.IP, connection.IP)
settings.Endpoint.Port = int(connection.Port)
address := new(net.IPNet)
address.IP = make(net.IP, len(userSettings.Address.IP))
copy(address.IP, userSettings.Address.IP)
address.Mask = make(net.IPMask, len(userSettings.Address.Mask))
copy(address.Mask, userSettings.Address.Mask)
settings.Addresses = append(settings.Addresses, address)
return settings
}

View File

@@ -9,16 +9,8 @@ import (
func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
connection models.Connection, err error) {
protocol := constants.UDP
var port uint16 = 443
if selection.OpenVPN.TCP {
protocol = constants.TCP
port = 1194
}
if selection.OpenVPN.CustomPort > 0 {
port = selection.OpenVPN.CustomPort
}
port := getPort(selection)
protocol := getProtocol(selection)
servers, err := w.filterServers(selection)
if err != nil {
@@ -34,6 +26,7 @@ func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
Port: port,
Protocol: protocol,
Hostname: server.OvpnX509,
PubKey: server.WgPubKey,
}
connections = append(connections, connection)
}
@@ -45,3 +38,33 @@ func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
return utils.PickRandomConnection(connections, w.randSource), nil
}
func getPort(selection configuration.ServerSelection) (port uint16) {
switch selection.VPN {
case constants.Wireguard:
customPort := selection.Wireguard.CustomPort
if customPort > 0 {
return customPort
}
const defaultPort = 1194
return defaultPort
default: // OpenVPN
customPort := selection.OpenVPN.CustomPort
if customPort > 0 {
return customPort
}
port = 1194
if selection.OpenVPN.TCP {
port = 443
}
return port
}
}
func getProtocol(selection configuration.ServerSelection) (protocol string) {
protocol = constants.UDP
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
protocol = constants.TCP
}
return protocol
}

View File

@@ -0,0 +1,204 @@
package windscribe
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Windscribe_GetConnection(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
servers []models.WindscribeServer
selection configuration.ServerSelection
connection models.Connection
err error
}{
"no server available": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"no filter": {
servers: []models.WindscribeServer{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(1, 1, 1, 1),
Port: 1194,
Protocol: constants.UDP,
},
},
"target IP": {
selection: configuration.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
},
servers: []models.WindscribeServer{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
"with filter": {
selection: configuration.ServerSelection{
Hostnames: []string{"b"},
},
servers: []models.WindscribeServer{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
connection, err := m.GetConnection(testCase.selection)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.connection, connection)
})
}
}
func Test_getPort(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
port uint16
}{
"default": {
port: 1194,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
port: 1194,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
port: 443,
},
"OpenVPN custom port": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
CustomPort: 1234,
},
},
port: 1234,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
port: 1194,
},
"Wireguard custom port": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
Wireguard: configuration.WireguardSelection{
CustomPort: 1234,
},
},
port: 1234,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
port := getPort(testCase.selection)
assert.Equal(t, testCase.port, port)
})
}
}
func Test_getProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
protocol string
}{
"default": {
protocol: constants.UDP,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
protocol: constants.UDP,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
protocol: constants.TCP,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
protocol: constants.UDP,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := getProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}

View File

@@ -11,6 +11,7 @@ func (w *Windscribe) filterServers(selection configuration.ServerSelection) (
for _, server := range w.servers {
switch {
case
server.VPN != selection.VPN,
utils.FilterByPossibilities(server.Region, selection.Regions),
utils.FilterByPossibilities(server.City, selection.Cities),
utils.FilterByPossibilities(server.Hostname, selection.Hostnames):

View File

@@ -0,0 +1,117 @@
package windscribe
import (
"errors"
"math/rand"
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Windscribe_filterServers(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
servers []models.WindscribeServer
selection configuration.ServerSelection
filtered []models.WindscribeServer
err error
}{
"no server available": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"no filter": {
servers: []models.WindscribeServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
filtered: []models.WindscribeServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
},
"filter OpenVPN out": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
servers: []models.WindscribeServer{
{VPN: constants.OpenVPN, Hostname: "a"},
{VPN: constants.Wireguard, Hostname: "b"},
{VPN: constants.OpenVPN, Hostname: "c"},
},
filtered: []models.WindscribeServer{
{VPN: constants.Wireguard, Hostname: "b"},
},
},
"filter by region": {
selection: configuration.ServerSelection{
Regions: []string{"b"},
},
servers: []models.WindscribeServer{
{Region: "a"},
{Region: "b"},
{Region: "c"},
},
filtered: []models.WindscribeServer{
{Region: "b"},
},
},
"filter by city": {
selection: configuration.ServerSelection{
Cities: []string{"b"},
},
servers: []models.WindscribeServer{
{City: "a"},
{City: "b"},
{City: "c"},
},
filtered: []models.WindscribeServer{
{City: "b"},
},
},
"filter by hostname": {
selection: configuration.ServerSelection{
Hostnames: []string{"b"},
},
servers: []models.WindscribeServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
filtered: []models.WindscribeServer{
{Hostname: "b"},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
servers, err := m.filterServers(testCase.selection)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.filtered, servers)
})
}
}

View File

@@ -22,10 +22,12 @@ type serverData struct {
Provider string `json:"provider"`
IPv4 string `json:"ipv4_addr_in"`
IPv6 string `json:"ipv6_addr_in"`
Type string `json:"type"`
PubKey string `json:"pubkey"` // Wireguard public key
}
func fetchAPI(ctx context.Context, client *http.Client) (data []serverData, err error) {
const url = "https://api.mullvad.net/www/relays/openvpn/"
const url = "https://api.mullvad.net/www/relays/all/"
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {

View File

@@ -6,14 +6,17 @@ import (
"net"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
)
type hostToServer map[string]models.MullvadServer
var (
ErrParseIPv4 = errors.New("cannot parse IPv4 address")
ErrParseIPv6 = errors.New("cannot parse IPv6 address")
ErrNoIP = errors.New("no IP address for VPN server")
ErrParseIPv4 = errors.New("cannot parse IPv4 address")
ErrParseIPv6 = errors.New("cannot parse IPv6 address")
ErrVPNTypeNotSupported = errors.New("VPN type not supported")
)
func (hts hostToServer) add(data serverData) (err error) {
@@ -21,14 +24,8 @@ func (hts hostToServer) add(data serverData) (err error) {
return
}
ipv4 := net.ParseIP(data.IPv4)
if ipv4 == nil || ipv4.To4() == nil {
return fmt.Errorf("%w: %s", ErrParseIPv4, data.IPv4)
}
ipv6 := net.ParseIP(data.IPv6)
if ipv6 == nil || ipv6.To4() != nil {
return fmt.Errorf("%w: %s", ErrParseIPv6, data.IPv6)
if data.IPv4 == "" && data.IPv6 == "" {
return ErrNoIP
}
server, ok := hts[data.Hostname]
@@ -36,13 +33,40 @@ func (hts hostToServer) add(data serverData) (err error) {
return nil
}
switch data.Type {
case "openvpn":
server.VPN = constants.OpenVPN
case "wireguard":
server.VPN = constants.Wireguard
case "bridge":
// ignore bridge servers
return nil
default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotSupported, data.Type)
}
if data.IPv4 != "" {
ipv4 := net.ParseIP(data.IPv4)
if ipv4 == nil || ipv4.To4() == nil {
return fmt.Errorf("%w: %s", ErrParseIPv4, data.IPv4)
}
server.IPs = []net.IP{ipv4}
}
if data.IPv6 != "" {
ipv6 := net.ParseIP(data.IPv6)
if ipv6 == nil || ipv6.To4() != nil {
return fmt.Errorf("%w: %s", ErrParseIPv6, data.IPv6)
}
server.IPsV6 = []net.IP{ipv6}
}
server.Country = data.Country
server.City = strings.ReplaceAll(data.City, ",", "")
server.Hostname = data.Hostname
server.ISP = data.Provider
server.Owned = data.Owned
server.IPs = []net.IP{ipv4}
server.IPsV6 = []net.IP{ipv6}
server.WgPubKey = data.PubKey
hts[data.Hostname] = server

View File

@@ -29,6 +29,7 @@ type groupData struct {
City string `json:"city"`
Nodes []serverData `json:"nodes"`
OvpnX509 string `json:"ovpn_x509"`
WgPubKey string `json:"wg_pubkey"`
}
type serverData struct {

View File

@@ -9,10 +9,14 @@ import (
"net"
"net/http"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
)
var ErrNotEnoughServers = errors.New("not enough servers found")
var (
ErrNotEnoughServers = errors.New("not enough servers found")
ErrNoWireguardKey = errors.New("no wireguard public key found")
)
func GetServers(ctx context.Context, client *http.Client, minServers int) (
servers []models.WindscribeServer, err error) {
@@ -26,19 +30,17 @@ func GetServers(ctx context.Context, client *http.Client, minServers int) (
for _, group := range regionData.Groups {
city := group.City
x5090Name := group.OvpnX509
wgPubKey := group.WgPubKey
for _, node := range group.Nodes {
const maxIPsPerNode = 3
ips := make([]net.IP, 0, maxIPsPerNode)
ips := make([]net.IP, 0, 2) // nolint:gomnd
if node.IP != nil {
ips = append(ips, node.IP)
}
if node.IP2 != nil {
ips = append(ips, node.IP2)
}
// if node.IP3 != nil { // Wireguard + Stealth
// ips = append(ips, node.IP3)
// }
server := models.WindscribeServer{
VPN: constants.OpenVPN,
Region: region,
City: city,
Hostname: node.Hostname,
@@ -46,6 +48,18 @@ func GetServers(ctx context.Context, client *http.Client, minServers int) (
IPs: ips,
}
servers = append(servers, server)
if node.IP3 == nil { // Wireguard + Stealth
continue
} else if wgPubKey == "" {
return nil, fmt.Errorf("%w: for node %s", ErrNoWireguardKey, node.Hostname)
}
server.VPN = constants.Wireguard
server.OvpnX509 = ""
server.WgPubKey = wgPubKey
server.IPs = []net.IP{node.IP3}
servers = append(servers, server)
}
}
}

View File

@@ -10,6 +10,9 @@ func sortServers(servers []models.WindscribeServer) {
sort.Slice(servers, func(i, j int) bool {
if servers[i].Region == servers[j].Region {
if servers[i].City == servers[j].City {
if servers[i].Hostname == servers[j].Hostname {
return servers[i].VPN < servers[j].VPN
}
return servers[i].Hostname < servers[j].Hostname
}
return servers[i].City < servers[j].City

View File

@@ -10,6 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/publicip"
@@ -37,6 +38,7 @@ type Loop struct {
versionInfo bool
// Configurators
openvpnConf openvpn.Interface
netLinker netlink.NetLinker
fw firewallConfigurer
routing routing.VPNGetter
portForward portforward.StartStopper
@@ -67,7 +69,7 @@ const (
func NewLoop(vpnSettings configuration.VPN,
allServers models.AllServers, openvpnConf openvpn.Interface,
fw firewallConfigurer, routing routing.VPNGetter,
netLinker netlink.NetLinker, fw firewallConfigurer, routing routing.VPNGetter,
portForward portforward.StartStopper, starter command.Starter,
publicip publicip.Looper, dnsLooper dns.Looper,
logger logging.Logger, client *http.Client,
@@ -86,6 +88,7 @@ func NewLoop(vpnSettings configuration.VPN,
buildInfo: buildInfo,
versionInfo: versionInfo,
openvpnConf: openvpnConf,
netLinker: netLinker,
fw: fw,
routing: routing,
portForward: portForward,

View File

@@ -30,8 +30,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
providerConf := provider.New(settings.Provider.Name, allServers, time.Now)
vpnRunner, serverName, err := setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.starter, l.logger)
var vpnRunner vpnRunner
var serverName, vpnInterface string
var err error
if settings.Type == constants.OpenVPN {
vpnInterface = settings.OpenVPN.Interface
vpnRunner, serverName, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.starter, l.logger)
} else { // Wireguard
vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, err = setupWireguard(ctx, l.netLinker, l.fw, providerConf, settings, l.logger)
}
if err != nil {
l.crashed(ctx, err)
continue
@@ -40,7 +49,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
portForwarding: settings.Provider.PortForwarding.Enabled,
serverName: serverName,
portForwarder: providerConf,
vpnIntf: settings.OpenVPN.Interface,
vpnIntf: vpnInterface,
}
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())

View File

@@ -17,13 +17,6 @@ type tunnelUpData struct {
}
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
vpnDestination, err := l.routing.VPNDestinationIP()
if err != nil {
l.logger.Warn(err.Error())
} else {
l.logger.Info("VPN routing IP address: " + vpnDestination.String())
}
if l.dnsLooper.GetSettings().Enabled {
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
}
@@ -40,7 +33,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
}
}
err = l.startPortForwarding(ctx, data)
err := l.startPortForwarding(ctx, data)
if err != nil {
l.logger.Error(err.Error())
}

45
internal/vpn/wireguard.go Normal file
View File

@@ -0,0 +1,45 @@
package vpn
import (
"context"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/wireguard"
)
var (
errGetServer = errors.New("failed finding a VPN server")
errCreateWireguard = errors.New("failed creating Wireguard")
)
// setupWireguard sets Wireguard up using the configurators and settings given.
// It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
fw firewall.VPNConnectionSetter, providerConf provider.Provider,
settings configuration.VPN, logger wireguard.Logger) (
wireguarder wireguard.Wireguarder, serverName string, err error) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
if err != nil {
return nil, "", fmt.Errorf("%w: %s", errGetServer, err)
}
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard)
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
if err != nil {
return nil, "", fmt.Errorf("%w: %s", errCreateWireguard, err)
}
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil {
return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
}
return wireguarder, connection.Hostname, nil
}

View File

@@ -0,0 +1,25 @@
package wireguard
import (
"fmt"
"net"
"github.com/vishvananda/netlink"
)
func (w *Wireguard) addAddresses(link netlink.Link,
addresses []*net.IPNet) (err error) {
for _, ipNet := range addresses {
address := &netlink.Addr{
IPNet: ipNet,
}
err = w.netlink.AddrAdd(link, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link %s",
err, address, link.Attrs().Name)
}
}
return nil
}

View File

@@ -0,0 +1,94 @@
package wireguard
import (
"errors"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
ipNetTwo := &net.IPNet{IP: net.IPv4(4, 5, 6, 7), Mask: net.IPv4Mask(255, 255, 255, 128)}
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
addrs []*net.IPNet
expectedAddrs []*netlink.Addr
addrAddErrs []error
err error
}{
"success": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
},
addrAddErrs: []error{nil, nil},
},
"first add error": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{
{IPNet: ipNetOne},
},
addrAddErrs: []error{errDummy},
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
},
"second add error": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
},
addrAddErrs: []error{nil, errDummy},
err: errors.New("dummy: when adding address 4.5.6.7/25 to link a_bridge"),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
require.Equal(t, len(testCase.expectedAddrs), len(testCase.addrAddErrs))
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
for i := range testCase.expectedAddrs {
netLinker.EXPECT().
AddrAdd(testCase.link, testCase.expectedAddrs[i]).
Return(testCase.addrAddErrs[i])
}
err := wg.addAddresses(testCase.link, testCase.addrs)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,59 @@
package wireguard
import "sort"
type closer struct {
operation string
step step
close func() error
closed bool
}
type closers []closer
func (c *closers) add(operation string, step step,
closeFunc func() error) {
closer := closer{
operation: operation,
step: step,
close: closeFunc,
}
*c = append(*c, closer)
}
func (c *closers) cleanup(logger Logger) {
closers := *c
sort.Slice(closers, func(i, j int) bool {
return closers[i].step < closers[j].step
})
for i, closer := range closers {
if closer.closed {
continue
} else {
closers[i].closed = true
}
logger.Debug(closer.operation + "...")
err := closer.close()
if err != nil {
logger.Error("failed " + closer.operation + ": " + err.Error())
}
}
}
type step int
const (
// stepOne closes the wireguard controller client,
// and removes the IP rule.
stepOne step = iota
// stepTwo closes the UAPI listener.
stepTwo
// stepThree closes the UAPI file.
stepThree
// stepFour closes the Wireguard device.
stepFour
// stepFive closes the bind connection and the TUN device file.
stepFive
)

View File

@@ -0,0 +1,57 @@
package wireguard
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_closers(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var ACloseCalled, BCloseCalled, CCloseCalled bool
var (
AErr error
BErr = errors.New("B failed")
CErr = errors.New("C failed")
)
var closers closers
closers.add("closing A", stepFive, func() error {
ACloseCalled = true
return AErr
})
closers.add("closing B", stepThree, func() error {
BCloseCalled = true
return BErr
})
closers.add("closing C", stepTwo, func() error {
CCloseCalled = true
return CErr
})
logger := NewMockLogger(ctrl)
prevCall := logger.EXPECT().Debug("closing C...")
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
logger.EXPECT().Debug("closing A...").After(prevCall)
closers.cleanup(logger)
closers.cleanup(logger) // run twice should not close already closed
for _, closer := range closers {
assert.True(t, closer.closed)
}
assert.True(t, ACloseCalled)
assert.True(t, BCloseCalled)
assert.True(t, CCloseCalled)
}

View File

@@ -0,0 +1,86 @@
package wireguard
import (
"errors"
"fmt"
"net"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
errMakeConfig = errors.New("cannot make device configuration")
errConfigureDevice = errors.New("cannot configure device")
)
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
deviceConfig, err := makeDeviceConfig(settings)
if err != nil {
return fmt.Errorf("%w: %s", errMakeConfig, err)
}
err = client.ConfigureDevice(settings.InterfaceName, deviceConfig)
if err != nil {
return fmt.Errorf("%w: %s", errConfigureDevice, err)
}
return nil
}
func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
privateKey, err := wgtypes.ParseKey(settings.PrivateKey)
if err != nil {
return config, ErrPrivateKeyInvalid
}
publicKey, err := wgtypes.ParseKey(settings.PublicKey)
if err != nil {
return config, fmt.Errorf("%w: %s", ErrPublicKeyInvalid, settings.PublicKey)
}
var preSharedKey *wgtypes.Key
if settings.PreSharedKey != "" {
preSharedKeyValue, err := wgtypes.ParseKey(settings.PreSharedKey)
if err != nil {
return config, ErrPreSharedKeyInvalid
}
preSharedKey = &preSharedKeyValue
}
firewallMark := settings.FirewallMark
config = wgtypes.Config{
PrivateKey: &privateKey,
ReplacePeers: true,
FirewallMark: &firewallMark,
Peers: []wgtypes.PeerConfig{
{
PublicKey: publicKey,
PresharedKey: preSharedKey,
AllowedIPs: []net.IPNet{
*allIPv4(),
*allIPv6(),
},
ReplaceAllowedIPs: true,
Endpoint: settings.Endpoint,
},
},
}
return config, nil
}
func allIPv4() (ipNet *net.IPNet) {
return &net.IPNet{
IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
}
}
func allIPv6() (ipNet *net.IPNet) {
return &net.IPNet{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
}
}

View File

@@ -0,0 +1,126 @@
package wireguard
import (
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func Test_makeDeviceConfig(t *testing.T) {
t.Parallel()
const (
validKey1 = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
validKey2 = "aPjc9US5ICB30D1P4glR9tO7bkB2Ga+KZiFqnoypBHk="
validKey3 = "gFIW0lTmBYEucynoIg+XmeWckDUXTcC4Po5ijR5G+HM="
)
parseKey := func(t *testing.T, s string) *wgtypes.Key {
t.Helper()
key, err := wgtypes.ParseKey(s)
require.NoError(t, err)
return &key
}
intPtr := func(n int) *int { return &n }
testCases := map[string]struct {
settings Settings
config wgtypes.Config
err error
}{
"bad private key": {
settings: Settings{
PrivateKey: "bad key",
},
err: ErrPrivateKeyInvalid,
},
"bad public key": {
settings: Settings{
PrivateKey: validKey1,
PublicKey: "bad key",
},
err: errors.New("cannot parse public key: bad key"),
},
"bad pre-shared key": {
settings: Settings{
PrivateKey: validKey1,
PublicKey: validKey2,
PreSharedKey: "bad key",
},
err: errors.New("cannot parse pre-shared key"),
},
"valid settings": {
settings: Settings{
PrivateKey: validKey1,
PublicKey: validKey2,
PreSharedKey: validKey3,
FirewallMark: 9876,
Endpoint: &net.UDPAddr{
IP: net.IPv4(99, 99, 99, 99),
Port: 51820,
},
},
config: wgtypes.Config{
PrivateKey: parseKey(t, validKey1),
ReplacePeers: true,
FirewallMark: intPtr(9876),
Peers: []wgtypes.PeerConfig{
{
PublicKey: *parseKey(t, validKey2),
PresharedKey: parseKey(t, validKey3),
AllowedIPs: []net.IPNet{
{
IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
},
{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
},
},
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{
IP: net.IPv4(99, 99, 99, 99),
Port: 51820,
},
},
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
config, err := makeDeviceConfig(testCase.settings)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.config, config)
})
}
}
func Test_allIPv4(t *testing.T) {
t.Parallel()
ipNet := allIPv4()
assert.Equal(t, "0.0.0.0/0", ipNet.String())
}
func Test_allIPv6(t *testing.T) {
t.Parallel()
ipNet := allIPv6()
assert.Equal(t, "::/0", ipNet.String())
}

View File

@@ -0,0 +1,30 @@
package wireguard
import "github.com/qdm12/gluetun/internal/netlink"
var _ Wireguarder = (*Wireguard)(nil)
type Wireguarder interface {
Runner
Runner
}
type Wireguard struct {
logger Logger
settings Settings
netlink netlink.NetLinker
}
func New(settings Settings, netlink NetLinker,
logger Logger) (w *Wireguard, err error) {
settings.SetDefaults()
if err := settings.Check(); err != nil {
return nil, err
}
return &Wireguard{
logger: logger,
settings: settings,
netlink: netlink,
}, nil
}

View File

@@ -0,0 +1,80 @@
package wireguard
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_New(t *testing.T) {
t.Parallel()
const validKeyString = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
logger := NewMockLogger(nil)
netLinker := NewMockNetLinker(nil)
testCases := map[string]struct {
settings Settings
wireguard *Wireguard
err error
}{
"bad settings": {
settings: Settings{
PrivateKey: "",
},
err: ErrPrivateKeyMissing,
},
"minimal valid settings": {
settings: Settings{
PrivateKey: validKeyString,
PublicKey: validKeyString,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
},
Addresses: []*net.IPNet{{
IP: net.IPv4(5, 6, 7, 8),
Mask: net.IPv4Mask(255, 255, 255, 255)},
},
FirewallMark: 100,
},
wireguard: &Wireguard{
logger: logger,
netlink: netLinker,
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKeyString,
PublicKey: validKeyString,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{
IP: net.IPv4(5, 6, 7, 8),
Mask: net.IPv4Mask(255, 255, 255, 255)},
},
FirewallMark: 100,
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
wireguard, err := New(testCase.settings, netLinker, logger)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.wireguard, wireguard)
})
}
}

26
internal/wireguard/log.go Normal file
View File

@@ -0,0 +1,26 @@
package wireguard
import (
"fmt"
"golang.zx2c4.com/wireguard/device"
)
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
return &device.Logger{
Verbosef: func(format string, args ...interface{}) {
logger.Debug(fmt.Sprintf(format, args...))
},
Errorf: func(format string, args ...interface{}) {
logger.Error(fmt.Sprintf(format, args...))
},
}
}

View File

@@ -0,0 +1,70 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/wireguard (interfaces: Logger)
// Package wireguard is a generated GoMock package.
package wireguard
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Error mocks base method.
func (m *MockLogger) Error(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Error", arg0)
}
// Error indicates an expected call of Error.
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
}
// Info mocks base method.
func (m *MockLogger) Info(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0)
}
// Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
}

View File

@@ -0,0 +1,23 @@
package wireguard
import (
"testing"
"github.com/golang/mock/gomock"
)
func Test_makeDeviceLogger(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
deviceLogger := makeDeviceLogger(logger)
logger.EXPECT().Debug("test 1")
deviceLogger.Verbosef("test %d", 1)
logger.EXPECT().Error("test 2")
deviceLogger.Errorf("test %d", 2)
}

View File

@@ -0,0 +1,113 @@
// +build netlink
package wireguard
import (
"fmt"
"math/rand"
"net"
"testing"
inetlink "github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
netlinker := inetlink.New()
wg := &Wireguard{
netlink: netlinker,
}
intfName := "test_" + fmt.Sprint(rand.Intn(10000)) //nolint:gosec
// Add link
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = intfName
link := &netlink.Bridge{
LinkAttrs: linkAttrs,
}
err := netlink.LinkAdd(link)
require.NoError(t, err)
defer func() {
err = netlink.LinkDel(link)
assert.NoError(t, err)
}()
addresses := []*net.IPNet{
{IP: net.IP{1, 2, 3, 4}, Mask: net.IPv4Mask(255, 255, 255, 255)},
{IP: net.IP{5, 6, 7, 8}, Mask: net.IPv4Mask(255, 255, 255, 255)},
}
// Success
err = wg.addAddresses(link, addresses)
require.NoError(t, err)
netlinkAddresses, err := netlink.AddrList(link, netlink.FAMILY_ALL)
require.NoError(t, err)
require.Equal(t, len(addresses), len(netlinkAddresses))
for i, netlinkAddress := range netlinkAddresses {
ipNet := netlinkAddress.IPNet
assert.Equal(t, addresses[i], ipNet)
}
// Existing address cannot be added
err = wg.addAddresses(link, addresses)
require.Error(t, err)
assert.Equal(t, "file exists: when adding address 1.2.3.4/32 to link test_8081", err.Error())
}
func Test_netlink_Wireguard_addRule(t *testing.T) {
t.Parallel()
netlinker := inetlink.New()
wg := &Wireguard{
netlink: netlinker,
}
rulePriority := 10000
const firewallMark = 999
cleanup, err := wg.addRule(rulePriority, firewallMark)
require.NoError(t, err)
defer func() {
err := cleanup()
assert.NoError(t, err)
}()
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
require.NoError(t, err)
var rule netlink.Rule
var ruleFound bool
for _, rule = range rules {
if rule.Mark == firewallMark {
ruleFound = true
break
}
}
require.True(t, ruleFound)
expectedRule := netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: 4294967295,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
}
assert.Equal(t, expectedRule, rule)
// Existing rule cannot be added
nilCleanup, err := wg.addRule(rulePriority, firewallMark)
if nilCleanup != nil {
_ = nilCleanup() // in case it succeeds
}
require.Error(t, err)
assert.Equal(t, "file exists: when adding rule: ip rule 10000: from <nil> table 999", err.Error())
}

View File

@@ -0,0 +1,12 @@
package wireguard
import "github.com/vishvananda/netlink"
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface {
AddrAdd(link netlink.Link, addr *netlink.Addr) error
RouteAdd(route *netlink.Route) error
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
}

View File

@@ -0,0 +1,91 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/wireguard (interfaces: NetLinker)
// Package wireguard is a generated GoMock package.
package wireguard
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
netlink "github.com/vishvananda/netlink"
)
// MockNetLinker is a mock of NetLinker interface.
type MockNetLinker struct {
ctrl *gomock.Controller
recorder *MockNetLinkerMockRecorder
}
// MockNetLinkerMockRecorder is the mock recorder for MockNetLinker.
type MockNetLinkerMockRecorder struct {
mock *MockNetLinker
}
// NewMockNetLinker creates a new mock instance.
func NewMockNetLinker(ctrl *gomock.Controller) *MockNetLinker {
mock := &MockNetLinker{ctrl: ctrl}
mock.recorder = &MockNetLinkerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
return m.recorder
}
// AddrAdd mocks base method.
func (m *MockNetLinker) AddrAdd(arg0 netlink.Link, arg1 *netlink.Addr) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrAdd", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// AddrAdd indicates an expected call of AddrAdd.
func (mr *MockNetLinkerMockRecorder) AddrAdd(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1)
}
// RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RouteAdd indicates an expected call of RouteAdd.
func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0)
}
// RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RuleAdd indicates an expected call of RuleAdd.
func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleAdd", reflect.TypeOf((*MockNetLinker)(nil).RuleAdd), arg0)
}
// RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RuleDel indicates an expected call of RuleDel.
func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleDel", reflect.TypeOf((*MockNetLinker)(nil).RuleDel), arg0)
}

View File

@@ -0,0 +1,26 @@
package wireguard
import (
"fmt"
"net"
"github.com/vishvananda/netlink"
)
// TODO add IPv6 route if IPv6 is supported
func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
firewallMark int) (err error) {
route := &netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: dst,
Table: firewallMark,
}
err = w.netlink.RouteAdd(route)
if err != nil {
return fmt.Errorf("%w: when adding route: %s", err, route)
}
return err
}

View File

@@ -0,0 +1,85 @@
package wireguard
import (
"errors"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
func Test_Wireguard_addRoute(t *testing.T) {
t.Parallel()
const linkIndex = 88
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
linkAttrs.Index = linkIndex
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
const firewallMark = 51820
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
dst *net.IPNet
expectedRoute *netlink.Route
routeAddErr error
err error
}{
"success": {
link: newLink(),
dst: ipNet,
expectedRoute: &netlink.Route{
LinkIndex: linkIndex,
Dst: ipNet,
Table: firewallMark,
},
},
"route add error": {
link: newLink(),
dst: ipNet,
expectedRoute: &netlink.Route{
LinkIndex: linkIndex,
Dst: ipNet,
Table: firewallMark,
},
routeAddErr: errDummy,
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820}"), //nolint:lll
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
netLinker.EXPECT().
RouteAdd(testCase.expectedRoute).
Return(testCase.routeAddErr)
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,28 @@
package wireguard
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (w *Wireguard) addRule(rulePriority, firewallMark int) (
cleanup func() error, err error) {
rule := netlink.NewRule()
rule.Invert = true
rule.Priority = rulePriority
rule.Mark = firewallMark
rule.Table = firewallMark
if err := w.netlink.RuleAdd(rule); err != nil {
return nil, fmt.Errorf("%w: when adding rule: %s", err, rule)
}
cleanup = func() error {
err := w.netlink.RuleDel(rule)
if err != nil {
return fmt.Errorf("%w: when deleting rule: %s", err, rule)
}
return nil
}
return cleanup, nil
}

View File

@@ -0,0 +1,106 @@
package wireguard
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
func Test_Wireguard_addRule(t *testing.T) {
t.Parallel()
const rulePriority = 987
const firewallMark = 456
errDummy := errors.New("dummy")
testCases := map[string]struct {
expectedRule *netlink.Rule
ruleAddErr error
err error
ruleDelErr error
cleanupErr error
}{
"success": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
},
},
"rule add error": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
},
ruleAddErr: errDummy,
err: errors.New("dummy: when adding rule: ip rule 987: from <nil> table 456"),
},
"rule delete error": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
},
ruleDelErr: errDummy,
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from <nil> table 456"),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
Return(testCase.ruleAddErr)
cleanup, err := wg.addRule(rulePriority, firewallMark)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
return
}
require.NoError(t, err)
netLinker.EXPECT().RuleDel(testCase.expectedRule).
Return(testCase.ruleDelErr)
err = cleanup()
if testCase.cleanupErr != nil {
require.Error(t, err)
assert.Equal(t, testCase.cleanupErr.Error(), err.Error())
} else {
require.NoError(t, err)
}
})
}
}

165
internal/wireguard/run.go Normal file
View File

@@ -0,0 +1,165 @@
package wireguard
import (
"context"
"errors"
"fmt"
"net"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl"
)
var (
ErrCreateTun = errors.New("cannot create TUN device")
ErrFindLink = errors.New("cannot find link")
ErrFindDevice = errors.New("cannot find Wireguard device")
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
ErrWgctrlOpen = errors.New("cannot open wgctrl")
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
ErrAddAddress = errors.New("cannot add address to wireguard interface")
ErrConfigure = errors.New("cannot configure wireguard interface")
ErrIfaceUp = errors.New("cannot set the interface to UP")
ErrRouteAdd = errors.New("cannot add route for interface")
ErrRuleAdd = errors.New("cannot add rule for interface")
ErrDeviceWaited = errors.New("device waited for")
)
type Runner interface {
Run(ctx context.Context, waitError chan<- error, ready chan<- struct{})
}
// See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
client, err := wgctrl.New()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
return
}
var closers closers
closers.add("closing controller client", stepOne, client.Close)
defer closers.cleanup(w.logger)
tun, err := tun.CreateTUN(w.settings.InterfaceName, device.DefaultMTU)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrCreateTun, err)
return
}
closers.add("closing TUN device", stepFive, tun.Close)
tunName, err := tun.Name()
if err != nil {
waitError <- fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return
} else if tunName != w.settings.InterfaceName {
waitError <- fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, w.settings.InterfaceName, tunName)
return
}
link, err := netlink.LinkByName(w.settings.InterfaceName)
if err != nil {
waitError <- fmt.Errorf("%w: %s: %s", ErrFindLink, w.settings.InterfaceName, err)
return
}
bind := conn.NewDefaultBind()
closers.add("closing bind", stepFive, bind.Close)
deviceLogger := makeDeviceLogger(w.logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepFour, func() error {
device.Close()
return nil
})
uapiFile, err := ipc.UAPIOpen(w.settings.InterfaceName)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(w.settings.InterfaceName, uapiFile)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrUAPIListen, err)
return
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
err = w.addAddresses(link, w.settings.Addresses)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
return
}
err = configureDevice(client, w.settings)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
return
}
if err := netlink.LinkSetUp(link); err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return
}
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return
}
ruleCleanup, err := w.addRule(
w.settings.RulePriority, w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRuleAdd, err)
return
}
closers.add("removing rule", stepOne, ruleCleanup)
w.logger.Info("Wireguard is up")
ready <- struct{}{}
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh)
case <-device.Wait():
err = ErrDeviceWaited
}
closers.cleanup(w.logger)
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
waitError <- err
}
func acceptAndHandle(uapi net.Listener, device *device.Device,
uapiAcceptErrorCh chan<- error) {
for { // stopped by uapiFile.Close()
conn, err := uapi.Accept()
if err != nil {
uapiAcceptErrorCh <- err
return
}
go device.IpcHandle(conn)
}
}

View File

@@ -0,0 +1,212 @@
package wireguard
import (
"errors"
"fmt"
"net"
"regexp"
"strings"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Settings struct {
// Interface name for the Wireguard interface.
// It defaults to wg0 if unset.
InterfaceName string
// Private key in base 64 format
PrivateKey string
// Public key in base 64 format
PublicKey string
// Pre shared key in base 64 format
PreSharedKey string
// Wireguard server endpoint to connect to.
Endpoint *net.UDPAddr
// Addresses assigned to the client.
Addresses []*net.IPNet
// FirewallMark to be used in routing tables and IP rules.
// It defaults to 51820 if left to 0.
FirewallMark int
// RulePriority is the priority for the rule created with the
// FirewallMark.
RulePriority int
}
func (s *Settings) SetDefaults() {
if s.InterfaceName == "" {
const defaultInterfaceName = "wg0"
s.InterfaceName = defaultInterfaceName
}
if s.Endpoint != nil && s.Endpoint.Port == 0 {
const defaultPort = 51820
s.Endpoint.Port = defaultPort
}
if s.FirewallMark == 0 {
const defaultFirewallMark = 51820
s.FirewallMark = defaultFirewallMark
}
}
var (
ErrInterfaceNameInvalid = errors.New("invalid interface name")
ErrPrivateKeyMissing = errors.New("private key is missing")
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
ErrPublicKeyMissing = errors.New("public key is missing")
ErrPublicKeyInvalid = errors.New("cannot parse public key")
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
ErrEndpointMissing = errors.New("endpoint is missing")
ErrEndpointIPMissing = errors.New("endpoint IP is missing")
ErrEndpointPortMissing = errors.New("endpoint port is missing")
ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNil = errors.New("interface address is nil")
ErrAddressIPMissing = errors.New("interface address IP is missing")
ErrAddressMaskMissing = errors.New("interface address mask is missing")
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
)
var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
func (s *Settings) Check() (err error) {
if !interfaceNameRegexp.MatchString(s.InterfaceName) {
return fmt.Errorf("%w: %s", ErrInterfaceNameInvalid, s.InterfaceName)
}
if s.PrivateKey == "" {
return ErrPrivateKeyMissing
} else if _, err := wgtypes.ParseKey(s.PrivateKey); err != nil {
return ErrPrivateKeyInvalid
}
if s.PublicKey == "" {
return ErrPublicKeyMissing
} else if _, err := wgtypes.ParseKey(s.PublicKey); err != nil {
return fmt.Errorf("%w: %s", ErrPublicKeyInvalid, s.PublicKey)
}
if s.PreSharedKey != "" {
if _, err := wgtypes.ParseKey(s.PreSharedKey); err != nil {
return ErrPreSharedKeyInvalid
}
}
switch {
case s.Endpoint == nil:
return ErrEndpointMissing
case s.Endpoint.IP == nil:
return ErrEndpointIPMissing
case s.Endpoint.Port == 0:
return ErrEndpointPortMissing
}
if len(s.Addresses) == 0 {
return ErrAddressMissing
}
for i, addr := range s.Addresses {
switch {
case addr == nil:
return fmt.Errorf("%w: for address %d of %d",
ErrAddressNil, i+1, len(s.Addresses))
case addr.IP == nil:
return fmt.Errorf("%w: for address %d of %d",
ErrAddressIPMissing, i+1, len(s.Addresses))
case addr.Mask == nil:
return fmt.Errorf("%w: for address %d of %d",
ErrAddressMaskMissing, i+1, len(s.Addresses))
}
}
if s.FirewallMark == 0 {
return ErrFirewallMarkMissing
}
return nil
}
func (s Settings) String() string {
lines := s.ToLines(ToLinesSettings{})
return strings.Join(lines, "\n")
}
type ToLinesSettings struct {
// Indent defaults to 4 spaces " ".
Indent *string
// FieldPrefix defaults to "├── ".
FieldPrefix *string
// LastFieldPrefix defaults to "└── ".
LastFieldPrefix *string
}
func (settings *ToLinesSettings) setDefaults() {
toStringPtr := func(s string) *string { return &s }
if settings.Indent == nil {
settings.Indent = toStringPtr(" ")
}
if settings.FieldPrefix == nil {
settings.FieldPrefix = toStringPtr("├── ")
}
if settings.LastFieldPrefix == nil {
settings.LastFieldPrefix = toStringPtr("└── ")
}
}
// ToLines serializes the settings to a slice of strings for display.
func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
settings.setDefaults()
indent := *settings.Indent
fieldPrefix := *settings.FieldPrefix
lastFieldPrefix := *settings.LastFieldPrefix
lines = append(lines, fieldPrefix+"Interface name: "+s.InterfaceName)
const (
set = "set"
notSet = "not set"
)
isSet := notSet
if s.PrivateKey != "" {
isSet = set
}
lines = append(lines, fieldPrefix+"Private key: "+isSet)
if s.PublicKey != "" {
lines = append(lines, fieldPrefix+"PublicKey: "+s.PublicKey)
}
isSet = notSet
if s.PreSharedKey != "" {
isSet = set
}
lines = append(lines, fieldPrefix+"Pre shared key: "+isSet)
endpointStr := notSet
if s.Endpoint != nil {
endpointStr = s.Endpoint.String()
}
lines = append(lines, fieldPrefix+"Endpoint: "+endpointStr)
if s.FirewallMark != 0 {
lines = append(lines, fieldPrefix+"Firewall mark: "+fmt.Sprint(s.FirewallMark))
}
if s.RulePriority != 0 {
lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority))
}
if len(s.Addresses) == 0 {
lines = append(lines, lastFieldPrefix+"Addresses: "+notSet)
} else {
lines = append(lines, lastFieldPrefix+"Addresses:")
for i, address := range s.Addresses {
prefix := fieldPrefix
if i == len(s.Addresses)-1 {
prefix = lastFieldPrefix
}
lines = append(lines, indent+prefix+address.String())
}
}
return lines
}

View File

@@ -0,0 +1,377 @@
package wireguard
import (
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Settings_SetDefaults(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
original Settings
expected Settings
}{
"empty settings": {
expected: Settings{
InterfaceName: "wg0",
FirewallMark: 51820,
},
},
"default endpoint port": {
original: Settings{
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
},
},
expected: Settings{
InterfaceName: "wg0",
FirewallMark: 51820,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
},
},
"not empty settings": {
original: Settings{
InterfaceName: "wg1",
FirewallMark: 999,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 9999,
},
},
expected: Settings{
InterfaceName: "wg1",
FirewallMark: 999,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 9999,
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
testCase.original.SetDefaults()
assert.Equal(t, testCase.expected, testCase.original)
})
}
}
func Test_Settings_Check(t *testing.T) {
t.Parallel()
const (
validKey1 = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
validKey2 = "aPjc9US5ICB30D1P4glR9tO7bkB2Ga+KZiFqnoypBHk="
)
testCases := map[string]struct {
settings Settings
err error
}{
"empty settings": {
err: errors.New("invalid interface name: "),
},
"bad interface name": {
settings: Settings{
InterfaceName: "$H1T",
},
err: errors.New("invalid interface name: $H1T"),
},
"empty private key": {
settings: Settings{
InterfaceName: "wg0",
},
err: ErrPrivateKeyMissing,
},
"bad private key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: "bad key",
},
err: ErrPrivateKeyInvalid,
},
"empty public key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
},
err: ErrPublicKeyMissing,
},
"bad public key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: "bad key",
},
err: errors.New("cannot parse public key: bad key"),
},
"bad preshared key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
PreSharedKey: "bad key",
},
err: errors.New("cannot parse pre-shared key"),
},
"empty endpoint": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
},
err: ErrEndpointMissing,
},
"nil endpoint IP": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{},
},
err: ErrEndpointIPMissing,
},
"nil endpoint port": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
},
},
err: ErrEndpointPortMissing,
},
"no address": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
},
err: ErrAddressMissing,
},
"nil address": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{nil},
},
err: errors.New("interface address is nil: for address 1 of 1"),
},
"nil address IP": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{}},
},
err: errors.New("interface address IP is missing: for address 1 of 1"),
},
"nil address mask": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4)}},
},
err: errors.New("interface address mask is missing: for address 1 of 1"),
},
"zero firewall mark": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
},
err: ErrFirewallMarkMissing,
},
"all valid": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
FirewallMark: 999,
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
err := testCase.settings.Check()
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func toStringPtr(s string) *string { return &s }
func Test_ToLinesSettings_setDefaults(t *testing.T) {
t.Parallel()
settings := ToLinesSettings{
Indent: toStringPtr("indent"),
}
someFunc := func(settings ToLinesSettings) {
settings.setDefaults()
expectedSettings := ToLinesSettings{
Indent: toStringPtr("indent"),
FieldPrefix: toStringPtr("├── "),
LastFieldPrefix: toStringPtr("└── "),
}
assert.Equal(t, expectedSettings, settings)
}
someFunc(settings)
untouchedSettings := ToLinesSettings{
Indent: toStringPtr("indent"),
}
assert.Equal(t, untouchedSettings, settings)
}
func Test_Settings_String(t *testing.T) {
t.Parallel()
settings := Settings{
InterfaceName: "wg0",
}
const expected = `├── Interface name: wg0
├── Private key: not set
├── Pre shared key: not set
├── Endpoint: not set
└── Addresses: not set`
s := settings.String()
assert.Equal(t, expected, s)
}
func Test_Settings_Lines(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
settings Settings
lineSettings ToLinesSettings
lines []string
}{
"empty settings": {
lines: []string{
"├── Interface name: ",
"├── Private key: not set",
"├── Pre shared key: not set",
"├── Endpoint: not set",
"└── Addresses: not set",
},
},
"settings all set": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: "private key",
PublicKey: "public key",
PreSharedKey: "pre-shared key",
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
FirewallMark: 999,
RulePriority: 888,
Addresses: []*net.IPNet{
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
},
},
lines: []string{
"├── Interface name: wg0",
"├── Private key: set",
"├── PublicKey: public key",
"├── Pre shared key: set",
"├── Endpoint: 1.2.3.4:51820",
"├── Firewall mark: 999",
"├── Rule priority: 888",
"└── Addresses:",
" ├── 1.1.1.1/24",
" └── 2.2.2.2/32",
},
},
"custom line settings": {
lineSettings: ToLinesSettings{
Indent: toStringPtr(" "),
FieldPrefix: toStringPtr("- "),
LastFieldPrefix: toStringPtr("* "),
},
settings: Settings{
InterfaceName: "wg0",
Addresses: []*net.IPNet{
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
},
},
lines: []string{
"- Interface name: wg0",
"- Private key: not set",
"- Pre shared key: not set",
"- Endpoint: not set",
"* Addresses:",
" - 1.1.1.1/24",
" * 2.2.2.2/32",
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
lines := testCase.settings.ToLines(testCase.lineSettings)
assert.Equal(t, testCase.lines, lines)
})
}
}

View File

@@ -11,6 +11,7 @@
- Filter servers by protocol for all
- Multiple IPs addresses support for all proviedrs
- Use `internal/netlink` in firewall and routing packages
## Code