feat(wireguard): WIREGUARD_IMPLEMENTATION variable

- Can be `auto` (default), `userspace` or `kernelspace`
This commit is contained in:
Quentin McGaw
2022-12-02 11:16:27 +00:00
parent 1b1335835b
commit 03ed3cb1c8
9 changed files with 123 additions and 40 deletions

View File

@@ -97,6 +97,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PRESHARED_KEY= \ WIREGUARD_PRESHARED_KEY= \
WIREGUARD_PUBLIC_KEY= \ WIREGUARD_PUBLIC_KEY= \
WIREGUARD_ADDRESSES= \ WIREGUARD_ADDRESSES= \
WIREGUARD_IMPLEMENTATION=auto \
# VPN server filtering # VPN server filtering
SERVER_REGIONS= \ SERVER_REGIONS= \
SERVER_COUNTRIES= \ SERVER_COUNTRIES= \

View File

@@ -44,4 +44,5 @@ var (
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set") ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
ErrWireguardPublicKeyNotSet = errors.New("public key is not set") ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid") ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
) )

View File

@@ -27,6 +27,11 @@ type Wireguard struct {
// to create. It cannot be the empty string in the // to create. It cannot be the empty string in the
// internal state. // internal state.
Interface string Interface string
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string
// in the internal state.
Implementation string
} }
var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
@@ -85,6 +90,12 @@ func (w Wireguard) validate(vpnProvider string) (err error) {
ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName) ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName)
} }
validImplementations := []string{"auto", "userspace", "kernelspace"}
if !helpers.IsOneOf(w.Implementation, validImplementations...) {
return fmt.Errorf("%w: %s must be one of %s", ErrWireguardImplementationNotValid,
w.Implementation, helpers.ChoicesOrString(validImplementations))
}
return nil return nil
} }
@@ -94,6 +105,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
PreSharedKey: helpers.CopyStringPtr(w.PreSharedKey), PreSharedKey: helpers.CopyStringPtr(w.PreSharedKey),
Addresses: helpers.CopyIPNetSlice(w.Addresses), Addresses: helpers.CopyIPNetSlice(w.Addresses),
Interface: w.Interface, Interface: w.Interface,
Implementation: w.Implementation,
} }
} }
@@ -102,6 +114,7 @@ func (w *Wireguard) mergeWith(other Wireguard) {
w.PreSharedKey = helpers.MergeWithStringPtr(w.PreSharedKey, other.PreSharedKey) w.PreSharedKey = helpers.MergeWithStringPtr(w.PreSharedKey, other.PreSharedKey)
w.Addresses = helpers.MergeIPNetsSlices(w.Addresses, other.Addresses) w.Addresses = helpers.MergeIPNetsSlices(w.Addresses, other.Addresses)
w.Interface = helpers.MergeWithString(w.Interface, other.Interface) w.Interface = helpers.MergeWithString(w.Interface, other.Interface)
w.Implementation = helpers.MergeWithString(w.Implementation, other.Implementation)
} }
func (w *Wireguard) overrideWith(other Wireguard) { func (w *Wireguard) overrideWith(other Wireguard) {
@@ -109,12 +122,14 @@ func (w *Wireguard) overrideWith(other Wireguard) {
w.PreSharedKey = helpers.OverrideWithStringPtr(w.PreSharedKey, other.PreSharedKey) w.PreSharedKey = helpers.OverrideWithStringPtr(w.PreSharedKey, other.PreSharedKey)
w.Addresses = helpers.OverrideWithIPNetsSlice(w.Addresses, other.Addresses) w.Addresses = helpers.OverrideWithIPNetsSlice(w.Addresses, other.Addresses)
w.Interface = helpers.OverrideWithString(w.Interface, other.Interface) w.Interface = helpers.OverrideWithString(w.Interface, other.Interface)
w.Implementation = helpers.OverrideWithString(w.Implementation, other.Implementation)
} }
func (w *Wireguard) setDefaults() { func (w *Wireguard) setDefaults() {
w.PrivateKey = helpers.DefaultStringPtr(w.PrivateKey, "") w.PrivateKey = helpers.DefaultStringPtr(w.PrivateKey, "")
w.PreSharedKey = helpers.DefaultStringPtr(w.PreSharedKey, "") w.PreSharedKey = helpers.DefaultStringPtr(w.PreSharedKey, "")
w.Interface = helpers.DefaultString(w.Interface, "wg0") w.Interface = helpers.DefaultString(w.Interface, "wg0")
w.Implementation = helpers.DefaultString(w.Implementation, "auto")
} }
func (w Wireguard) String() string { func (w Wireguard) String() string {
@@ -141,5 +156,9 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
node.Appendf("Network interface: %s", w.Interface) node.Appendf("Network interface: %s", w.Interface)
if w.Implementation != "auto" {
node.Appendf("Implementation: %s", w.Implementation)
}
return node return node
} }

View File

@@ -3,6 +3,7 @@ package env
import ( import (
"fmt" "fmt"
"net" "net"
"os"
"strings" "strings"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
@@ -15,6 +16,7 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) {
wireguard.PrivateKey = envToStringPtr("WIREGUARD_PRIVATE_KEY") wireguard.PrivateKey = envToStringPtr("WIREGUARD_PRIVATE_KEY")
wireguard.PreSharedKey = envToStringPtr("WIREGUARD_PRESHARED_KEY") wireguard.PreSharedKey = envToStringPtr("WIREGUARD_PRESHARED_KEY")
_, wireguard.Interface = s.getEnvWithRetro("VPN_INTERFACE", "WIREGUARD_INTERFACE") _, wireguard.Interface = s.getEnvWithRetro("VPN_INTERFACE", "WIREGUARD_INTERFACE")
wireguard.Implementation = os.Getenv("WIREGUARD_IMPLEMENTATION")
wireguard.Addresses, err = s.readWireguardAddresses() wireguard.Addresses, err = s.readWireguardAddresses()
if err != nil { if err != nil {
return wireguard, err // already wrapped return wireguard, err // already wrapped

View File

@@ -14,6 +14,7 @@ func BuildWireguardSettings(connection models.Connection,
settings.PublicKey = connection.PubKey settings.PublicKey = connection.PubKey
settings.PreSharedKey = *userSettings.PreSharedKey settings.PreSharedKey = *userSettings.PreSharedKey
settings.InterfaceName = userSettings.Interface settings.InterfaceName = userSettings.Interface
settings.Implementation = userSettings.Implementation
settings.IPv6 = &ipv6Supported settings.IPv6 = &ipv6Supported
const rulePriority = 101 // 100 is to receive external connections const rulePriority = 101 // 100 is to receive external connections

View File

@@ -56,6 +56,7 @@ func Test_New(t *testing.T) {
}, },
FirewallMark: 100, FirewallMark: 100,
IPv6: ptr(false), IPv6: ptr(false),
Implementation: "auto",
}, },
}, },
}, },

