fix(wireguard): ignore IPv6 addresses if IPv6 disabled

This commit is contained in:
Quentin McGaw
2022-10-17 06:31:32 +00:00
parent c954e6f231
commit f70609c464
3 changed files with 66 additions and 29 deletions

View File

@@ -10,6 +10,11 @@ import (
func (w *Wireguard) addAddresses(link netlink.Link, func (w *Wireguard) addAddresses(link netlink.Link,
addresses []*net.IPNet) (err error) { addresses []*net.IPNet) (err error) {
for _, ipNet := range addresses { for _, ipNet := range addresses {
ipNetIsIPv6 := ipNet.IP.To4() == nil
if !*w.settings.IPv6 && ipNetIsIPv6 {
continue
}
address := &netlink.Addr{ address := &netlink.Addr{
IPNet: ipNet, IPNet: ipNet,
} }

View File

@@ -15,7 +15,7 @@ func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel() t.Parallel()
ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)} ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
ipNetTwo := &net.IPNet{IP: net.IPv4(4, 5, 6, 7), Mask: net.IPv4Mask(255, 255, 255, 128)} ipNetTwo := &net.IPNet{IP: net.ParseIP("::1234"), Mask: net.CIDRMask(64, 128)}
newLink := func() netlink.Link { newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs() linkAttrs := netlink.NewLinkAttrs()
@@ -28,37 +28,77 @@ func Test_Wireguard_addAddresses(t *testing.T) {
errDummy := errors.New("dummy") errDummy := errors.New("dummy")
testCases := map[string]struct { testCases := map[string]struct {
link netlink.Link link netlink.Link
addrs []*net.IPNet addrs []*net.IPNet
expectedAddrs []*netlink.Addr wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard
addrAddErrs []error err error
err error
}{ }{
"success": { "success": {
link: newLink(), link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo}, addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{ wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
{IPNet: ipNetOne}, {IPNet: ipNetTwo}, netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}).
Return(nil)
netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetTwo}).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
}, },
addrAddErrs: []error{nil, nil},
}, },
"first add error": { "first add error": {
link: newLink(), link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo}, addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{ wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
{IPNet: ipNetOne}, netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
}, },
addrAddErrs: []error{errDummy}, err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
}, },
"second add error": { "second add error": {
link: newLink(), link: newLink(),
addrs: []*net.IPNet{ipNetOne, ipNetTwo}, addrs: []*net.IPNet{ipNetOne, ipNetTwo},
expectedAddrs: []*netlink.Addr{ wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
{IPNet: ipNetOne}, {IPNet: ipNetTwo}, netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}).
Return(nil)
netLinker.EXPECT().
AddrAdd(link, &netlink.Addr{IPNet: ipNetTwo}).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
settings: Settings{
IPv6: ptrTo(true),
},
}
},
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
},
"ignore IPv6": {
link: newLink(),
addrs: []*net.IPNet{ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
return &Wireguard{
settings: Settings{
IPv6: ptrTo(false),
},
}
}, },
addrAddErrs: []error{nil, errDummy},
err: errors.New("dummy: when adding address 4.5.6.7/25 to link a_bridge"),
}, },
} }
@@ -68,18 +108,7 @@ func Test_Wireguard_addAddresses(t *testing.T) {
t.Parallel() t.Parallel()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
require.Equal(t, len(testCase.expectedAddrs), len(testCase.addrAddErrs)) wg := testCase.wgBuilder(ctrl, testCase.link)
netLinker := NewMockNetLinker(ctrl)
wg := Wireguard{
netlink: netLinker,
}
for i := range testCase.expectedAddrs {
netLinker.EXPECT().
AddrAdd(testCase.link, testCase.expectedAddrs[i]).
Return(testCase.addrAddErrs[i])
}
err := wg.addAddresses(testCase.link, testCase.addrs) err := wg.addAddresses(testCase.link, testCase.addrs)

View File

@@ -0,0 +1,3 @@
package wireguard
func ptrTo[T any](x T) *T { return &x }