feat(internal/wireguard): opportunistic kernelspace

- Auto detect if kernelspace implementation is available
- Fallback to Go userspace implementation if kernel is not available
This commit is contained in:
Quentin McGaw
2021-12-14 11:03:36 +00:00
parent b9a9319cb4
commit cfa3bb3b64
14 changed files with 229 additions and 79 deletions

View File

@@ -1,6 +1,10 @@
package netlink
import "github.com/vishvananda/netlink"
import (
"fmt"
"github.com/vishvananda/netlink"
)
//nolint:revive
const (
@@ -8,3 +12,20 @@ const (
FAMILY_V4 = netlink.FAMILY_V4
FAMILY_V6 = netlink.FAMILY_V6
)
type WireguardChecker interface {
IsWireguardSupported() (ok bool, err error)
}
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
families, err := netlink.GenlFamilyList()
if err != nil {
return false, fmt.Errorf("cannot list gen 1 families: %w", err)
}
for _, family := range families {
if family.Name == "wireguard" {
return true, nil
}
}
return false, nil
}

View File

@@ -0,0 +1,21 @@
package netlink
import (
"testing"
"github.com/stretchr/testify/require"
)
func Test_NetLink_IsWireguardSupported(t *testing.T) {
t.Skip() // TODO unskip once the data race problem with netlink.GenlFamilyList() is fixed
t.Parallel()
netLink := &NetLink{}
ok, err := netLink.IsWireguardSupported()
require.NoError(t, err)
if ok { // cannot assert since this depends on kernel
t.Log("wireguard is supported")
} else {
t.Log("wireguard is not supported")
}
}

View File

@@ -9,4 +9,5 @@ type NetLinker interface {
Linker
Router
Ruler
WireguardChecker
}

View File

@@ -3,8 +3,9 @@ package netlink
import "github.com/vishvananda/netlink"
type (
Link = netlink.Link
Bridge = netlink.Bridge
Link = netlink.Link
Bridge = netlink.Bridge
Wireguard = netlink.Wireguard
)
var _ Linker = (*NetLink)(nil)

View File

@@ -63,6 +63,21 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrList", reflect.TypeOf((*MockNetLinker)(nil).AddrList), arg0, arg1)
}
// IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported))
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
m.ctrl.T.Helper()

View File

@@ -88,7 +88,7 @@ func Test_Routing_addIPRule(t *testing.T) {
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
err: errDummy,
},
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 table 99"),
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"),
},
"add rule success": {
src: makeIPNet(t, 1),
@@ -193,7 +193,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
err: errDummy,
},
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 table 99"),
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"),
},
"rule deleted": {
src: makeIPNet(t, 1),

View File

@@ -10,9 +10,11 @@ type NetLinker interface {
RouteAdd(route *netlink.Route) error
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
LinkAdd(link netlink.Link) (err error)
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) error
LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error
IsWireguardSupported() (ok bool, err error)
}

View File

@@ -48,6 +48,35 @@ func (mr *MockNetLinkerMockRecorder) AddrAdd(arg0, arg1 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1)
}
// IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported))
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LinkAdd indicates an expected call of LinkAdd.
func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkAdd", reflect.TypeOf((*MockNetLinker)(nil).LinkAdd), arg0)
}
// LinkByName mocks base method.
func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) {
m.ctrl.T.Helper()

View File

@@ -53,7 +53,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
Table: firewallMark,
},
routeAddErr: errDummy,
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820}"), //nolint:lll
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820 Realm: 0}"), //nolint:lll
},
}

View File

