Wireguard support for Mullvad and Windscribe (#565)

- `internal/wireguard` client package with unit tests
- Implementation works with kernel space or user space if unavailable
- `WIREGUARD_PRIVATE_KEY`
- `WIREGUARD_ADDRESS`
- `WIREGUARD_PRESHARED_KEY`
- `WIREGUARD_PORT`
- `internal/netlink` package used by `internal/wireguard`
This commit is contained in:
Quentin McGaw
2021-08-22 14:58:39 -07:00
committed by GitHub
parent 0bfd58a3f5
commit 614eb10d67
70 changed files with 13595 additions and 148 deletions

View File

@@ -49,7 +49,7 @@ func (settings *OpenVPNSelection) readMullvad(env params.Env) (err error) {
return err
}
settings.CustomPort, err = readCustomPort(env, settings.TCP,
settings.CustomPort, err = readOpenVPNCustomPort(env, settings.TCP,
[]uint16{80, 443, 1401}, []uint16{53, 1194, 1195, 1196, 1197, 1300, 1301, 1302, 1303, 1400})
if err != nil {
return err

View File

@@ -43,7 +43,7 @@ var (
)
func (settings *Provider) read(r reader, vpnType string) error {
err := settings.readVPNServiceProvider(r)
err := settings.readVPNServiceProvider(r, vpnType)
if err != nil {
return err
}
@@ -94,11 +94,17 @@ func (settings *Provider) read(r reader, vpnType string) error {
return nil
}
func (settings *Provider) readVPNServiceProvider(r reader) (err error) {
allowedVPNServiceProviders := []string{
"cyberghost", "fastestvpn", "hidemyass", "ipvanish", "ivpn", "mullvad", "nordvpn",
"privado", "pia", "private internet access", "privatevpn", "protonvpn",
"purevpn", "surfshark", "torguard", constants.VPNUnlimited, "vyprvpn", "windscribe"}
func (settings *Provider) readVPNServiceProvider(r reader, vpnType string) (err error) {
var allowedVPNServiceProviders []string
switch vpnType {
case constants.OpenVPN:
allowedVPNServiceProviders = []string{
"cyberghost", "fastestvpn", "hidemyass", "ipvanish", "ivpn", "mullvad", "nordvpn",
"privado", "pia", "private internet access", "privatevpn", "protonvpn",
"purevpn", "surfshark", "torguard", constants.VPNUnlimited, "vyprvpn", "windscribe"}
case constants.Wireguard:
allowedVPNServiceProviders = []string{constants.Mullvad, constants.Windscribe}
}
vpnsp, err := r.env.Inside("VPNSP", allowedVPNServiceProviders,
params.Default("private internet access"))
@@ -132,7 +138,7 @@ func readTargetIP(env params.Env) (targetIP net.IP, err error) {
return targetIP, nil
}
func readCustomPort(env params.Env, tcp bool,
func readOpenVPNCustomPort(env params.Env, tcp bool,
allowedTCP, allowedUDP []uint16) (port uint16, err error) {
port, err = readPortOrZero(env, "PORT")
if err != nil {
@@ -147,12 +153,42 @@ func readCustomPort(env params.Env, tcp bool,
return port, nil
}
}
return 0, fmt.Errorf("environment variable PORT: %w: port %d for TCP protocol", ErrInvalidPort, port)
return 0, fmt.Errorf(
"environment variable PORT: %w: port %d for TCP protocol, can only be one of %s",
ErrInvalidPort, port, portsToString(allowedTCP))
}
for i := range allowedUDP {
if allowedUDP[i] == port {
return port, nil
}
}
return 0, fmt.Errorf("environment variable PORT: %w: port %d for UDP protocol", ErrInvalidPort, port)
return 0, fmt.Errorf(
"environment variable PORT: %w: port %d for UDP protocol, can only be one of %s",
ErrInvalidPort, port, portsToString(allowedUDP))
}
func readWireguardCustomPort(env params.Env, allowed []uint16) (port uint16, err error) {
port, err = readPortOrZero(env, "WIREGUARD_PORT")
if err != nil {
return 0, fmt.Errorf("environment variable WIREGUARD_PORT: %w", err)
} else if port == 0 {
return 0, nil
}
for i := range allowed {
if allowed[i] == port {
return port, nil
}
}
return 0, fmt.Errorf(
"environment variable WIREGUARD_PORT: %w: port %d, can only be one of %s",
ErrInvalidPort, port, portsToString(allowed))
}
func portsToString(ports []uint16) string {
slice := make([]string, len(ports))
for i := range ports {
slice[i] = fmt.Sprint(ports[i])
}
return strings.Join(slice, ", ")
}

View File

@@ -24,6 +24,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Cyberghost,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Groups: []string{"group"},
Regions: []string{"a", "El country"},
},
@@ -40,6 +41,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Fastestvpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Hostnames: []string{"a", "b"},
Countries: []string{"c", "d"},
},
@@ -56,6 +58,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.HideMyAss,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},
@@ -74,6 +77,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Ipvanish,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},
@@ -92,6 +96,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Ivpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},
@@ -110,6 +115,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Mullvad,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
ISPs: []string{"e", "f"},
@@ -132,6 +138,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Nordvpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
Numbers: []uint16{1, 2},
},
@@ -148,6 +155,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Privado,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Hostnames: []string{"a", "b"},
},
},
@@ -162,6 +170,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Privatevpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Hostnames: []string{"a", "b"},
Countries: []string{"c", "d"},
Cities: []string{"e", "f"},
@@ -180,6 +189,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Protonvpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Regions: []string{"c", "d"},
Cities: []string{"e", "f"},
@@ -202,6 +212,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.PrivateInternetAccess,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
OpenVPN: OpenVPNSelection{
CustomPort: 1,
@@ -226,6 +237,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Purevpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
Countries: []string{"c", "d"},
Cities: []string{"e", "f"},
@@ -244,6 +256,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Surfshark,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
},
},
@@ -258,6 +271,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Torguard,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e"},
@@ -276,6 +290,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.VPNUnlimited,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Countries: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},
@@ -298,6 +313,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Vyprvpn,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
},
},
@@ -312,6 +328,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Windscribe,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
Regions: []string{"a", "b"},
Cities: []string{"c", "d"},
Hostnames: []string{"e", "f"},

View File

