chore(netlink): define own types with minimal fields

- Allow to swap `github.com/vishvananda/netlink`
- Allow to add build tags for each platform
- One step closer to development on non-Linux platforms
This commit is contained in:
Quentin McGaw
2023-05-29 06:44:58 +00:00
parent 163ac48ce4
commit 38ddcfa756
34 changed files with 828 additions and 493 deletions

View File

@@ -5,7 +5,6 @@ import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
)
func (w *Wireguard) addAddresses(link netlink.Link,
@@ -15,15 +14,14 @@ func (w *Wireguard) addAddresses(link netlink.Link,
continue
}
ipNet := ipNet
address := &netlink.Addr{
IPNet: routing.NetipPrefixToIPNet(&ipNet),
address := netlink.Addr{
Network: ipNet,
}
err = w.netlink.AddrReplace(link, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link %s",
err, address, link.Attrs().Name)
err, address, link.Name)
}
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -18,14 +17,6 @@ func Test_Wireguard_addAddresses(t *testing.T) {
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64)
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
errDummy := errors.New("dummy")
testCases := map[string]struct {
@@ -35,15 +26,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err error
}{
"success": {
link: newLink(),
link: netlink.Link{Type: "wireguard"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(nil)
netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}).
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -54,12 +45,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
},
"first add error": {
link: newLink(),
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
@@ -71,15 +62,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
},
"second add error": {
link: newLink(),
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(nil)
netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}).
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -91,7 +82,6 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
},
"ignore IPv6": {
link: newLink(),
addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
return &Wireguard{

View File

@@ -3,6 +3,7 @@ package wireguard
import (
"fmt"
"net"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -53,8 +54,14 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
PublicKey: publicKey,
PresharedKey: preSharedKey,
AllowedIPs: []net.IPNet{
*allIPv4(),
*allIPv6(),
{
IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
},
{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
},
},
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{
@@ -68,16 +75,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
return config, nil
}
func allIPv4() (ipNet *net.IPNet) {
return &net.IPNet{
IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
}
func allIPv4() (prefix netip.Prefix) {
const bits = 0
return netip.PrefixFrom(netip.IPv4Unspecified(), bits)
}
func allIPv6() (ipNet *net.IPNet) {
return &net.IPNet{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
}
func allIPv6() (prefix netip.Prefix) {
const bits = 0
return netip.PrefixFrom(netip.IPv6Unspecified(), bits)
}

View File

@@ -24,10 +24,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{})
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "test_8081"
link := &netlink.Bridge{
LinkAttrs: linkAttrs,
link := netlink.Link{
Type: "bridge",
Name: "test_8081",
}
// Remove any previously created test interface from a crashed/panic
@@ -37,8 +36,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
require.NoError(t, err)
}
err = netlinker.LinkAdd(link)
linkIndex, err := netlinker.LinkAdd(link)
require.NoError(t, err)
link.Index = linkIndex
defer func() {
err = netlinker.LinkDel(link)
@@ -63,14 +63,12 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
err = wg.addAddresses(link, addresses)
require.NoError(t, err)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FAMILY_ALL)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
require.NoError(t, err)
require.Equal(t, len(addresses), len(netlinkAddresses))
for i, netlinkAddress := range netlinkAddresses {
require.NotNil(t, netlinkAddress.IPNet)
ipNet, err := netip.ParsePrefix(netlinkAddress.IPNet.String())
require.NoError(t, err)
assert.Equal(t, addresses[i], ipNet)
require.NotNil(t, netlinkAddress.Network)
assert.Equal(t, addresses[i], netlinkAddress.Network)
}
}
}
@@ -95,7 +93,7 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
assert.NoError(t, err)
}()
rules, err := netlinker.RuleList(netlink.FAMILY_V4)
rules, err := netlinker.RuleList(netlink.FamilyV4)
require.NoError(t, err)
var rule netlink.Rule
var ruleFound bool
@@ -107,15 +105,10 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
}
require.True(t, ruleFound)
expectedRule := netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: 4294967295,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
}
assert.Equal(t, expectedRule, rule)

View File