@@ -51,7 +51,7 @@ func Test_Wireguard_addRule(t *testing.T) {
SuppressPrefixlen: -1,
},
ruleAddErr: errDummy,
err: errors.New("dummy: when adding rule: ip rule 987: from <nil> table 456"),
err: errors.New("dummy: when adding rule: ip rule 987: from all to all table 456"),
},
"rule delete error": {
expectedRule: &netlink.Rule{
@@ -66,7 +66,7 @@ func Test_Wireguard_addRule(t *testing.T) {
SuppressPrefixlen: -1,
},
ruleDelErr: errDummy,
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from <nil> table 456"),
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from all to all table 456"),
},
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"github.com/qdm12/gluetun/internal/netlink"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
@@ -15,7 +16,9 @@ import (
var (
ErrDetectIPv6 = errors.New("cannot detect IPv6 support")
ErrDetectKernel = errors.New("cannot detect Kernel support")
ErrCreateTun = errors.New("cannot create TUN device")
ErrAddLink = errors.New("cannot add Wireguard link")
ErrFindLink = errors.New("cannot find link")
ErrFindDevice = errors.New("cannot find Wireguard device")
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
@@ -23,6 +26,7 @@ var (
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
ErrAddAddress = errors.New("cannot add address to wireguard interface")
ErrConfigure = errors.New("cannot configure wireguard interface")
ErrDeviceInfo = errors.New("cannot get wireguard device information")
ErrIfaceUp = errors.New("cannot set the interface to UP")
ErrRouteAdd = errors.New("cannot add route for interface")
ErrRuleAdd = errors.New("cannot add rule for interface")
@@ -41,6 +45,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return
}
doKernel, err := w.netlink.IsWireguardSupported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
return
}
client, err := wgctrl.New()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
@@ -52,62 +62,21 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
defer closers.cleanup(w.logger)
tun, err := tun.CreateTUN(w.settings.InterfaceName, device.DefaultMTU)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrCreateTun, err)
return
setupFunction := setupUserSpace
if doKernel {
w.logger.Info("Using available kernelspace implementation")
setupFunction = setupKernelSpace
} else {
w.logger.Info("Using userspace implementation since Kernel support does not exist")
}
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
link, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, &closers, w.logger)
if err != nil {
waitError <- fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return
} else if tunName != w.settings.InterfaceName {
waitError <- fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, w.settings.InterfaceName, tunName)
waitError <- err
return
}
link, err := w.netlink.LinkByName(w.settings.InterfaceName)
if err != nil {
waitError <- fmt.Errorf("%w: %s: %s", ErrFindLink, w.settings.InterfaceName, err)
return
}
bind := conn.NewDefaultBind()
closers.add("closing bind", stepSeven, bind.Close)
deviceLogger := makeDeviceLogger(w.logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepSix, func() error {
device.Close()
return nil
})
uapiFile, err := ipc.UAPIOpen(w.settings.InterfaceName)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(w.settings.InterfaceName, uapiFile)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrUAPIListen, err)
return
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
err = w.addAddresses(link, w.settings.Addresses)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
@@ -128,9 +97,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link)
})
closers.add("deleting link", stepFive, func() error {
return w.netlink.LinkDel(link)
})
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
if err != nil {
@@ -158,20 +124,113 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
w.logger.Info("Wireguard is up")
ready <- struct{}{}
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh)
case <-device.Wait():
err = ErrDeviceWaited
waitError <- waitAndCleanup()
}
type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
linkAttrs := netlink.LinkAttrs{
Name: interfaceName,
MTU: device.DefaultMTU, // TODO
}
link = &netlink.Wireguard{
LinkAttrs: linkAttrs,
}
err = netLinker.LinkAdd(link)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
})
waitAndCleanup = func() error {
<-ctx.Done()
closers.cleanup(logger)
return ctx.Err()
}
closers.cleanup(w.logger)
return link, waitAndCleanup, nil
}
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
tun, err := tun.CreateTUN(interfaceName, device.DefaultMTU)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
}
waitError <- err
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
if err != nil {
return nil, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName {
return nil, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
}
link, err = netLinker.LinkByName(interfaceName)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
})
bind := conn.NewDefaultBind()
closers.add("closing bind", stepSeven, bind.Close)
deviceLogger := makeDeviceLogger(logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepSix, func() error {
device.Close()
return nil
})
uapiFile, err := ipc.UAPIOpen(interfaceName)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
waitAndCleanup = func() error {
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-uapiAcceptErrorCh:
close(uapiAcceptErrorCh)
case <-device.Wait():
err = ErrDeviceWaited
}
closers.cleanup(logger)
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
return err
}
return link, waitAndCleanup, nil
}
func acceptAndHandle(uapi net.Listener, device *device.Device,