@@ -109,12 +109,12 @@ func readIP(env params.Env, key string) (ip net.IP, err error) {
}
func readPortOrZero(env params.Env, key string) (port uint16, err error) {
s, err := env.Get(key)
s, err := env.Get(key, params.Default("0"))
if err != nil {
return 0, err
}
if s == "" || s == "0" {
if s == "0" {
return 0, nil
}

View File

@@ -4,12 +4,13 @@ import (
"fmt"
"net"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/params"
)
type ServerSelection struct { //nolint:maligned
// Common
VPN string `json:"vpn"`
VPN string `json:"vpn"` // note: this is required
TargetIP net.IP `json:"target_ip,omitempty"`
// TODO comments
// Cyberghost, PIA, Protonvpn, Surfshark, Windscribe, Vyprvpn, NordVPN
@@ -39,7 +40,8 @@ type ServerSelection struct { //nolint:maligned
// VPNUnlimited
StreamOnly bool `json:"stream_only"`
OpenVPN OpenVPNSelection `json:"openvpn"`
OpenVPN OpenVPNSelection `json:"openvpn"`
Wireguard WireguardSelection `json:"wireguard"`
}
func (selection ServerSelection) toLines() (lines []string) {
@@ -91,7 +93,11 @@ func (selection ServerSelection) toLines() (lines []string) {
lines = append(lines, lastIndent+"Numbers: "+commaJoin(numbersString))
}
lines = append(lines, selection.OpenVPN.lines()...)
if selection.VPN == constants.OpenVPN {
lines = append(lines, selection.OpenVPN.lines()...)
} else { // wireguard
lines = append(lines, selection.Wireguard.lines()...)
}
return lines
}
@@ -137,6 +143,20 @@ func (settings *OpenVPNSelection) readProtocolAndPort(env params.Env) (err error
return nil
}
type WireguardSelection struct {
CustomPort uint16 `json:"custom_port"` // Mullvad
}
func (settings *WireguardSelection) lines() (lines []string) {
lines = append(lines, lastIndent+"Wireguard selection:")
if settings.CustomPort != 0 {
lines = append(lines, indent+lastIndent+"Custom port: "+fmt.Sprint(settings.CustomPort))
}
return lines
}
// PortForwarding contains settings for port forwarding.
type PortForwarding struct {
Enabled bool `json:"enabled"`

View File

@@ -20,6 +20,9 @@ func Test_Settings_lines(t *testing.T) {
Type: constants.OpenVPN,
Provider: Provider{
Name: constants.Mullvad,
ServerSelection: ServerSelection{
VPN: constants.OpenVPN,
},
},
OpenVPN: OpenVPN{
Version: constants.Openvpn25,

View File

@@ -10,9 +10,10 @@ import (
)
type VPN struct {
Type string `json:"type"`
OpenVPN OpenVPN `json:"openvpn"`
Provider Provider `json:"provider"`
Type string `json:"type"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
Provider Provider `json:"provider"`
}
func (settings *VPN) String() string {
@@ -24,7 +25,14 @@ func (settings *VPN) lines() (lines []string) {
lines = append(lines, indent+lastIndent+"Type: "+settings.Type)
for _, line := range settings.OpenVPN.lines() {
var vpnLines []string
switch settings.Type {
case constants.OpenVPN:
vpnLines = settings.OpenVPN.lines()
case constants.Wireguard:
vpnLines = settings.Wireguard.lines()
}
for _, line := range vpnLines {
lines = append(lines, indent+line)
}
@@ -36,13 +44,15 @@ func (settings *VPN) lines() (lines []string) {
}
var (
errReadProviderSettings = errors.New("cannot read provider settings")
errReadOpenVPNSettings = errors.New("cannot read OpenVPN settings")
errReadProviderSettings = errors.New("cannot read provider settings")
errReadOpenVPNSettings = errors.New("cannot read OpenVPN settings")
errReadWireguardSettings = errors.New("cannot read Wireguard settings")
)
func (settings *VPN) read(r reader) (err error) {
vpnType, err := r.env.Inside("VPN_TYPE",
[]string{constants.OpenVPN}, params.Default(constants.OpenVPN))
[]string{constants.OpenVPN, constants.Wireguard},
params.Default(constants.OpenVPN))
if err != nil {
return fmt.Errorf("environment variable VPN_TYPE: %w", err)
}
@@ -54,9 +64,17 @@ func (settings *VPN) read(r reader) (err error) {
}
}
err = settings.OpenVPN.read(r, settings.Provider.Name)
if err != nil {
return fmt.Errorf("%w: %s", errReadOpenVPNSettings, err)
switch settings.Type {
case constants.OpenVPN:
err = settings.OpenVPN.read(r, settings.Provider.Name)
if err != nil {
return fmt.Errorf("%w: %s", errReadOpenVPNSettings, err)
}
case constants.Wireguard:
err = settings.Wireguard.read(r)
if err != nil {
return fmt.Errorf("%w: %s", errReadWireguardSettings, err)
}
}
return nil

View File

@@ -30,7 +30,12 @@ func (settings *Provider) readWindscribe(r reader) (err error) {
return fmt.Errorf("environment variable SERVER_HOSTNAME: %w", err)
}
return settings.ServerSelection.OpenVPN.readWindscribe(r.env)
err = settings.ServerSelection.OpenVPN.readWindscribe(r.env)
if err != nil {
return err
}
return settings.ServerSelection.Wireguard.readWindscribe(r.env)
}
func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
@@ -39,7 +44,7 @@ func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
return err
}
settings.CustomPort, err = readCustomPort(env, settings.TCP,
settings.CustomPort, err = readOpenVPNCustomPort(env, settings.TCP,
[]uint16{21, 22, 80, 123, 143, 443, 587, 1194, 3306, 8080, 54783},
[]uint16{53, 80, 123, 443, 1194, 54783})
if err != nil {
@@ -48,3 +53,13 @@ func (settings *OpenVPNSelection) readWindscribe(env params.Env) (err error) {
return nil
}
func (settings *WireguardSelection) readWindscribe(env params.Env) (err error) {
settings.CustomPort, err = readWireguardCustomPort(env,
[]uint16{53, 80, 123, 443, 1194, 65142})
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,73 @@
package configuration
import (
"fmt"
"net"
"strings"
"github.com/qdm12/golibs/params"
)
// Wireguard contains settings to configure the Wireguard client.
type Wireguard struct {
PrivateKey string `json:"privatekey"`
PreSharedKey string `json:"presharedkey"`
Address *net.IPNet `json:"address"`
Interface string `json:"interface"`
}
func (settings *Wireguard) String() string {
return strings.Join(settings.lines(), "\n")
}
func (settings *Wireguard) lines() (lines []string) {
lines = append(lines, lastIndent+"Wireguard:")
lines = append(lines, indent+lastIndent+"Network interface: "+settings.Interface)
if settings.PrivateKey != "" {
lines = append(lines, indent+lastIndent+"Private key is set")
}
if settings.PreSharedKey != "" {
lines = append(lines, indent+lastIndent+"Pre-shared key is set")
}
if settings.Address != nil {
lines = append(lines, indent+lastIndent+"Address: "+settings.Address.String())
}
return lines
}
func (settings *Wireguard) read(r reader) (err error) {
settings.PrivateKey, err = r.env.Get("WIREGUARD_PRIVATE_KEY",
params.CaseSensitiveValue(), params.Unset(), params.Compulsory())
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_PRIVATE_KEY: %w", err)
}
settings.PreSharedKey, err = r.env.Get("WIREGUARD_PRESHARED_KEY",
params.CaseSensitiveValue(), params.Unset())
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_PRESHARED_KEY: %w", err)
}
addressString, err := r.env.Get("WIREGUARD_ADDRESS", params.Compulsory())
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_ADDRESS: %w", err)
}
ip, ipNet, err := net.ParseCIDR(addressString)
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_ADDRESS: %w", err)
}
ipNet.IP = ip
settings.Address = ipNet
settings.Interface, err = r.env.Get("WIREGUARD_INTERFACE", params.Default("wg0"))
if err != nil {
return fmt.Errorf("environment variable WIREGUARD_INTERFACE: %w", err)
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -73,7 +73,7 @@ func Test_versions(t *testing.T) {
"Mullvad": {
model: models.MullvadServer{},
version: allServers.Mullvad.Version,
digest: "2a009192",
digest: "ec56f19d",
},
"Nordvpn": {
model: models.NordvpnServer{},
@@ -128,7 +128,7 @@ func Test_versions(t *testing.T) {
"Windscribe": {
model: models.WindscribeServer{},
version: allServers.Windscribe.Version,
digest: "6f6c16d6",
digest: "4bd0fc4f",
},
}
for name, testCase := range testCases {

View File

@@ -1,7 +1,8 @@
package constants
const (
OpenVPN = "openvpn"
OpenVPN = "openvpn"
Wireguard = "wireguard"
)
const (

View File

@@ -6,7 +6,7 @@ import (
)
type Connection struct {
// Type is the connection type and can be "openvpn"
// Type is the connection type and can be "openvpn" or "wireguard"
Type string `json:"type"`
// IP is the VPN server IP address.
IP net.IP `json:"ip"`
@@ -15,13 +15,17 @@ type Connection struct {
// Protocol can be "tcp" or "udp".
Protocol string `json:"protocol"`
// Hostname is used for IPVanish, IVPN, Privado
// and Windscribe for TLS verification
// and Windscribe for TLS verification.
Hostname string `json:"hostname"`
// PubKey is the public key of the VPN server,
// used only for Wireguard.
PubKey string `json:"pubkey"`
}
func (c *Connection) Equal(other Connection) bool {
return c.IP.Equal(other.IP) && c.Port == other.Port &&
c.Protocol == other.Protocol && c.Hostname == other.Hostname
c.Protocol == other.Protocol && c.Hostname == other.Hostname &&
c.PubKey == other.PubKey
}
func (c Connection) OpenVPNRemoteLine() (line string) {

View File

@@ -48,6 +48,7 @@ type IvpnServer struct {
}
type MullvadServer struct {
VPN string `json:"vpn"`
IPs []net.IP `json:"ips"`
IPsV6 []net.IP `json:"ipsv6"`
Country string `json:"country"`
@@ -55,6 +56,7 @@ type MullvadServer struct {
Hostname string `json:"hostname"`
ISP string `json:"isp"`
Owned bool `json:"owned"`
WgPubKey string `json:"wgpubkey,omitempty"`
}
type NordvpnServer struct { //nolint:maligned
@@ -149,9 +151,11 @@ type VyprvpnServer struct {
}
type WindscribeServer struct {
VPN string `json:"vpn"`
Region string `json:"region"`
City string `json:"city"`
Hostname string `json:"hostname"`
OvpnX509 string `json:"x509"`
OvpnX509 string `json:"x509,omitempty"`
WgPubKey string `json:"wgpubkey,omitempty"`
IPs []net.IP `json:"ips"`
}

View File

@@ -0,0 +1,7 @@
package netlink
import "github.com/vishvananda/netlink"
func (n *NetLink) AddrAdd(link netlink.Link, addr *netlink.Addr) error {
return netlink.AddrAdd(link, addr)
}

View File

@@ -0,0 +1,14 @@
package netlink
import "github.com/vishvananda/netlink"
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . NetLinker
var _ NetLinker = (*NetLink)(nil)
type NetLinker interface {
AddrAdd(link netlink.Link, addr *netlink.Addr) error
RouteAdd(route *netlink.Route) error
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
}

View File

@@ -0,0 +1,7 @@
package netlink
type NetLink struct{}
func New() *NetLink {
return &NetLink{}
}

View File

@@ -0,0 +1,7 @@
package netlink
import "github.com/vishvananda/netlink"
func (n *NetLink) RouteAdd(route *netlink.Route) error {
return netlink.RouteAdd(route)
}

11
internal/netlink/rule.go Normal file
View File

@@ -0,0 +1,11 @@
package netlink
import "github.com/vishvananda/netlink"
func (n *NetLink) RuleAdd(rule *netlink.Rule) error {
return netlink.RuleAdd(rule)
}
func (n *NetLink) RuleDel(rule *netlink.Rule) error {
return netlink.RuleDel(rule)
}

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -19,7 +20,8 @@ func Test_Cyberghost_filterServers(t *testing.T) {
err error
}{
"no servers": {
err: errors.New("no server found: for protocol udp"),
selection: configuration.ServerSelection{VPN: constants.OpenVPN},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"servers without filter defaults to UDP": {
servers: []models.CyberghostServer{

View File

@@ -9,16 +9,8 @@ import (
func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
connection models.Connection, err error) {
var port uint16 = 1194
protocol := constants.UDP
if selection.OpenVPN.TCP {
port = 443
protocol = constants.TCP
}
if selection.OpenVPN.CustomPort > 0 {
port = selection.OpenVPN.CustomPort
}
port := getPort(selection)
protocol := getProtocol(selection)
servers, err := m.filterServers(selection)
if err != nil {
@@ -33,6 +25,7 @@ func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
IP: IP,
Port: port,
Protocol: protocol,
PubKey: server.WgPubKey, // Wireguard only
}
connections = append(connections, connection)
}
@@ -44,3 +37,33 @@ func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
return utils.PickRandomConnection(connections, m.randSource), nil
}
func getPort(selection configuration.ServerSelection) (port uint16) {
switch selection.VPN {
case constants.Wireguard:
customPort := selection.Wireguard.CustomPort
if customPort > 0 {
return customPort
}
const defaultPort = 51820
return defaultPort
default: // OpenVPN
customPort := selection.OpenVPN.CustomPort
if customPort > 0 {
return customPort
}
port = 1194
if selection.OpenVPN.TCP {
port = 443
}
return port
}
}
func getProtocol(selection configuration.ServerSelection) (protocol string) {
protocol = constants.UDP
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
protocol = constants.TCP
}
return protocol
}

View File

@@ -0,0 +1,204 @@
package mullvad
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Mullvad_GetConnection(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
servers []models.MullvadServer
selection configuration.ServerSelection
connection models.Connection
err error
}{
"no server available": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"no filter": {
servers: []models.MullvadServer{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(1, 1, 1, 1),
Port: 1194,
Protocol: constants.UDP,
},
},
"target IP": {
selection: configuration.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
},
servers: []models.MullvadServer{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
"with filter": {
selection: configuration.ServerSelection{
Hostnames: []string{"b"},
},
servers: []models.MullvadServer{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
connection, err := m.GetConnection(testCase.selection)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.connection, connection)
})
}
}
func Test_getPort(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
port uint16
}{
"default": {
port: 1194,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
port: 1194,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
port: 443,
},
"OpenVPN custom port": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
CustomPort: 1234,
},
},
port: 1234,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
port: 51820,
},
"Wireguard custom port": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
Wireguard: configuration.WireguardSelection{
CustomPort: 1234,
},
},
port: 1234,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
port := getPort(testCase.selection)
assert.Equal(t, testCase.port, port)
})
}
}
func Test_getProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
protocol string
}{
"default": {
protocol: constants.UDP,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
protocol: constants.UDP,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
protocol: constants.TCP,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
protocol: constants.UDP,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := getProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}

View File

@@ -11,6 +11,7 @@ func (m *Mullvad) filterServers(selection configuration.ServerSelection) (
for _, server := range m.servers {
switch {
case
server.VPN != selection.VPN,
utils.FilterByPossibilities(server.Country, selection.Countries),
utils.FilterByPossibilities(server.City, selection.Cities),
utils.FilterByPossibilities(server.ISP, selection.ISPs),

View File

@@ -0,0 +1,143 @@
package mullvad
import (
"errors"
"math/rand"
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Mullvad_filterServers(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
servers []models.MullvadServer
selection configuration.ServerSelection
filtered []models.MullvadServer
err error
}{
"no server available": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"no filter": {
servers: []models.MullvadServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
filtered: []models.MullvadServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
},
"filter OpenVPN out": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
servers: []models.MullvadServer{
{VPN: constants.OpenVPN, Hostname: "a"},
{VPN: constants.Wireguard, Hostname: "b"},
{VPN: constants.OpenVPN, Hostname: "c"},
},
filtered: []models.MullvadServer{
{VPN: constants.Wireguard, Hostname: "b"},
},
},
"filter by country": {
selection: configuration.ServerSelection{
Countries: []string{"b"},
},
servers: []models.MullvadServer{
{Country: "a"},
{Country: "b"},
{Country: "c"},
},
filtered: []models.MullvadServer{
{Country: "b"},
},
},
"filter by city": {
selection: configuration.ServerSelection{
Cities: []string{"b"},
},
servers: []models.MullvadServer{
{City: "a"},
{City: "b"},
{City: "c"},
},
filtered: []models.MullvadServer{
{City: "b"},
},
},
"filter by ISP": {
selection: configuration.ServerSelection{
ISPs: []string{"b"},
},
servers: []models.MullvadServer{
{ISP: "a"},
{ISP: "b"},
{ISP: "c"},
},
filtered: []models.MullvadServer{
{ISP: "b"},
},
},
"filter by hostname": {
selection: configuration.ServerSelection{
Hostnames: []string{"b"},
},
servers: []models.MullvadServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
filtered: []models.MullvadServer{
{Hostname: "b"},
},
},
"filter by owned": {
selection: configuration.ServerSelection{
Owned: true,
},
servers: []models.MullvadServer{
{Hostname: "a"},
{Hostname: "b", Owned: true},
{Hostname: "c"},
},
filtered: []models.MullvadServer{
{Hostname: "b", Owned: true},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
servers, err := m.filterServers(testCase.selection)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.filtered, servers)
})
}
}

View File

@@ -19,6 +19,8 @@ var ErrNoServerFound = errors.New("no server found")
func NoServerFoundError(selection configuration.ServerSelection) (err error) {
var messageParts []string
messageParts = append(messageParts, "VPN "+selection.VPN)
protocol := constants.UDP
if selection.OpenVPN.TCP {
protocol = constants.TCP

View File

@@ -0,0 +1,34 @@
package utils
import (
"net"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/wireguard"
)
func BuildWireguardSettings(connection models.Connection,
userSettings configuration.Wireguard) (settings wireguard.Settings) {
settings.PrivateKey = userSettings.PrivateKey
settings.PublicKey = connection.PubKey
settings.PreSharedKey = userSettings.PreSharedKey
settings.InterfaceName = userSettings.Interface
const routePriority = 101 // 100 is to receive external connections
settings.RulePriority = routePriority
settings.Endpoint = new(net.UDPAddr)
settings.Endpoint.IP = make(net.IP, len(connection.IP))
copy(settings.Endpoint.IP, connection.IP)
settings.Endpoint.Port = int(connection.Port)
address := new(net.IPNet)
address.IP = make(net.IP, len(userSettings.Address.IP))
copy(address.IP, userSettings.Address.IP)
address.Mask = make(net.IPMask, len(userSettings.Address.Mask))
copy(address.Mask, userSettings.Address.Mask)
settings.Addresses = append(settings.Addresses, address)
return settings
}

View File

@@ -9,16 +9,8 @@ import (
func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
connection models.Connection, err error) {
protocol := constants.UDP
var port uint16 = 443
if selection.OpenVPN.TCP {
protocol = constants.TCP
port = 1194
}
if selection.OpenVPN.CustomPort > 0 {
port = selection.OpenVPN.CustomPort
}
port := getPort(selection)
protocol := getProtocol(selection)
servers, err := w.filterServers(selection)
if err != nil {
@@ -34,6 +26,7 @@ func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
Port: port,
Protocol: protocol,
Hostname: server.OvpnX509,
PubKey: server.WgPubKey,
}
connections = append(connections, connection)
}
@@ -45,3 +38,33 @@ func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
return utils.PickRandomConnection(connections, w.randSource), nil
}
func getPort(selection configuration.ServerSelection) (port uint16) {
switch selection.VPN {
case constants.Wireguard:
customPort := selection.Wireguard.CustomPort
if customPort > 0 {
return customPort
}
const defaultPort = 1194
return defaultPort
default: // OpenVPN
customPort := selection.OpenVPN.CustomPort
if customPort > 0 {
return customPort
}
port = 1194
if selection.OpenVPN.TCP {
port = 443
}
return port
}
}
func getProtocol(selection configuration.ServerSelection) (protocol string) {
protocol = constants.UDP
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
protocol = constants.TCP
}
return protocol
}

View File

@@ -0,0 +1,204 @@
package windscribe
import (
"errors"
"math/rand"
"net"
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Windscribe_GetConnection(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
servers []models.WindscribeServer
selection configuration.ServerSelection
connection models.Connection
err error
}{
"no server available": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"no filter": {
servers: []models.WindscribeServer{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(1, 1, 1, 1),
Port: 1194,
Protocol: constants.UDP,
},
},
"target IP": {
selection: configuration.ServerSelection{
TargetIP: net.IPv4(2, 2, 2, 2),
},
servers: []models.WindscribeServer{
{IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
"with filter": {
selection: configuration.ServerSelection{
Hostnames: []string{"b"},
},
servers: []models.WindscribeServer{
{Hostname: "a", IPs: []net.IP{net.IPv4(1, 1, 1, 1)}},
{Hostname: "b", IPs: []net.IP{net.IPv4(2, 2, 2, 2)}},
{Hostname: "a", IPs: []net.IP{net.IPv4(3, 3, 3, 3)}},
},
connection: models.Connection{
IP: net.IPv4(2, 2, 2, 2),
Port: 1194,
Protocol: constants.UDP,
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
connection, err := m.GetConnection(testCase.selection)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.connection, connection)
})
}
}
func Test_getPort(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
port uint16
}{
"default": {
port: 1194,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
port: 1194,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
port: 443,
},
"OpenVPN custom port": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
CustomPort: 1234,
},
},
port: 1234,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
port: 1194,
},
"Wireguard custom port": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
Wireguard: configuration.WireguardSelection{
CustomPort: 1234,
},
},
port: 1234,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
port := getPort(testCase.selection)
assert.Equal(t, testCase.port, port)
})
}
}
func Test_getProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
protocol string
}{
"default": {
protocol: constants.UDP,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
protocol: constants.UDP,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
protocol: constants.TCP,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
protocol: constants.UDP,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := getProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}

View File

@@ -11,6 +11,7 @@ func (w *Windscribe) filterServers(selection configuration.ServerSelection) (
for _, server := range w.servers {
switch {
case
server.VPN != selection.VPN,
utils.FilterByPossibilities(server.Region, selection.Regions),
utils.FilterByPossibilities(server.City, selection.Cities),
utils.FilterByPossibilities(server.Hostname, selection.Hostnames):

View File

@@ -0,0 +1,117 @@
package windscribe
import (
"errors"
"math/rand"
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Windscribe_filterServers(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
servers []models.WindscribeServer
selection configuration.ServerSelection
filtered []models.WindscribeServer
err error
}{
"no server available": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
err: errors.New("no server found: for VPN openvpn; protocol udp"),
},
"no filter": {
servers: []models.WindscribeServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
filtered: []models.WindscribeServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
},
"filter OpenVPN out": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
servers: []models.WindscribeServer{
{VPN: constants.OpenVPN, Hostname: "a"},
{VPN: constants.Wireguard, Hostname: "b"},
{VPN: constants.OpenVPN, Hostname: "c"},
},
filtered: []models.WindscribeServer{
{VPN: constants.Wireguard, Hostname: "b"},
},
},
"filter by region": {
selection: configuration.ServerSelection{
Regions: []string{"b"},
},
servers: []models.WindscribeServer{
{Region: "a"},
{Region: "b"},
{Region: "c"},
},
filtered: []models.WindscribeServer{
{Region: "b"},
},
},
"filter by city": {
selection: configuration.ServerSelection{
Cities: []string{"b"},
},
servers: []models.WindscribeServer{
{City: "a"},
{City: "b"},
{City: "c"},
},
filtered: []models.WindscribeServer{
{City: "b"},
},
},
"filter by hostname": {
selection: configuration.ServerSelection{
Hostnames: []string{"b"},
},
servers: []models.WindscribeServer{
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
filtered: []models.WindscribeServer{
{Hostname: "b"},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
randSource := rand.NewSource(0)
m := New(testCase.servers, randSource)
servers, err := m.filterServers(testCase.selection)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.filtered, servers)
})
}
}

View File

@@ -22,10 +22,12 @@ type serverData struct {
Provider string `json:"provider"`
IPv4 string `json:"ipv4_addr_in"`
IPv6 string `json:"ipv6_addr_in"`
Type string `json:"type"`
PubKey string `json:"pubkey"` // Wireguard public key
}
func fetchAPI(ctx context.Context, client *http.Client) (data []serverData, err error) {
const url = "https://api.mullvad.net/www/relays/openvpn/"
const url = "https://api.mullvad.net/www/relays/all/"
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {

View File

@@ -6,14 +6,17 @@ import (
"net"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
)
type hostToServer map[string]models.MullvadServer
var (
ErrParseIPv4 = errors.New("cannot parse IPv4 address")
ErrParseIPv6 = errors.New("cannot parse IPv6 address")
ErrNoIP = errors.New("no IP address for VPN server")
ErrParseIPv4 = errors.New("cannot parse IPv4 address")
ErrParseIPv6 = errors.New("cannot parse IPv6 address")
ErrVPNTypeNotSupported = errors.New("VPN type not supported")
)
func (hts hostToServer) add(data serverData) (err error) {
@@ -21,14 +24,8 @@ func (hts hostToServer) add(data serverData) (err error) {
return
}
ipv4 := net.ParseIP(data.IPv4)
if ipv4 == nil || ipv4.To4() == nil {
return fmt.Errorf("%w: %s", ErrParseIPv4, data.IPv4)
}
ipv6 := net.ParseIP(data.IPv6)
if ipv6 == nil || ipv6.To4() != nil {
return fmt.Errorf("%w: %s", ErrParseIPv6, data.IPv6)
if data.IPv4 == "" && data.IPv6 == "" {
return ErrNoIP
}
server, ok := hts[data.Hostname]
@@ -36,13 +33,40 @@ func (hts hostToServer) add(data serverData) (err error) {
return nil
}
switch data.Type {
case "openvpn":
server.VPN = constants.OpenVPN
case "wireguard":
server.VPN = constants.Wireguard
case "bridge":
// ignore bridge servers
return nil
default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotSupported, data.Type)
}
if data.IPv4 != "" {
ipv4 := net.ParseIP(data.IPv4)
if ipv4 == nil || ipv4.To4() == nil {
return fmt.Errorf("%w: %s", ErrParseIPv4, data.IPv4)
}
server.IPs = []net.IP{ipv4}
}
if data.IPv6 != "" {
ipv6 := net.ParseIP(data.IPv6)
if ipv6 == nil || ipv6.To4() != nil {
return fmt.Errorf("%w: %s", ErrParseIPv6, data.IPv6)
}
server.IPsV6 = []net.IP{ipv6}
}
server.Country = data.Country
server.City = strings.ReplaceAll(data.City, ",", "")
server.Hostname = data.Hostname
server.ISP = data.Provider
server.Owned = data.Owned
server.IPs = []net.IP{ipv4}
server.IPsV6 = []net.IP{ipv6}
server.WgPubKey = data.PubKey
hts[data.Hostname] = server

View File

@@ -29,6 +29,7 @@ type groupData struct {
City string `json:"city"`
Nodes []serverData `json:"nodes"`
OvpnX509 string `json:"ovpn_x509"`
WgPubKey string `json:"wg_pubkey"`
}
type serverData struct {

View File

@@ -9,10 +9,14 @@ import (
"net"
"net/http"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
)
var ErrNotEnoughServers = errors.New("not enough servers found")
var (
ErrNotEnoughServers = errors.New("not enough servers found")
ErrNoWireguardKey = errors.New("no wireguard public key found")
)
func GetServers(ctx context.Context, client *http.Client, minServers int) (
servers []models.WindscribeServer, err error) {
@@ -26,19 +30,17 @@ func GetServers(ctx context.Context, client *http.Client, minServers int) (
for _, group := range regionData.Groups {
city := group.City
x5090Name := group.OvpnX509
wgPubKey := group.WgPubKey
for _, node := range group.Nodes {
const maxIPsPerNode = 3
ips := make([]net.IP, 0, maxIPsPerNode)
ips := make([]net.IP, 0, 2) // nolint:gomnd
if node.IP != nil {
ips = append(ips, node.IP)
}
if node.IP2 != nil {
ips = append(ips, node.IP2)
}
// if node.IP3 != nil { // Wireguard + Stealth
// ips = append(ips, node.IP3)
// }
server := models.WindscribeServer{
VPN: constants.OpenVPN,
Region: region,
City: city,
Hostname: node.Hostname,
@@ -46,6 +48,18 @@ func GetServers(ctx context.Context, client *http.Client, minServers int) (
IPs: ips,
}
servers = append(servers, server)
if node.IP3 == nil { // Wireguard + Stealth
continue
} else if wgPubKey == "" {
return nil, fmt.Errorf("%w: for node %s", ErrNoWireguardKey, node.Hostname)
}
server.VPN = constants.Wireguard
server.OvpnX509 = ""
server.WgPubKey = wgPubKey
server.IPs = []net.IP{node.IP3}
servers = append(servers, server)
}
}
}

View File

@@ -10,6 +10,9 @@ func sortServers(servers []models.WindscribeServer) {
sort.Slice(servers, func(i, j int) bool {
if servers[i].Region == servers[j].Region {
if servers[i].City == servers[j].City {
if servers[i].Hostname == servers[j].Hostname {
return servers[i].VPN < servers[j].VPN
}
return servers[i].Hostname < servers[j].Hostname
}
return servers[i].City < servers[j].City

View File

@@ -10,6 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/publicip"
@@ -37,6 +38,7 @@ type Loop struct {
versionInfo bool
// Configurators
openvpnConf openvpn.Interface
netLinker netlink.NetLinker
fw firewallConfigurer
routing routing.VPNGetter
portForward portforward.StartStopper
@@ -67,7 +69,7 @@ const (
func NewLoop(vpnSettings configuration.VPN,
allServers models.AllServers, openvpnConf openvpn.Interface,
fw firewallConfigurer, routing routing.VPNGetter,
netLinker netlink.NetLinker, fw firewallConfigurer, routing routing.VPNGetter,
portForward portforward.StartStopper, starter command.Starter,
publicip publicip.Looper, dnsLooper dns.Looper,
logger logging.Logger, client *http.Client,
@@ -86,6 +88,7 @@ func NewLoop(vpnSettings configuration.VPN,
buildInfo: buildInfo,
versionInfo: versionInfo,
openvpnConf: openvpnConf,
netLinker: netLinker,
fw: fw,
routing: routing,
portForward: portForward,

View File

@@ -30,8 +30,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
providerConf := provider.New(settings.Provider.Name, allServers, time.Now)
vpnRunner, serverName, err := setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.starter, l.logger)
var vpnRunner vpnRunner
var serverName, vpnInterface string
var err error
if settings.Type == constants.OpenVPN {
vpnInterface = settings.OpenVPN.Interface
vpnRunner, serverName, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.starter, l.logger)
} else { // Wireguard
vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, err = setupWireguard(ctx, l.netLinker, l.fw, providerConf, settings, l.logger)
}
if err != nil {
l.crashed(ctx, err)
continue
@@ -40,7 +49,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
portForwarding: settings.Provider.PortForwarding.Enabled,
serverName: serverName,
portForwarder: providerConf,
vpnIntf: settings.OpenVPN.Interface,
vpnIntf: vpnInterface,
}
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())

View File

@@ -17,13 +17,6 @@ type tunnelUpData struct {
}
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
vpnDestination, err := l.routing.VPNDestinationIP()
if err != nil {
l.logger.Warn(err.Error())
} else {
l.logger.Info("VPN routing IP address: " + vpnDestination.String())
}
if l.dnsLooper.GetSettings().Enabled {
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
}
@@ -40,7 +33,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
}
}
err = l.startPortForwarding(ctx, data)
err := l.startPortForwarding(ctx, data)
if err != nil {
l.logger.Error(err.Error())
}

45
internal/vpn/wireguard.go Normal file
View File

@@ -0,0 +1,45 @@
package vpn
import (
"context"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/wireguard"
)
var (
errGetServer = errors.New("failed finding a VPN server")
errCreateWireguard = errors.New("failed creating Wireguard")
)
// setupWireguard sets Wireguard up using the configurators and settings given.
// It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
fw firewall.VPNConnectionSetter, providerConf provider.Provider,
settings configuration.VPN, logger wireguard.Logger) (
wireguarder wireguard.Wireguarder, serverName string, err error) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
if err != nil {
return nil, "", fmt.Errorf("%w: %s", errGetServer, err)
}
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard)
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
if err != nil {
return nil, "", fmt.Errorf("%w: %s", errCreateWireguard, err)
}
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil {
return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
}
return wireguarder, connection.Hostname, nil
}

View File

@@ -0,0 +1,25 @@
package wireguard
import (
"fmt"
"net"
"github.com/vishvananda/netlink"
)
func (w *Wireguard) addAddresses(link netlink.Link,
addresses []*net.IPNet) (err error) {
for _, ipNet := range addresses {
address := &netlink.Addr{
IPNet: ipNet,
}
err = w.netlink.AddrAdd(link, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link %s",
err, address, link.Attrs().Name)
}
}
return nil
}

View File

@@ -0,0 +1,94 @@
package wireguard
import (
"errors"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
ipNetTwo := &net.IPNet{IP: net.IPv4(4, 5, 6, 7), Mask: net.IPv4Mask(255, 255, 255, 128)}
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
addrs []*net.IPNet
expectedAddrs []*netlink.Addr
addrAddErrs []error
err error
}{
"success": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
},
addrAddErrs: []error{nil, nil},
},
"first add error": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{
{IPNet: ipNetOne},
},
addrAddErrs: []error{errDummy},
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
},
"second add error": {
link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
},
addrAddErrs: []error{nil, errDummy},
err: errors.New("dummy: when adding address 4.5.6.7/25 to link a_bridge"),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
require.Equal(t, len(testCase.expectedAddrs), len(testCase.addrAddErrs))
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
for i := range testCase.expectedAddrs {
netLinker.EXPECT().
AddrAdd(testCase.link, testCase.expectedAddrs[i]).
Return(testCase.addrAddErrs[i])
}
err := wg.addAddresses(testCase.link, testCase.addrs)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,59 @@
package wireguard
import "sort"
type closer struct {
operation string
step step
close func() error
closed bool
}
type closers []closer
func (c *closers) add(operation string, step step,
closeFunc func() error) {
closer := closer{
operation: operation,
step: step,
close: closeFunc,
}
*c = append(*c, closer)
}
func (c *closers) cleanup(logger Logger) {
closers := *c
sort.Slice(closers, func(i, j int) bool {
return closers[i].step < closers[j].step
})
for i, closer := range closers {
if closer.closed {
continue
} else {
closers[i].closed = true
}
logger.Debug(closer.operation + "...")
err := closer.close()
if err != nil {
logger.Error("failed " + closer.operation + ": " + err.Error())
}
}
}
type step int
const (
// stepOne closes the wireguard controller client,
// and removes the IP rule.
stepOne step = iota
// stepTwo closes the UAPI listener.
stepTwo
// stepThree closes the UAPI file.
stepThree
// stepFour closes the Wireguard device.
stepFour
// stepFive closes the bind connection and the TUN device file.
stepFive
)

View File

@@ -0,0 +1,57 @@
package wireguard
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_closers(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var ACloseCalled, BCloseCalled, CCloseCalled bool
var (
AErr error
BErr = errors.New("B failed")
CErr = errors.New("C failed")
)
var closers closers
closers.add("closing A", stepFive, func() error {
ACloseCalled = true
return AErr
})
closers.add("closing B", stepThree, func() error {
BCloseCalled = true
return BErr
})
closers.add("closing C", stepTwo, func() error {
CCloseCalled = true
return CErr
})
logger := NewMockLogger(ctrl)
prevCall := logger.EXPECT().Debug("closing C...")
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
logger.EXPECT().Debug("closing A...").After(prevCall)
closers.cleanup(logger)
closers.cleanup(logger) // run twice should not close already closed
for _, closer := range closers {
assert.True(t, closer.closed)
}
assert.True(t, ACloseCalled)
assert.True(t, BCloseCalled)
assert.True(t, CCloseCalled)
}

View File

@@ -0,0 +1,86 @@
package wireguard
import (
"errors"
"fmt"
"net"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
errMakeConfig = errors.New("cannot make device configuration")
errConfigureDevice = errors.New("cannot configure device")
)
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
deviceConfig, err := makeDeviceConfig(settings)
if err != nil {
return fmt.Errorf("%w: %s", errMakeConfig, err)
}
err = client.ConfigureDevice(settings.InterfaceName, deviceConfig)
if err != nil {
return fmt.Errorf("%w: %s", errConfigureDevice, err)
}
return nil
}
func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
privateKey, err := wgtypes.ParseKey(settings.PrivateKey)
if err != nil {
return config, ErrPrivateKeyInvalid
}
publicKey, err := wgtypes.ParseKey(settings.PublicKey)
if err != nil {
return config, fmt.Errorf("%w: %s", ErrPublicKeyInvalid, settings.PublicKey)
}
var preSharedKey *wgtypes.Key
if settings.PreSharedKey != "" {
preSharedKeyValue, err := wgtypes.ParseKey(settings.PreSharedKey)
if err != nil {
return config, ErrPreSharedKeyInvalid
}
preSharedKey = &preSharedKeyValue
}
firewallMark := settings.FirewallMark
config = wgtypes.Config{
PrivateKey: &privateKey,
ReplacePeers: true,
FirewallMark: &firewallMark,
Peers: []wgtypes.PeerConfig{
{
PublicKey: publicKey,
PresharedKey: preSharedKey,
AllowedIPs: []net.IPNet{
*allIPv4(),
*allIPv6(),
},
ReplaceAllowedIPs: true,
Endpoint: settings.Endpoint,
},
},
}
return config, nil
}
func allIPv4() (ipNet *net.IPNet) {
return &net.IPNet{
IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
}
}
func allIPv6() (ipNet *net.IPNet) {
return &net.IPNet{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
}
}

View File

@@ -0,0 +1,126 @@
package wireguard
import (
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func Test_makeDeviceConfig(t *testing.T) {
t.Parallel()
const (
validKey1 = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
validKey2 = "aPjc9US5ICB30D1P4glR9tO7bkB2Ga+KZiFqnoypBHk="
validKey3 = "gFIW0lTmBYEucynoIg+XmeWckDUXTcC4Po5ijR5G+HM="
)
parseKey := func(t *testing.T, s string) *wgtypes.Key {
t.Helper()
key, err := wgtypes.ParseKey(s)
require.NoError(t, err)
return &key
}
intPtr := func(n int) *int { return &n }
testCases := map[string]struct {
settings Settings
config wgtypes.Config
err error
}{
"bad private key": {
settings: Settings{
PrivateKey: "bad key",
},
err: ErrPrivateKeyInvalid,
},
"bad public key": {
settings: Settings{
PrivateKey: validKey1,
PublicKey: "bad key",
},
err: errors.New("cannot parse public key: bad key"),
},
"bad pre-shared key": {
settings: Settings{
PrivateKey: validKey1,
PublicKey: validKey2,
PreSharedKey: "bad key",
},
err: errors.New("cannot parse pre-shared key"),
},
"valid settings": {
settings: Settings{
PrivateKey: validKey1,
PublicKey: validKey2,
PreSharedKey: validKey3,
FirewallMark: 9876,
Endpoint: &net.UDPAddr{
IP: net.IPv4(99, 99, 99, 99),
Port: 51820,
},
},
config: wgtypes.Config{
PrivateKey: parseKey(t, validKey1),
ReplacePeers: true,
FirewallMark: intPtr(9876),
Peers: []wgtypes.PeerConfig{
{
PublicKey: *parseKey(t, validKey2),
PresharedKey: parseKey(t, validKey3),
AllowedIPs: []net.IPNet{
{
IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
},
{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
},
},
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{
IP: net.IPv4(99, 99, 99, 99),
Port: 51820,
},
},
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
config, err := makeDeviceConfig(testCase.settings)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.config, config)
})
}
}
func Test_allIPv4(t *testing.T) {
t.Parallel()
ipNet := allIPv4()
assert.Equal(t, "0.0.0.0/0", ipNet.String())
}
func Test_allIPv6(t *testing.T) {
t.Parallel()
ipNet := allIPv6()
assert.Equal(t, "::/0", ipNet.String())
}

View File

@@ -0,0 +1,30 @@
package wireguard
import "github.com/qdm12/gluetun/internal/netlink"
var _ Wireguarder = (*Wireguard)(nil)
type Wireguarder interface {
Runner
Runner
}
type Wireguard struct {
logger Logger
settings Settings
netlink netlink.NetLinker
}
func New(settings Settings, netlink NetLinker,
logger Logger) (w *Wireguard, err error) {
settings.SetDefaults()
if err := settings.Check(); err != nil {
return nil, err
}
return &Wireguard{
logger: logger,
settings: settings,
netlink: netlink,
}, nil
}

View File

@@ -0,0 +1,80 @@
package wireguard
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_New(t *testing.T) {
t.Parallel()
const validKeyString = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
logger := NewMockLogger(nil)
netLinker := NewMockNetLinker(nil)
testCases := map[string]struct {
settings Settings
wireguard *Wireguard
err error
}{
"bad settings": {
settings: Settings{
PrivateKey: "",
},
err: ErrPrivateKeyMissing,
},
"minimal valid settings": {
settings: Settings{
PrivateKey: validKeyString,
PublicKey: validKeyString,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
},
Addresses: []*net.IPNet{{
IP: net.IPv4(5, 6, 7, 8),
Mask: net.IPv4Mask(255, 255, 255, 255)},
},
FirewallMark: 100,
},
wireguard: &Wireguard{
logger: logger,
netlink: netLinker,
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKeyString,
PublicKey: validKeyString,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{
IP: net.IPv4(5, 6, 7, 8),
Mask: net.IPv4Mask(255, 255, 255, 255)},
},
FirewallMark: 100,
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
wireguard, err := New(testCase.settings, netLinker, logger)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.wireguard, wireguard)
})
}
}

26
internal/wireguard/log.go Normal file
View File

@@ -0,0 +1,26 @@
package wireguard
import (
"fmt"
"golang.zx2c4.com/wireguard/device"
)
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
return &device.Logger{
Verbosef: func(format string, args ...interface{}) {
logger.Debug(fmt.Sprintf(format, args...))
},
Errorf: func(format string, args ...interface{}) {
logger.Error(fmt.Sprintf(format, args...))
},
}
}

View File

@@ -0,0 +1,70 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/wireguard (interfaces: Logger)
// Package wireguard is a generated GoMock package.
package wireguard
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Error mocks base method.
func (m *MockLogger) Error(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Error", arg0)
}
// Error indicates an expected call of Error.
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
}
// Info mocks base method.
func (m *MockLogger) Info(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0)
}
// Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
}

View File

@@ -0,0 +1,23 @@
package wireguard
import (
"testing"
"github.com/golang/mock/gomock"
)
func Test_makeDeviceLogger(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
logger := NewMockLogger(ctrl)
deviceLogger := makeDeviceLogger(logger)
logger.EXPECT().Debug("test 1")
deviceLogger.Verbosef("test %d", 1)
logger.EXPECT().Error("test 2")
deviceLogger.Errorf("test %d", 2)
}

View File

@@ -0,0 +1,113 @@
// +build netlink
package wireguard
import (
"fmt"
"math/rand"
"net"
"testing"
inetlink "github.com/qdm12/gluetun/internal/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
t.Parallel()
netlinker := inetlink.New()
wg := &Wireguard{
netlink: netlinker,
}
intfName := "test_" + fmt.Sprint(rand.Intn(10000)) //nolint:gosec
// Add link
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = intfName
link := &netlink.Bridge{
LinkAttrs: linkAttrs,
}
err := netlink.LinkAdd(link)
require.NoError(t, err)
defer func() {
err = netlink.LinkDel(link)
assert.NoError(t, err)
}()
addresses := []*net.IPNet{
{IP: net.IP{1, 2, 3, 4}, Mask: net.IPv4Mask(255, 255, 255, 255)},
{IP: net.IP{5, 6, 7, 8}, Mask: net.IPv4Mask(255, 255, 255, 255)},
}
// Success
err = wg.addAddresses(link, addresses)
require.NoError(t, err)
netlinkAddresses, err := netlink.AddrList(link, netlink.FAMILY_ALL)
require.NoError(t, err)
require.Equal(t, len(addresses), len(netlinkAddresses))
for i, netlinkAddress := range netlinkAddresses {
ipNet := netlinkAddress.IPNet
assert.Equal(t, addresses[i], ipNet)
}
// Existing address cannot be added
err = wg.addAddresses(link, addresses)
require.Error(t, err)
assert.Equal(t, "file exists: when adding address 1.2.3.4/32 to link test_8081", err.Error())
}
func Test_netlink_Wireguard_addRule(t *testing.T) {
t.Parallel()
netlinker := inetlink.New()
wg := &Wireguard{
netlink: netlinker,
}
rulePriority := 10000
const firewallMark = 999
cleanup, err := wg.addRule(rulePriority, firewallMark)
require.NoError(t, err)
defer func() {
err := cleanup()
assert.NoError(t, err)
}()
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
require.NoError(t, err)
var rule netlink.Rule
var ruleFound bool
for _, rule = range rules {
if rule.Mark == firewallMark {
ruleFound = true
break
}
}
require.True(t, ruleFound)
expectedRule := netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: 4294967295,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
}
assert.Equal(t, expectedRule, rule)
// Existing rule cannot be added
nilCleanup, err := wg.addRule(rulePriority, firewallMark)
if nilCleanup != nil {
_ = nilCleanup() // in case it succeeds
}
require.Error(t, err)
assert.Equal(t, "file exists: when adding rule: ip rule 10000: from <nil> table 999", err.Error())
}

View File

@@ -0,0 +1,12 @@
package wireguard
import "github.com/vishvananda/netlink"
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface {
AddrAdd(link netlink.Link, addr *netlink.Addr) error
RouteAdd(route *netlink.Route) error
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
}

View File

@@ -0,0 +1,91 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/wireguard (interfaces: NetLinker)
// Package wireguard is a generated GoMock package.
package wireguard
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
netlink "github.com/vishvananda/netlink"
)
// MockNetLinker is a mock of NetLinker interface.
type MockNetLinker struct {
ctrl *gomock.Controller
recorder *MockNetLinkerMockRecorder
}
// MockNetLinkerMockRecorder is the mock recorder for MockNetLinker.
type MockNetLinkerMockRecorder struct {
mock *MockNetLinker
}
// NewMockNetLinker creates a new mock instance.
func NewMockNetLinker(ctrl *gomock.Controller) *MockNetLinker {
mock := &MockNetLinker{ctrl: ctrl}
mock.recorder = &MockNetLinkerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
return m.recorder
}
// AddrAdd mocks base method.
func (m *MockNetLinker) AddrAdd(arg0 netlink.Link, arg1 *netlink.Addr) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrAdd", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// AddrAdd indicates an expected call of AddrAdd.
func (mr *MockNetLinkerMockRecorder) AddrAdd(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1)
}
// RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RouteAdd indicates an expected call of RouteAdd.
func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0)
}
// RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RuleAdd indicates an expected call of RuleAdd.
func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleAdd", reflect.TypeOf((*MockNetLinker)(nil).RuleAdd), arg0)
}
// RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RuleDel indicates an expected call of RuleDel.
func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleDel", reflect.TypeOf((*MockNetLinker)(nil).RuleDel), arg0)
}

View File

@@ -0,0 +1,26 @@
package wireguard
import (
"fmt"
"net"
"github.com/vishvananda/netlink"
)
// TODO add IPv6 route if IPv6 is supported
func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
firewallMark int) (err error) {
route := &netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: dst,
Table: firewallMark,
}
err = w.netlink.RouteAdd(route)
if err != nil {
return fmt.Errorf("%w: when adding route: %s", err, route)
}
return err
}

View File

@@ -0,0 +1,85 @@
package wireguard
import (
"errors"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
func Test_Wireguard_addRoute(t *testing.T) {
t.Parallel()
const linkIndex = 88
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
linkAttrs.Index = linkIndex
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
const firewallMark = 51820
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
dst *net.IPNet
expectedRoute *netlink.Route
routeAddErr error
err error
}{
"success": {
link: newLink(),
dst: ipNet,
expectedRoute: &netlink.Route{
LinkIndex: linkIndex,
Dst: ipNet,
Table: firewallMark,
},
},
"route add error": {
link: newLink(),
dst: ipNet,
expectedRoute: &netlink.Route{
LinkIndex: linkIndex,
Dst: ipNet,
Table: firewallMark,
},
routeAddErr: errDummy,
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820}"), //nolint:lll
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
netLinker.EXPECT().
RouteAdd(testCase.expectedRoute).
Return(testCase.routeAddErr)
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,28 @@
package wireguard
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (w *Wireguard) addRule(rulePriority, firewallMark int) (
cleanup func() error, err error) {
rule := netlink.NewRule()
rule.Invert = true
rule.Priority = rulePriority
rule.Mark = firewallMark
rule.Table = firewallMark
if err := w.netlink.RuleAdd(rule); err != nil {
return nil, fmt.Errorf("%w: when adding rule: %s", err, rule)
}
cleanup = func() error {
err := w.netlink.RuleDel(rule)
if err != nil {
return fmt.Errorf("%w: when deleting rule: %s", err, rule)
}
return nil
}
return cleanup, nil
}

View File

@@ -0,0 +1,106 @@
package wireguard
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
func Test_Wireguard_addRule(t *testing.T) {
t.Parallel()
const rulePriority = 987
const firewallMark = 456
errDummy := errors.New("dummy")
testCases := map[string]struct {
expectedRule *netlink.Rule
ruleAddErr error
err error
ruleDelErr error
cleanupErr error
}{
"success": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
},
},
"rule add error": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
},
ruleAddErr: errDummy,
err: errors.New("dummy: when adding rule: ip rule 987: from <nil> table 456"),
},
"rule delete error": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
},
ruleDelErr: errDummy,
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from <nil> table 456"),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
Return(testCase.ruleAddErr)
cleanup, err := wg.addRule(rulePriority, firewallMark)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
return
}
require.NoError(t, err)
netLinker.EXPECT().RuleDel(testCase.expectedRule).
Return(testCase.ruleDelErr)
err = cleanup()
if testCase.cleanupErr != nil {
require.Error(t, err)
assert.Equal(t, testCase.cleanupErr.Error(), err.Error())
} else {
require.NoError(t, err)
}
})
}
}

165
internal/wireguard/run.go Normal file
View File

@@ -0,0 +1,165 @@
package wireguard
import (
"context"
"errors"
"fmt"
"net"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl"
)
var (
ErrCreateTun = errors.New("cannot create TUN device")
ErrFindLink = errors.New("cannot find link")
ErrFindDevice = errors.New("cannot find Wireguard device")
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
ErrWgctrlOpen = errors.New("cannot open wgctrl")
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
ErrAddAddress = errors.New("cannot add address to wireguard interface")
ErrConfigure = errors.New("cannot configure wireguard interface")
ErrIfaceUp = errors.New("cannot set the interface to UP")
ErrRouteAdd = errors.New("cannot add route for interface")
ErrRuleAdd = errors.New("cannot add rule for interface")
ErrDeviceWaited = errors.New("device waited for")
)
type Runner interface {
Run(ctx context.Context, waitError chan<- error, ready chan<- struct{})
}
// See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
client, err := wgctrl.New()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
return
}
var closers closers
closers.add("closing controller client", stepOne, client.Close)
defer closers.cleanup(w.logger)
tun, err := tun.CreateTUN(w.settings.InterfaceName, device.DefaultMTU)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrCreateTun, err)
return
}
closers.add("closing TUN device", stepFive, tun.Close)
tunName, err := tun.Name()
if err != nil {
waitError <- fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return
} else if tunName != w.settings.InterfaceName {
waitError <- fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, w.settings.InterfaceName, tunName)
return
}
link, err := netlink.LinkByName(w.settings.InterfaceName)
if err != nil {
waitError <- fmt.Errorf("%w: %s: %s", ErrFindLink, w.settings.InterfaceName, err)
return
}
bind := conn.NewDefaultBind()
closers.add("closing bind", stepFive, bind.Close)
deviceLogger := makeDeviceLogger(w.logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepFour, func() error {
device.Close()
return nil
})
uapiFile, err := ipc.UAPIOpen(w.settings.InterfaceName)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(w.settings.InterfaceName, uapiFile)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrUAPIListen, err)
return
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
err = w.addAddresses(link, w.settings.Addresses)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
return
}
err = configureDevice(client, w.settings)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
return
}
if err := netlink.LinkSetUp(link); err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return
}
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return
}
ruleCleanup, err := w.addRule(
w.settings.RulePriority, w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRuleAdd, err)
return
}
closers.add("removing rule", stepOne, ruleCleanup)
w.logger.Info("Wireguard is up")
ready <- struct{}{}
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh)
case <-device.Wait():
err = ErrDeviceWaited
}
closers.cleanup(w.logger)
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
waitError <- err
}
func acceptAndHandle(uapi net.Listener, device *device.Device,
uapiAcceptErrorCh chan<- error) {
for { // stopped by uapiFile.Close()
conn, err := uapi.Accept()
if err != nil {
uapiAcceptErrorCh <- err
return
}
go device.IpcHandle(conn)
}
}

View File

@@ -0,0 +1,212 @@
package wireguard
import (
"errors"
"fmt"
"net"
"regexp"
"strings"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Settings struct {
// Interface name for the Wireguard interface.
// It defaults to wg0 if unset.
InterfaceName string
// Private key in base 64 format
PrivateKey string
// Public key in base 64 format
PublicKey string
// Pre shared key in base 64 format
PreSharedKey string
// Wireguard server endpoint to connect to.
Endpoint *net.UDPAddr
// Addresses assigned to the client.
Addresses []*net.IPNet
// FirewallMark to be used in routing tables and IP rules.
// It defaults to 51820 if left to 0.
FirewallMark int
// RulePriority is the priority for the rule created with the
// FirewallMark.
RulePriority int
}
func (s *Settings) SetDefaults() {
if s.InterfaceName == "" {
const defaultInterfaceName = "wg0"
s.InterfaceName = defaultInterfaceName
}
if s.Endpoint != nil && s.Endpoint.Port == 0 {
const defaultPort = 51820
s.Endpoint.Port = defaultPort
}
if s.FirewallMark == 0 {
const defaultFirewallMark = 51820
s.FirewallMark = defaultFirewallMark
}
}
var (
ErrInterfaceNameInvalid = errors.New("invalid interface name")
ErrPrivateKeyMissing = errors.New("private key is missing")
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
ErrPublicKeyMissing = errors.New("public key is missing")
ErrPublicKeyInvalid = errors.New("cannot parse public key")
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
ErrEndpointMissing = errors.New("endpoint is missing")
ErrEndpointIPMissing = errors.New("endpoint IP is missing")
ErrEndpointPortMissing = errors.New("endpoint port is missing")
ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNil = errors.New("interface address is nil")
ErrAddressIPMissing = errors.New("interface address IP is missing")
ErrAddressMaskMissing = errors.New("interface address mask is missing")
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
)
var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
func (s *Settings) Check() (err error) {
if !interfaceNameRegexp.MatchString(s.InterfaceName) {
return fmt.Errorf("%w: %s", ErrInterfaceNameInvalid, s.InterfaceName)
}
if s.PrivateKey == "" {
return ErrPrivateKeyMissing
} else if _, err := wgtypes.ParseKey(s.PrivateKey); err != nil {
return ErrPrivateKeyInvalid
}
if s.PublicKey == "" {
return ErrPublicKeyMissing
} else if _, err := wgtypes.ParseKey(s.PublicKey); err != nil {
return fmt.Errorf("%w: %s", ErrPublicKeyInvalid, s.PublicKey)
}
if s.PreSharedKey != "" {
if _, err := wgtypes.ParseKey(s.PreSharedKey); err != nil {
return ErrPreSharedKeyInvalid
}
}
switch {
case s.Endpoint == nil:
return ErrEndpointMissing
case s.Endpoint.IP == nil:
return ErrEndpointIPMissing
case s.Endpoint.Port == 0:
return ErrEndpointPortMissing
}
if len(s.Addresses) == 0 {
return ErrAddressMissing
}
for i, addr := range s.Addresses {
switch {
case addr == nil:
return fmt.Errorf("%w: for address %d of %d",
ErrAddressNil, i+1, len(s.Addresses))
case addr.IP == nil:
return fmt.Errorf("%w: for address %d of %d",
ErrAddressIPMissing, i+1, len(s.Addresses))
case addr.Mask == nil:
return fmt.Errorf("%w: for address %d of %d",
ErrAddressMaskMissing, i+1, len(s.Addresses))
}
}
if s.FirewallMark == 0 {
return ErrFirewallMarkMissing
}
return nil
}
func (s Settings) String() string {
lines := s.ToLines(ToLinesSettings{})
return strings.Join(lines, "\n")
}
type ToLinesSettings struct {
// Indent defaults to 4 spaces " ".
Indent *string
// FieldPrefix defaults to "├── ".
FieldPrefix *string
// LastFieldPrefix defaults to "└── ".
LastFieldPrefix *string
}
func (settings *ToLinesSettings) setDefaults() {
toStringPtr := func(s string) *string { return &s }
if settings.Indent == nil {
settings.Indent = toStringPtr(" ")
}
if settings.FieldPrefix == nil {
settings.FieldPrefix = toStringPtr("├── ")
}
if settings.LastFieldPrefix == nil {
settings.LastFieldPrefix = toStringPtr("└── ")
}
}
// ToLines serializes the settings to a slice of strings for display.
func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
settings.setDefaults()
indent := *settings.Indent
fieldPrefix := *settings.FieldPrefix
lastFieldPrefix := *settings.LastFieldPrefix
lines = append(lines, fieldPrefix+"Interface name: "+s.InterfaceName)
const (
set = "set"
notSet = "not set"
)
isSet := notSet
if s.PrivateKey != "" {
isSet = set
}
lines = append(lines, fieldPrefix+"Private key: "+isSet)
if s.PublicKey != "" {
lines = append(lines, fieldPrefix+"PublicKey: "+s.PublicKey)
}
isSet = notSet
if s.PreSharedKey != "" {
isSet = set
}
lines = append(lines, fieldPrefix+"Pre shared key: "+isSet)
endpointStr := notSet
if s.Endpoint != nil {
endpointStr = s.Endpoint.String()
}
lines = append(lines, fieldPrefix+"Endpoint: "+endpointStr)
if s.FirewallMark != 0 {
lines = append(lines, fieldPrefix+"Firewall mark: "+fmt.Sprint(s.FirewallMark))
}
if s.RulePriority != 0 {
lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority))
}
if len(s.Addresses) == 0 {
lines = append(lines, lastFieldPrefix+"Addresses: "+notSet)
} else {
lines = append(lines, lastFieldPrefix+"Addresses:")
for i, address := range s.Addresses {
prefix := fieldPrefix
if i == len(s.Addresses)-1 {
prefix = lastFieldPrefix
}
lines = append(lines, indent+prefix+address.String())
}
}
return lines
}

View File

@@ -0,0 +1,377 @@
package wireguard
import (
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Settings_SetDefaults(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
original Settings
expected Settings
}{
"empty settings": {
expected: Settings{
InterfaceName: "wg0",
FirewallMark: 51820,
},
},
"default endpoint port": {
original: Settings{
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
},
},
expected: Settings{
InterfaceName: "wg0",
FirewallMark: 51820,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
},
},
"not empty settings": {
original: Settings{
InterfaceName: "wg1",
FirewallMark: 999,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 9999,
},
},
expected: Settings{
InterfaceName: "wg1",
FirewallMark: 999,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 9999,
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
testCase.original.SetDefaults()
assert.Equal(t, testCase.expected, testCase.original)
})
}
}
func Test_Settings_Check(t *testing.T) {
t.Parallel()
const (
validKey1 = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
validKey2 = "aPjc9US5ICB30D1P4glR9tO7bkB2Ga+KZiFqnoypBHk="
)
testCases := map[string]struct {
settings Settings
err error
}{
"empty settings": {
err: errors.New("invalid interface name: "),
},
"bad interface name": {
settings: Settings{
InterfaceName: "$H1T",
},
err: errors.New("invalid interface name: $H1T"),
},
"empty private key": {
settings: Settings{
InterfaceName: "wg0",
},
err: ErrPrivateKeyMissing,
},
"bad private key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: "bad key",
},
err: ErrPrivateKeyInvalid,
},
"empty public key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
},
err: ErrPublicKeyMissing,
},
"bad public key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: "bad key",
},
err: errors.New("cannot parse public key: bad key"),
},
"bad preshared key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
PreSharedKey: "bad key",
},
err: errors.New("cannot parse pre-shared key"),
},
"empty endpoint": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
},
err: ErrEndpointMissing,
},
"nil endpoint IP": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{},
},
err: ErrEndpointIPMissing,
},
"nil endpoint port": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
},
},
err: ErrEndpointPortMissing,
},
"no address": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
},
err: ErrAddressMissing,
},
"nil address": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{nil},
},
err: errors.New("interface address is nil: for address 1 of 1"),
},
"nil address IP": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{}},
},
err: errors.New("interface address IP is missing: for address 1 of 1"),
},
"nil address mask": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4)}},
},
err: errors.New("interface address mask is missing: for address 1 of 1"),
},
"zero firewall mark": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
},
err: ErrFirewallMarkMissing,
},
"all valid": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
FirewallMark: 999,
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
err := testCase.settings.Check()
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func toStringPtr(s string) *string { return &s }
func Test_ToLinesSettings_setDefaults(t *testing.T) {
t.Parallel()
settings := ToLinesSettings{
Indent: toStringPtr("indent"),
}
someFunc := func(settings ToLinesSettings) {
settings.setDefaults()
expectedSettings := ToLinesSettings{
Indent: toStringPtr("indent"),
FieldPrefix: toStringPtr("├── "),
LastFieldPrefix: toStringPtr("└── "),
}
assert.Equal(t, expectedSettings, settings)
}
someFunc(settings)
untouchedSettings := ToLinesSettings{
Indent: toStringPtr("indent"),
}
assert.Equal(t, untouchedSettings, settings)
}
func Test_Settings_String(t *testing.T) {
t.Parallel()
settings := Settings{
InterfaceName: "wg0",
}
const expected = `├── Interface name: wg0
├── Private key: not set
├── Pre shared key: not set
├── Endpoint: not set
└── Addresses: not set`
s := settings.String()
assert.Equal(t, expected, s)
}
func Test_Settings_Lines(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
settings Settings
lineSettings ToLinesSettings
lines []string
}{
"empty settings": {
lines: []string{
"├── Interface name: ",
"├── Private key: not set",
"├── Pre shared key: not set",
"├── Endpoint: not set",
"└── Addresses: not set",
},
},
"settings all set": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: "private key",
PublicKey: "public key",
PreSharedKey: "pre-shared key",
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
FirewallMark: 999,
RulePriority: 888,
Addresses: []*net.IPNet{
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
},
},
lines: []string{
"├── Interface name: wg0",
"├── Private key: set",
"├── PublicKey: public key",
"├── Pre shared key: set",
"├── Endpoint: 1.2.3.4:51820",
"├── Firewall mark: 999",
"├── Rule priority: 888",
"└── Addresses:",
" ├── 1.1.1.1/24",
" └── 2.2.2.2/32",
},
},
"custom line settings": {
lineSettings: ToLinesSettings{
Indent: toStringPtr(" "),
FieldPrefix: toStringPtr("- "),
LastFieldPrefix: toStringPtr("* "),
},
settings: Settings{
InterfaceName: "wg0",
Addresses: []*net.IPNet{
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
},
},
lines: []string{
"- Interface name: wg0",
"- Private key: not set",
"- Pre shared key: not set",
"- Endpoint: not set",
"* Addresses:",
" - 1.1.1.1/24",
" * 2.2.2.2/32",
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
lines := testCase.settings.ToLines(testCase.lineSettings)
assert.Equal(t, testCase.lines, lines)
})
}
}