View File

@@ -31,16 +31,37 @@ var (
ErrIfaceUp = errors.New("cannot set the interface to UP") ErrIfaceUp = errors.New("cannot set the interface to UP")
ErrRouteAdd = errors.New("cannot add route for interface") ErrRouteAdd = errors.New("cannot add route for interface")
ErrDeviceWaited = errors.New("device waited for") ErrDeviceWaited = errors.New("device waited for")
ErrKernelSupport = errors.New("kernel does not support Wireguard")
) )
// See https://git.zx2c4.com/wireguard-go/tree/main.go // See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) { func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
doKernel, err := w.netlink.IsWireguardSupported() kernelSupported, err := w.netlink.IsWireguardSupported()
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err) waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
return return
} }
setupFunction := setupUserSpace
switch w.settings.Implementation {
case "auto": //nolint:goconst
if !kernelSupported {
w.logger.Info("Using userspace implementation since Kernel support does not exist")
break
}
w.logger.Info("Using available kernelspace implementation")
setupFunction = setupKernelSpace
case "userspace":
case "kernelspace":
if !kernelSupported {
waitError <- fmt.Errorf("%w", ErrKernelSupport)
return
}
setupFunction = setupKernelSpace
default:
panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation))
}
client, err := wgctrl.New() client, err := wgctrl.New()
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err) waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
@@ -52,14 +73,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
defer closers.cleanup(w.logger) defer closers.cleanup(w.logger)
setupFunction := setupUserSpace
if doKernel {
w.logger.Info("Using available kernelspace implementation")
setupFunction = setupKernelSpace
} else {
w.logger.Info("Using userspace implementation since Kernel support does not exist")
}
link, waitAndCleanup, err := setupFunction(ctx, link, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, &closers, w.logger) w.settings.InterfaceName, w.netlink, &closers, w.logger)
if err != nil { if err != nil {

View File

@@ -33,6 +33,9 @@ type Settings struct {
// IPv6 can bet set to true if IPv6 should be handled. // IPv6 can bet set to true if IPv6 should be handled.
// It defaults to false if left unset. // It defaults to false if left unset.
IPv6 *bool IPv6 *bool
// Implementation is the implementation to use.
// It can be auto, kernelspace or userspace, and defaults to auto.
Implementation string
} }
func (s *Settings) SetDefaults() { func (s *Settings) SetDefaults() {
@@ -55,6 +58,11 @@ func (s *Settings) SetDefaults() {
ipv6 := false // this should be injected from host ipv6 := false // this should be injected from host
s.IPv6 = &ipv6 s.IPv6 = &ipv6
} }
if s.Implementation == "" {
const defaultImplementation = "auto"
s.Implementation = defaultImplementation
}
} }
var ( var (
@@ -72,6 +80,7 @@ var (
ErrAddressIPMissing = errors.New("interface address IP is missing") ErrAddressIPMissing = errors.New("interface address IP is missing")
ErrAddressMaskMissing = errors.New("interface address mask is missing") ErrAddressMaskMissing = errors.New("interface address mask is missing")
ErrFirewallMarkMissing = errors.New("firewall mark is missing") ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrImplementationInvalid = errors.New("invalid implementation")
) )
var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
@@ -129,6 +138,12 @@ func (s *Settings) Check() (err error) {
return ErrFirewallMarkMissing return ErrFirewallMarkMissing
} }
switch s.Implementation {
case "auto", "kernelspace", "userspace":
default:
return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation)
}
return nil return nil
} }
@@ -209,6 +224,10 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority)) lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority))
} }
if s.Implementation != "auto" {
lines = append(lines, fieldPrefix+"Implementation: "+s.Implementation)
}
if len(s.Addresses) == 0 { if len(s.Addresses) == 0 {
lines = append(lines, lastFieldPrefix+"Addresses: "+notSet) lines = append(lines, lastFieldPrefix+"Addresses: "+notSet)
} else { } else {

View File

@@ -23,6 +23,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
InterfaceName: "wg0", InterfaceName: "wg0",
FirewallMark: 51820, FirewallMark: 51820,
IPv6: ptr(false), IPv6: ptr(false),
Implementation: "auto",
}, },
}, },
"default endpoint port": { "default endpoint port": {
@@ -39,6 +40,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
Port: 51820, Port: 51820,
}, },
IPv6: ptr(false), IPv6: ptr(false),
Implementation: "auto",
}, },
}, },
"not empty settings": { "not empty settings": {
@@ -50,6 +52,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
Port: 9999, Port: 9999,
}, },
IPv6: ptr(true), IPv6: ptr(true),
Implementation: "userspace",
}, },
expected: Settings{ expected: Settings{
InterfaceName: "wg1", InterfaceName: "wg1",
@@ -59,6 +62,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
Port: 9999, Port: 9999,
}, },
IPv6: ptr(true), IPv6: ptr(true),
Implementation: "userspace",
}, },
}, },
} }
@@ -225,6 +229,21 @@ func Test_Settings_Check(t *testing.T) {
}, },
err: ErrFirewallMarkMissing, err: ErrFirewallMarkMissing,
}, },
"invalid implementation": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: &net.UDPAddr{
IP: net.IPv4(1, 2, 3, 4),
Port: 51820,
},
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
FirewallMark: 999,
Implementation: "x",
},
err: errors.New("invalid implementation: x"),
},
"all valid": { "all valid": {
settings: Settings{ settings: Settings{
InterfaceName: "wg0", InterfaceName: "wg0",
@@ -236,6 +255,7 @@ func Test_Settings_Check(t *testing.T) {
}, },
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}}, Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
FirewallMark: 999, FirewallMark: 999,
Implementation: "userspace",
}, },
}, },
} }
@@ -289,12 +309,14 @@ func Test_Settings_String(t *testing.T) {
settings := Settings{ settings := Settings{
InterfaceName: "wg0", InterfaceName: "wg0",
IPv6: ptr(true), IPv6: ptr(true),
Implementation: "x",
} }
const expected = `├── Interface name: wg0 const expected = `├── Interface name: wg0
├── Private key: not set ├── Private key: not set
├── Pre shared key: not set ├── Pre shared key: not set
├── Endpoint: not set ├── Endpoint: not set
├── IPv6: enabled ├── IPv6: enabled
├── Implementation: x
└── Addresses: not set` └── Addresses: not set`
s := settings.String() s := settings.String()
assert.Equal(t, expected, s) assert.Equal(t, expected, s)
@@ -318,6 +340,7 @@ func Test_Settings_Lines(t *testing.T) {
"├── Pre shared key: not set", "├── Pre shared key: not set",
"├── Endpoint: not set", "├── Endpoint: not set",
"├── IPv6: disabled", "├── IPv6: disabled",
"├── Implementation: ",
"└── Addresses: not set", "└── Addresses: not set",
}, },
}, },
@@ -338,6 +361,7 @@ func Test_Settings_Lines(t *testing.T) {
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)}, {IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
}, },
IPv6: ptr(true), IPv6: ptr(true),
Implementation: "userspace",
}, },
lines: []string{ lines: []string{
"├── Interface name: wg0", "├── Interface name: wg0",
@@ -348,6 +372,7 @@ func Test_Settings_Lines(t *testing.T) {
"├── IPv6: enabled", "├── IPv6: enabled",
"├── Firewall mark: 999", "├── Firewall mark: 999",
"├── Rule priority: 888", "├── Rule priority: 888",
"├── Implementation: userspace",
"└── Addresses:", "└── Addresses:",
" ├── 1.1.1.1/24", " ├── 1.1.1.1/24",
" └── 2.2.2.2/32", " └── 2.2.2.2/32",
@@ -373,6 +398,7 @@ func Test_Settings_Lines(t *testing.T) {
"- Pre shared key: not set", "- Pre shared key: not set",
"- Endpoint: not set", "- Endpoint: not set",
"- IPv6: disabled", "- IPv6: disabled",
"- Implementation: ",
"* Addresses:", "* Addresses:",
" - 1.1.1.1/24", " - 1.1.1.1/24",
" * 2.2.2.2/32", " * 2.2.2.2/32",