@@ -5,7 +5,7 @@ import "github.com/qdm12/gluetun/internal/netlink"
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface {
AddrReplace(link netlink.Link, addr *netlink.Addr) error
AddrReplace(link netlink.Link, addr netlink.Addr) error
Router
Ruler
Linker
@@ -13,21 +13,21 @@ type NetLinker interface {
}
type Router interface {
RouteList(link netlink.Link, family int) (
RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error
RouteAdd(route netlink.Route) error
}
type Ruler interface {
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
type Linker interface {
LinkAdd(link netlink.Link) (err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) error
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error
}

View File

@@ -8,7 +8,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
netlink "github.com/vishvananda/netlink"
netlink "github.com/qdm12/gluetun/internal/netlink"
)
// MockNetLinker is a mock of NetLinker interface.
@@ -35,7 +35,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error {
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -64,11 +64,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkAdd indicates an expected call of LinkAdd.
@@ -136,11 +137,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -150,7 +152,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
}
// RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error)
@@ -164,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Route)
@@ -179,7 +181,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C
}
// RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error)
@@ -193,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
}
// RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error)

View File

@@ -2,17 +2,17 @@ package wireguard
import (
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
// TODO add IPv6 route if IPv6 is supported
func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark int) (err error) {
route := &netlink.Route{
LinkIndex: link.Attrs().Index,
route := netlink.Route{
LinkIndex: link.Index,
Dst: dst,
Table: firewallMark,
}
@@ -21,7 +21,7 @@ func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
if err != nil {
return fmt.Errorf(
"adding route for link %s, destination %s and table %d: %w",
link.Attrs().Name, dst, firewallMark, err)
link.Name, dst, firewallMark, err)
}
return err

View File

@@ -2,7 +2,7 @@ package wireguard
import (
"errors"
"net"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
@@ -15,41 +15,40 @@ func Test_Wireguard_addRoute(t *testing.T) {
t.Parallel()
const linkIndex = 88
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
linkAttrs.Index = linkIndex
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
ipPrefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
const firewallMark = 51820
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
dst *net.IPNet
expectedRoute *netlink.Route
dst netip.Prefix
expectedRoute netlink.Route
routeAddErr error
err error
}{
"success": {
link: newLink(),
dst: ipNet,
expectedRoute: &netlink.Route{
link: netlink.Link{
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipNet,
Dst: ipPrefix,
Table: firewallMark,
},
},
"route add error": {
link: newLink(),
dst: ipNet,
expectedRoute: &netlink.Route{
link: netlink.Link{
Name: "a_bridge",
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipNet,
Dst: ipPrefix,
Table: firewallMark,
},
routeAddErr: errDummy,

View File

@@ -21,54 +21,39 @@ func Test_Wireguard_addRule(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
expectedRule *netlink.Rule
expectedRule netlink.Rule
ruleAddErr error
err error
ruleDelErr error
cleanupErr error
}{
"success": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family,
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Family: family,
},
},
"rule add error": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family,
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Family: family,
},
ruleAddErr: errDummy,
err: errors.New("adding rule ip rule 987: from all to all table 456: dummy"),
},
"rule delete error": {
expectedRule: &netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family,
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Family: family,
},
ruleDelErr: errDummy,
cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"),

View File

@@ -93,10 +93,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return
}
if err := w.netlink.LinkSetUp(link); err != nil {
linkIndex, err := w.netlink.LinkSetUp(link)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return
}
link.Index = linkIndex
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link)
})
@@ -161,17 +163,16 @@ func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
linkAttrs := netlink.LinkAttrs{
link = netlink.Link{
Type: "wireguard",
Name: interfaceName,
MTU: int(mtu),
MTU: mtu,
}
link = &netlink.Wireguard{
LinkAttrs: linkAttrs,
}
err = netLinker.LinkAdd(link)
linkIndex, err := netLinker.LinkAdd(link)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
}
link.Index = linkIndex
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
})
@@ -191,22 +192,22 @@ func setupUserSpace(ctx context.Context,
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
}
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
if err != nil {
return nil, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName {
return nil, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
}
link, err = netLinker.LinkByName(interfaceName)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
@@ -226,14 +227,14 @@ func setupUserSpace(ctx context.Context,
uapiFile, err := ipc.UAPIOpen(interfaceName)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)