diff --git a/Dockerfile b/Dockerfile index 9604a1fe..f8e0a28d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -97,6 +97,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ WIREGUARD_PRESHARED_KEY= \ WIREGUARD_PUBLIC_KEY= \ WIREGUARD_ADDRESSES= \ + WIREGUARD_IMPLEMENTATION=auto \ # VPN server filtering SERVER_REGIONS= \ SERVER_COUNTRIES= \ diff --git a/internal/configuration/settings/errors.go b/internal/configuration/settings/errors.go index d2456697..cc40c459 100644 --- a/internal/configuration/settings/errors.go +++ b/internal/configuration/settings/errors.go @@ -44,4 +44,5 @@ var ( ErrWireguardPrivateKeyNotSet = errors.New("private key is not set") ErrWireguardPublicKeyNotSet = errors.New("public key is not set") ErrWireguardPublicKeyNotValid = errors.New("public key is not valid") + ErrWireguardImplementationNotValid = errors.New("implementation is not valid") ) diff --git a/internal/configuration/settings/wireguard.go b/internal/configuration/settings/wireguard.go index a45d9cc4..d6d3c06d 100644 --- a/internal/configuration/settings/wireguard.go +++ b/internal/configuration/settings/wireguard.go @@ -27,6 +27,11 @@ type Wireguard struct { // to create. It cannot be the empty string in the // internal state. Interface string + // Implementation is the Wireguard implementation to use. + // It can be "auto", "userspace" or "kernelspace". + // It defaults to "auto" and cannot be the empty string + // in the internal state. + Implementation string } var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) @@ -85,15 +90,22 @@ func (w Wireguard) validate(vpnProvider string) (err error) { ErrWireguardInterfaceNotValid, w.Interface, regexpInterfaceName) } + validImplementations := []string{"auto", "userspace", "kernelspace"} + if !helpers.IsOneOf(w.Implementation, validImplementations...) { + return fmt.Errorf("%w: %s must be one of %s", ErrWireguardImplementationNotValid, + w.Implementation, helpers.ChoicesOrString(validImplementations)) + } + return nil } func (w *Wireguard) copy() (copied Wireguard) { return Wireguard{ - PrivateKey: helpers.CopyStringPtr(w.PrivateKey), - PreSharedKey: helpers.CopyStringPtr(w.PreSharedKey), - Addresses: helpers.CopyIPNetSlice(w.Addresses), - Interface: w.Interface, + PrivateKey: helpers.CopyStringPtr(w.PrivateKey), + PreSharedKey: helpers.CopyStringPtr(w.PreSharedKey), + Addresses: helpers.CopyIPNetSlice(w.Addresses), + Interface: w.Interface, + Implementation: w.Implementation, } } @@ -102,6 +114,7 @@ func (w *Wireguard) mergeWith(other Wireguard) { w.PreSharedKey = helpers.MergeWithStringPtr(w.PreSharedKey, other.PreSharedKey) w.Addresses = helpers.MergeIPNetsSlices(w.Addresses, other.Addresses) w.Interface = helpers.MergeWithString(w.Interface, other.Interface) + w.Implementation = helpers.MergeWithString(w.Implementation, other.Implementation) } func (w *Wireguard) overrideWith(other Wireguard) { @@ -109,12 +122,14 @@ func (w *Wireguard) overrideWith(other Wireguard) { w.PreSharedKey = helpers.OverrideWithStringPtr(w.PreSharedKey, other.PreSharedKey) w.Addresses = helpers.OverrideWithIPNetsSlice(w.Addresses, other.Addresses) w.Interface = helpers.OverrideWithString(w.Interface, other.Interface) + w.Implementation = helpers.OverrideWithString(w.Implementation, other.Implementation) } func (w *Wireguard) setDefaults() { w.PrivateKey = helpers.DefaultStringPtr(w.PrivateKey, "") w.PreSharedKey = helpers.DefaultStringPtr(w.PreSharedKey, "") w.Interface = helpers.DefaultString(w.Interface, "wg0") + w.Implementation = helpers.DefaultString(w.Implementation, "auto") } func (w Wireguard) String() string { @@ -141,5 +156,9 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) { node.Appendf("Network interface: %s", w.Interface) + if w.Implementation != "auto" { + node.Appendf("Implementation: %s", w.Implementation) + } + return node } diff --git a/internal/configuration/sources/env/wireguard.go b/internal/configuration/sources/env/wireguard.go index 9801d22b..a8ae7c97 100644 --- a/internal/configuration/sources/env/wireguard.go +++ b/internal/configuration/sources/env/wireguard.go @@ -3,6 +3,7 @@ package env import ( "fmt" "net" + "os" "strings" "github.com/qdm12/gluetun/internal/configuration/settings" @@ -15,6 +16,7 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) { wireguard.PrivateKey = envToStringPtr("WIREGUARD_PRIVATE_KEY") wireguard.PreSharedKey = envToStringPtr("WIREGUARD_PRESHARED_KEY") _, wireguard.Interface = s.getEnvWithRetro("VPN_INTERFACE", "WIREGUARD_INTERFACE") + wireguard.Implementation = os.Getenv("WIREGUARD_IMPLEMENTATION") wireguard.Addresses, err = s.readWireguardAddresses() if err != nil { return wireguard, err // already wrapped diff --git a/internal/provider/utils/wireguard.go b/internal/provider/utils/wireguard.go index 0a48af25..1a5de932 100644 --- a/internal/provider/utils/wireguard.go +++ b/internal/provider/utils/wireguard.go @@ -14,6 +14,7 @@ func BuildWireguardSettings(connection models.Connection, settings.PublicKey = connection.PubKey settings.PreSharedKey = *userSettings.PreSharedKey settings.InterfaceName = userSettings.Interface + settings.Implementation = userSettings.Implementation settings.IPv6 = &ipv6Supported const rulePriority = 101 // 100 is to receive external connections diff --git a/internal/wireguard/constructor_test.go b/internal/wireguard/constructor_test.go index ef4259a3..d6ebee94 100644 --- a/internal/wireguard/constructor_test.go +++ b/internal/wireguard/constructor_test.go @@ -54,8 +54,9 @@ func Test_New(t *testing.T) { IP: net.IPv4(5, 6, 7, 8), Mask: net.IPv4Mask(255, 255, 255, 255)}, }, - FirewallMark: 100, - IPv6: ptr(false), + FirewallMark: 100, + IPv6: ptr(false), + Implementation: "auto", }, }, }, diff --git a/internal/wireguard/run.go b/internal/wireguard/run.go index 8060f835..cfb771d5 100644 --- a/internal/wireguard/run.go +++ b/internal/wireguard/run.go @@ -31,16 +31,37 @@ var ( ErrIfaceUp = errors.New("cannot set the interface to UP") ErrRouteAdd = errors.New("cannot add route for interface") ErrDeviceWaited = errors.New("device waited for") + ErrKernelSupport = errors.New("kernel does not support Wireguard") ) // See https://git.zx2c4.com/wireguard-go/tree/main.go func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) { - doKernel, err := w.netlink.IsWireguardSupported() + kernelSupported, err := w.netlink.IsWireguardSupported() if err != nil { waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err) return } + setupFunction := setupUserSpace + switch w.settings.Implementation { + case "auto": //nolint:goconst + if !kernelSupported { + w.logger.Info("Using userspace implementation since Kernel support does not exist") + break + } + w.logger.Info("Using available kernelspace implementation") + setupFunction = setupKernelSpace + case "userspace": + case "kernelspace": + if !kernelSupported { + waitError <- fmt.Errorf("%w", ErrKernelSupport) + return + } + setupFunction = setupKernelSpace + default: + panic(fmt.Sprintf("unknown implementation %q", w.settings.Implementation)) + } + client, err := wgctrl.New() if err != nil { waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err) @@ -52,14 +73,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< defer closers.cleanup(w.logger) - setupFunction := setupUserSpace - if doKernel { - w.logger.Info("Using available kernelspace implementation") - setupFunction = setupKernelSpace - } else { - w.logger.Info("Using userspace implementation since Kernel support does not exist") - } - link, waitAndCleanup, err := setupFunction(ctx, w.settings.InterfaceName, w.netlink, &closers, w.logger) if err != nil { diff --git a/internal/wireguard/settings.go b/internal/wireguard/settings.go index f9a09532..d6513bdf 100644 --- a/internal/wireguard/settings.go +++ b/internal/wireguard/settings.go @@ -33,6 +33,9 @@ type Settings struct { // IPv6 can bet set to true if IPv6 should be handled. // It defaults to false if left unset. IPv6 *bool + // Implementation is the implementation to use. + // It can be auto, kernelspace or userspace, and defaults to auto. + Implementation string } func (s *Settings) SetDefaults() { @@ -55,23 +58,29 @@ func (s *Settings) SetDefaults() { ipv6 := false // this should be injected from host s.IPv6 = &ipv6 } + + if s.Implementation == "" { + const defaultImplementation = "auto" + s.Implementation = defaultImplementation + } } var ( - ErrInterfaceNameInvalid = errors.New("invalid interface name") - ErrPrivateKeyMissing = errors.New("private key is missing") - ErrPrivateKeyInvalid = errors.New("cannot parse private key") - ErrPublicKeyMissing = errors.New("public key is missing") - ErrPublicKeyInvalid = errors.New("cannot parse public key") - ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key") - ErrEndpointMissing = errors.New("endpoint is missing") - ErrEndpointIPMissing = errors.New("endpoint IP is missing") - ErrEndpointPortMissing = errors.New("endpoint port is missing") - ErrAddressMissing = errors.New("interface address is missing") - ErrAddressNil = errors.New("interface address is nil") - ErrAddressIPMissing = errors.New("interface address IP is missing") - ErrAddressMaskMissing = errors.New("interface address mask is missing") - ErrFirewallMarkMissing = errors.New("firewall mark is missing") + ErrInterfaceNameInvalid = errors.New("invalid interface name") + ErrPrivateKeyMissing = errors.New("private key is missing") + ErrPrivateKeyInvalid = errors.New("cannot parse private key") + ErrPublicKeyMissing = errors.New("public key is missing") + ErrPublicKeyInvalid = errors.New("cannot parse public key") + ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key") + ErrEndpointMissing = errors.New("endpoint is missing") + ErrEndpointIPMissing = errors.New("endpoint IP is missing") + ErrEndpointPortMissing = errors.New("endpoint port is missing") + ErrAddressMissing = errors.New("interface address is missing") + ErrAddressNil = errors.New("interface address is nil") + ErrAddressIPMissing = errors.New("interface address IP is missing") + ErrAddressMaskMissing = errors.New("interface address mask is missing") + ErrFirewallMarkMissing = errors.New("firewall mark is missing") + ErrImplementationInvalid = errors.New("invalid implementation") ) var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) @@ -129,6 +138,12 @@ func (s *Settings) Check() (err error) { return ErrFirewallMarkMissing } + switch s.Implementation { + case "auto", "kernelspace", "userspace": + default: + return fmt.Errorf("%w: %s", ErrImplementationInvalid, s.Implementation) + } + return nil } @@ -209,6 +224,10 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) { lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority)) } + if s.Implementation != "auto" { + lines = append(lines, fieldPrefix+"Implementation: "+s.Implementation) + } + if len(s.Addresses) == 0 { lines = append(lines, lastFieldPrefix+"Addresses: "+notSet) } else { diff --git a/internal/wireguard/settings_test.go b/internal/wireguard/settings_test.go index 53c957db..b49dded1 100644 --- a/internal/wireguard/settings_test.go +++ b/internal/wireguard/settings_test.go @@ -20,9 +20,10 @@ func Test_Settings_SetDefaults(t *testing.T) { }{ "empty settings": { expected: Settings{ - InterfaceName: "wg0", - FirewallMark: 51820, - IPv6: ptr(false), + InterfaceName: "wg0", + FirewallMark: 51820, + IPv6: ptr(false), + Implementation: "auto", }, }, "default endpoint port": { @@ -38,7 +39,8 @@ func Test_Settings_SetDefaults(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 51820, }, - IPv6: ptr(false), + IPv6: ptr(false), + Implementation: "auto", }, }, "not empty settings": { @@ -49,7 +51,8 @@ func Test_Settings_SetDefaults(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 9999, }, - IPv6: ptr(true), + IPv6: ptr(true), + Implementation: "userspace", }, expected: Settings{ InterfaceName: "wg1", @@ -58,7 +61,8 @@ func Test_Settings_SetDefaults(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 9999, }, - IPv6: ptr(true), + IPv6: ptr(true), + Implementation: "userspace", }, }, } @@ -225,6 +229,21 @@ func Test_Settings_Check(t *testing.T) { }, err: ErrFirewallMarkMissing, }, + "invalid implementation": { + settings: Settings{ + InterfaceName: "wg0", + PrivateKey: validKey1, + PublicKey: validKey2, + Endpoint: &net.UDPAddr{ + IP: net.IPv4(1, 2, 3, 4), + Port: 51820, + }, + Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}}, + FirewallMark: 999, + Implementation: "x", + }, + err: errors.New("invalid implementation: x"), + }, "all valid": { settings: Settings{ InterfaceName: "wg0", @@ -234,8 +253,9 @@ func Test_Settings_Check(t *testing.T) { IP: net.IPv4(1, 2, 3, 4), Port: 51820, }, - Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}}, - FirewallMark: 999, + Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}}, + FirewallMark: 999, + Implementation: "userspace", }, }, } @@ -287,14 +307,16 @@ func Test_Settings_String(t *testing.T) { t.Parallel() settings := Settings{ - InterfaceName: "wg0", - IPv6: ptr(true), + InterfaceName: "wg0", + IPv6: ptr(true), + Implementation: "x", } const expected = `├── Interface name: wg0 ├── Private key: not set ├── Pre shared key: not set ├── Endpoint: not set ├── IPv6: enabled +├── Implementation: x └── Addresses: not set` s := settings.String() assert.Equal(t, expected, s) @@ -318,6 +340,7 @@ func Test_Settings_Lines(t *testing.T) { "├── Pre shared key: not set", "├── Endpoint: not set", "├── IPv6: disabled", + "├── Implementation: ", "└── Addresses: not set", }, }, @@ -337,7 +360,8 @@ func Test_Settings_Lines(t *testing.T) { {IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)}, {IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)}, }, - IPv6: ptr(true), + IPv6: ptr(true), + Implementation: "userspace", }, lines: []string{ "├── Interface name: wg0", @@ -348,6 +372,7 @@ func Test_Settings_Lines(t *testing.T) { "├── IPv6: enabled", "├── Firewall mark: 999", "├── Rule priority: 888", + "├── Implementation: userspace", "└── Addresses:", " ├── 1.1.1.1/24", " └── 2.2.2.2/32", @@ -373,6 +398,7 @@ func Test_Settings_Lines(t *testing.T) { "- Pre shared key: not set", "- Endpoint: not set", "- IPv6: disabled", + "- Implementation: ", "* Addresses:", " - 1.1.1.1/24", " * 2.2.2.2/32",