diff --git a/internal/wireguard/address.go b/internal/wireguard/address.go index 7e0aaf34..b6c50b44 100644 --- a/internal/wireguard/address.go +++ b/internal/wireguard/address.go @@ -10,6 +10,11 @@ import ( func (w *Wireguard) addAddresses(link netlink.Link, addresses []*net.IPNet) (err error) { for _, ipNet := range addresses { + ipNetIsIPv6 := ipNet.IP.To4() == nil + if !*w.settings.IPv6 && ipNetIsIPv6 { + continue + } + address := &netlink.Addr{ IPNet: ipNet, } diff --git a/internal/wireguard/address_test.go b/internal/wireguard/address_test.go index bbbb8569..34ff7ae7 100644 --- a/internal/wireguard/address_test.go +++ b/internal/wireguard/address_test.go @@ -15,7 +15,7 @@ func Test_Wireguard_addAddresses(t *testing.T) { t.Parallel() ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)} - ipNetTwo := &net.IPNet{IP: net.IPv4(4, 5, 6, 7), Mask: net.IPv4Mask(255, 255, 255, 128)} + ipNetTwo := &net.IPNet{IP: net.ParseIP("::1234"), Mask: net.CIDRMask(64, 128)} newLink := func() netlink.Link { linkAttrs := netlink.NewLinkAttrs() @@ -28,37 +28,77 @@ func Test_Wireguard_addAddresses(t *testing.T) { errDummy := errors.New("dummy") testCases := map[string]struct { - link netlink.Link - addrs []*net.IPNet - expectedAddrs []*netlink.Addr - addrAddErrs []error - err error + link netlink.Link + addrs []*net.IPNet + wgBuilder func(ctrl *gomock.Controller, link netlink.Link) *Wireguard + err error }{ "success": { link: newLink(), addrs: []*net.IPNet{ipNetOne, ipNetTwo}, - expectedAddrs: []*netlink.Addr{ - {IPNet: ipNetOne}, {IPNet: ipNetTwo}, + wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { + netLinker := NewMockNetLinker(ctrl) + firstCall := netLinker.EXPECT(). + AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}). + Return(nil) + netLinker.EXPECT(). + AddrAdd(link, &netlink.Addr{IPNet: ipNetTwo}). + Return(nil).After(firstCall) + return &Wireguard{ + netlink: netLinker, + settings: Settings{ + IPv6: ptrTo(true), + }, + } }, - addrAddErrs: []error{nil, nil}, }, "first add error": { link: newLink(), addrs: []*net.IPNet{ipNetOne, ipNetTwo}, - expectedAddrs: []*netlink.Addr{ - {IPNet: ipNetOne}, + wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { + netLinker := NewMockNetLinker(ctrl) + netLinker.EXPECT(). + AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}). + Return(errDummy) + return &Wireguard{ + netlink: netLinker, + settings: Settings{ + IPv6: ptrTo(true), + }, + } }, - addrAddErrs: []error{errDummy}, - err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"), + err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"), }, "second add error": { link: newLink(), addrs: []*net.IPNet{ipNetOne, ipNetTwo}, - expectedAddrs: []*netlink.Addr{ - {IPNet: ipNetOne}, {IPNet: ipNetTwo}, + wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { + netLinker := NewMockNetLinker(ctrl) + firstCall := netLinker.EXPECT(). + AddrAdd(link, &netlink.Addr{IPNet: ipNetOne}). + Return(nil) + netLinker.EXPECT(). + AddrAdd(link, &netlink.Addr{IPNet: ipNetTwo}). + Return(errDummy).After(firstCall) + return &Wireguard{ + netlink: netLinker, + settings: Settings{ + IPv6: ptrTo(true), + }, + } + }, + err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"), + }, + "ignore IPv6": { + link: newLink(), + addrs: []*net.IPNet{ipNetTwo}, + wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { + return &Wireguard{ + settings: Settings{ + IPv6: ptrTo(false), + }, + } }, - addrAddErrs: []error{nil, errDummy}, - err: errors.New("dummy: when adding address 4.5.6.7/25 to link a_bridge"), }, } @@ -68,18 +108,7 @@ func Test_Wireguard_addAddresses(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - require.Equal(t, len(testCase.expectedAddrs), len(testCase.addrAddErrs)) - - netLinker := NewMockNetLinker(ctrl) - wg := Wireguard{ - netlink: netLinker, - } - - for i := range testCase.expectedAddrs { - netLinker.EXPECT(). - AddrAdd(testCase.link, testCase.expectedAddrs[i]). - Return(testCase.addrAddErrs[i]) - } + wg := testCase.wgBuilder(ctrl, testCase.link) err := wg.addAddresses(testCase.link, testCase.addrs) diff --git a/internal/wireguard/helpers_test.go b/internal/wireguard/helpers_test.go new file mode 100644 index 00000000..2db84e77 --- /dev/null +++ b/internal/wireguard/helpers_test.go @@ -0,0 +1,3 @@ +package wireguard + +func ptrTo[T any](x T) *T { return &x }