diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index a5c3a66b..54103160 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -531,30 +531,30 @@ type netLinker interface { type Addresser interface { AddrList(link netlink.Link, family int) ( addresses []netlink.Addr, err error) - AddrReplace(link netlink.Link, addr *netlink.Addr) error + AddrReplace(link netlink.Link, addr netlink.Addr) error } type Router interface { - RouteList(link netlink.Link, family int) ( + RouteList(link *netlink.Link, family int) ( routes []netlink.Route, err error) - RouteAdd(route *netlink.Route) error - RouteDel(route *netlink.Route) error - RouteReplace(route *netlink.Route) error + RouteAdd(route netlink.Route) error + RouteDel(route netlink.Route) error + RouteReplace(route netlink.Route) error } type Ruler interface { RuleList(family int) (rules []netlink.Rule, err error) - RuleAdd(rule *netlink.Rule) error - RuleDel(rule *netlink.Rule) error + RuleAdd(rule netlink.Rule) error + RuleDel(rule netlink.Rule) error } type Linker interface { LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) LinkByIndex(index int) (link netlink.Link, err error) - LinkAdd(link netlink.Link) (err error) + LinkAdd(link netlink.Link) (linkIndex int, err error) LinkDel(link netlink.Link) (err error) - LinkSetUp(link netlink.Link) (err error) + LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetDown(link netlink.Link) (err error) } diff --git a/internal/netlink/address.go b/internal/netlink/address.go index cabf74f2..6340d145 100644 --- a/internal/netlink/address.go +++ b/internal/netlink/address.go @@ -1,14 +1,40 @@ package netlink -import "github.com/vishvananda/netlink" +import ( + "net/netip" -type Addr = netlink.Addr + "github.com/vishvananda/netlink" +) + +type Addr struct { + Network netip.Prefix +} + +func (a Addr) String() string { + return a.Network.String() +} func (n *NetLink) AddrList(link Link, family int) ( addresses []Addr, err error) { - return netlink.AddrList(link, family) + netlinkLink := linkToNetlinkLink(&link) + netlinkAddresses, err := netlink.AddrList(netlinkLink, family) + if err != nil { + return nil, err + } + + addresses = make([]Addr, len(netlinkAddresses)) + for i := range netlinkAddresses { + addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet) + } + + return addresses, nil } -func (n *NetLink) AddrReplace(link Link, addr *Addr) error { - return netlink.AddrReplace(link, addr) +func (n *NetLink) AddrReplace(link Link, addr Addr) error { + netlinkLink := linkToNetlinkLink(&link) + netlinkAddress := netlink.Addr{ + IPNet: netipPrefixToIPNet(addr.Network), + } + + return netlink.AddrReplace(netlinkLink, &netlinkAddress) } diff --git a/internal/netlink/conversion.go b/internal/netlink/conversion.go new file mode 100644 index 00000000..55910fc4 --- /dev/null +++ b/internal/netlink/conversion.go @@ -0,0 +1,62 @@ +package netlink + +import ( + "fmt" + "net" + "net/netip" +) + +func netipPrefixToIPNet(prefix netip.Prefix) (ipNet *net.IPNet) { + if !prefix.IsValid() { + return nil + } + + prefixAddr := prefix.Addr().Unmap() + ipMask := net.CIDRMask(prefix.Bits(), prefixAddr.BitLen()) + ip := netipAddrToNetIP(prefixAddr) + + return &net.IPNet{ + IP: ip, + Mask: ipMask, + } +} + +func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) { + if ipNet == nil || (len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len) { + return prefix + } + + var ip netip.Addr + if ipv4 := ipNet.IP.To4(); ipv4 != nil { + ip = netip.AddrFrom4([4]byte(ipv4)) + } else { + ip = netip.AddrFrom16([16]byte(ipNet.IP)) + } + bits, _ := ipNet.Mask.Size() + return netip.PrefixFrom(ip, bits) +} + +func netipAddrToNetIP(address netip.Addr) (ip net.IP) { + switch { + case !address.IsValid(): + return nil + case address.Is4() || address.Is4In6(): + bytes := address.As4() + return net.IP(bytes[:]) + default: + bytes := address.As16() + return net.IP(bytes[:]) + } +} + +func netIPToNetipAddress(ip net.IP) (address netip.Addr) { + if len(ip) != net.IPv4len && len(ip) != net.IPv6len { + return address // invalid + } + + address, ok := netip.AddrFromSlice(ip) + if !ok { + panic(fmt.Sprintf("converting %#v to netip.Addr failed", ip)) + } + return address.Unmap() +} diff --git a/internal/netlink/conversion_test.go b/internal/netlink/conversion_test.go new file mode 100644 index 00000000..c9e89594 --- /dev/null +++ b/internal/netlink/conversion_test.go @@ -0,0 +1,146 @@ +package netlink + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_netipPrefixToIPNet(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + prefix netip.Prefix + ipNet *net.IPNet + }{ + "empty_prefix": {}, + "IPv4_prefix": { + prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), + ipNet: &net.IPNet{ + IP: net.IP{1, 2, 3, 4}, + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, + "IPv4-in-IPv6_prefix": { + prefix: netip.PrefixFrom(netip.AddrFrom16( + [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4}), + 24), + ipNet: &net.IPNet{ + IP: net.IP{1, 2, 3, 4}, + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, + "IPv6_prefix": { + prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8), + ipNet: &net.IPNet{ + IP: net.IPv6loopback, + Mask: net.IPMask{0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + ipNet := netipPrefixToIPNet(testCase.prefix) + + assert.Equal(t, testCase.ipNet, ipNet) + }) + } +} + +func Test_netIPNetToNetipPrefix(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + ipNet *net.IPNet + prefix netip.Prefix + }{ + "empty ipnet": {}, + "custom sized IP in ipnet": { + ipNet: &net.IPNet{ + IP: net.IP{1}, + }, + }, + "IPv4 ipnet": { + ipNet: &net.IPNet{ + IP: net.IP{1, 2, 3, 4}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), + }, + "IPv4-in-IPv6 ipnet": { + ipNet: &net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPMask{255, 255, 255, 0}, + }, + prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), + }, + "IPv6 ipnet": { + ipNet: &net.IPNet{ + IP: net.IPv6loopback, + Mask: net.IPMask{0xff}, + }, + prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8), + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + prefix := netIPNetToNetipPrefix(testCase.ipNet) + + assert.Equal(t, testCase.prefix, prefix) + }) + } +} + +func Test_netIPToNetipAddress(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + ip net.IP + address netip.Addr + panicMessage string + }{ + "nil_ip": {}, + "ip_not_ipv4_or_ipv6": { + ip: net.IP{1}, + }, + "IPv4": { + ip: net.IPv4(1, 2, 3, 4), + address: netip.AddrFrom4([4]byte{1, 2, 3, 4}), + }, + "IPv6": { + ip: net.IPv6zero, + address: netip.AddrFrom16([16]byte{}), + }, + "IPv4 prefixed with 0xffff": { + ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4}, + address: netip.AddrFrom4([4]byte{1, 2, 3, 4}), + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + if testCase.panicMessage != "" { + assert.PanicsWithValue(t, testCase.panicMessage, func() { + netIPToNetipAddress(testCase.ip) + }) + return + } + + address := netIPToNetipAddress(testCase.ip) + assert.Equal(t, testCase.address, address) + }) + } +} diff --git a/internal/netlink/family.go b/internal/netlink/family.go index b70799d5..8c6f6587 100644 --- a/internal/netlink/family.go +++ b/internal/netlink/family.go @@ -6,20 +6,19 @@ import ( "github.com/vishvananda/netlink" ) -//nolint:revive const ( - FAMILY_ALL = netlink.FAMILY_ALL - FAMILY_V4 = netlink.FAMILY_V4 - FAMILY_V6 = netlink.FAMILY_V6 + FamilyAll = 0 + FamilyV4 = 2 + FamilyV6 = 10 ) func FamilyToString(family int) string { switch family { - case FAMILY_ALL: - return "all" - case FAMILY_V4: + case FamilyAll: + return "all" //nolint:goconst + case FamilyV4: return "v4" - case FAMILY_V6: + case FamilyV6: return "v6" default: return fmt.Sprint(family) diff --git a/internal/netlink/ipv6.go b/internal/netlink/ipv6.go index 0e75aff8..6a476fd5 100644 --- a/internal/netlink/ipv6.go +++ b/internal/netlink/ipv6.go @@ -14,20 +14,21 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) { var totalRoutes uint for _, link := range links { - routes, err := n.RouteList(link, netlink.FAMILY_V6) + link := link + routes, err := n.RouteList(&link, netlink.FAMILY_V6) if err != nil { return false, fmt.Errorf("listing IPv6 routes for link %s: %w", - link.Attrs().Name, err) + link.Name, err) } // Check each route for IPv6 due to Podman bug listing IPv4 routes // as IPv6 routes at container start, see: // https://github.com/qdm12/gluetun/issues/1241#issuecomment-1333405949 for _, route := range routes { - sourceIsIPv6 := route.Src != nil && route.Src.To4() == nil - destinationIsIPv6 := route.Dst != nil && route.Dst.IP.To4() == nil + sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6() + destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6() if sourceIsIPv6 || destinationIsIPv6 { - n.debugLogger.Debugf("IPv6 is supported by link %s", link.Attrs().Name) + n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name) return true, nil } totalRoutes++ diff --git a/internal/netlink/link.go b/internal/netlink/link.go index 13ff28f3..e28bc69f 100644 --- a/internal/netlink/link.go +++ b/internal/netlink/link.go @@ -2,36 +2,117 @@ package netlink import "github.com/vishvananda/netlink" -type ( - Link = netlink.Link - Bridge = netlink.Bridge - Wireguard = netlink.Wireguard -) +type Link struct { + Type string + Name string + Index int + EncapType string + MTU uint16 + + NetNsID int + TxQLen int +} func (n *NetLink) LinkList() (links []Link, err error) { - return netlink.LinkList() + netlinkLinks, err := netlink.LinkList() + if err != nil { + return nil, err + } + + links = make([]Link, len(netlinkLinks)) + for i := range netlinkLinks { + links[i] = netlinkLinkToLink(netlinkLinks[i]) + } + + return links, nil } func (n *NetLink) LinkByName(name string) (link Link, err error) { - return netlink.LinkByName(name) + netlinkLink, err := netlink.LinkByName(name) + if err != nil { + return Link{}, err + } + + return netlinkLinkToLink(netlinkLink), nil } func (n *NetLink) LinkByIndex(index int) (link Link, err error) { - return netlink.LinkByIndex(index) + netlinkLink, err := netlink.LinkByIndex(index) + if err != nil { + return Link{}, err + } + + return netlinkLinkToLink(netlinkLink), nil } -func (n *NetLink) LinkAdd(link Link) (err error) { - return netlink.LinkAdd(link) +func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) { + netlinkLink := linkToNetlinkLink(&link) + err = netlink.LinkAdd(netlinkLink) + if err != nil { + return 0, err + } + return netlinkLink.Attrs().Index, nil } func (n *NetLink) LinkDel(link Link) (err error) { - return netlink.LinkDel(link) + return netlink.LinkDel(linkToNetlinkLink(&link)) } -func (n *NetLink) LinkSetUp(link Link) (err error) { - return netlink.LinkSetUp(link) +func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) { + netlinkLink := linkToNetlinkLink(&link) + err = netlink.LinkSetUp(netlinkLink) + if err != nil { + return 0, err + } + return netlinkLink.Attrs().Index, nil } func (n *NetLink) LinkSetDown(link Link) (err error) { - return netlink.LinkSetDown(link) + return netlink.LinkSetDown(linkToNetlinkLink(&link)) +} + +type netlinkLinkImpl struct { + attrs *netlink.LinkAttrs + linkType string +} + +func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs { + return n.attrs +} + +func (n *netlinkLinkImpl) Type() string { + return n.linkType +} + +func netlinkLinkToLink(netlinkLink netlink.Link) Link { + attributes := netlinkLink.Attrs() + return Link{ + Type: netlinkLink.Type(), + Name: attributes.Name, + Index: attributes.Index, + EncapType: attributes.EncapType, + MTU: uint16(attributes.MTU), + NetNsID: attributes.NetNsID, + TxQLen: attributes.TxQLen, + } +} + +// Warning: we must return `netlink.Link` and not `netlinkLinkImpl` +// so that the vishvananda/netlink package can compare the returned +// value against an untyped nil. +func linkToNetlinkLink(link *Link) netlink.Link { + if link == nil { + return nil + } + return &netlinkLinkImpl{ + linkType: link.Type, + attrs: &netlink.LinkAttrs{ // TODO get all original attributes + Name: link.Name, + Index: link.Index, + EncapType: link.EncapType, + MTU: int(link.MTU), + NetNsID: link.NetNsID, + TxQLen: link.TxQLen, + }, + } } diff --git a/internal/netlink/linkattrs.go b/internal/netlink/linkattrs.go deleted file mode 100644 index fd17f19b..00000000 --- a/internal/netlink/linkattrs.go +++ /dev/null @@ -1,9 +0,0 @@ -package netlink - -import "github.com/vishvananda/netlink" - -type LinkAttrs = netlink.LinkAttrs - -func NewLinkAttrs() LinkAttrs { - return netlink.NewLinkAttrs() -} diff --git a/internal/netlink/route.go b/internal/netlink/route.go index 171a6d1a..d29c54e8 100644 --- a/internal/netlink/route.go +++ b/internal/netlink/route.go @@ -1,22 +1,74 @@ package netlink -import "github.com/vishvananda/netlink" +import ( + "net/netip" -type Route = netlink.Route + "github.com/vishvananda/netlink" +) -func (n *NetLink) RouteList(link Link, family int) ( +type Route struct { + LinkIndex int + Dst netip.Prefix + Src netip.Addr + Gw netip.Addr + Priority int + Family int + Table int + Type int +} + +func (n *NetLink) RouteList(link *Link, family int) ( routes []Route, err error) { - return netlink.RouteList(link, family) + netlinkLink := linkToNetlinkLink(link) + netlinkRoutes, err := netlink.RouteList(netlinkLink, family) + if err != nil { + return nil, err + } + + routes = make([]Route, len(netlinkRoutes)) + for i := range netlinkRoutes { + routes[i] = netlinkRouteToRoute(netlinkRoutes[i]) + } + return routes, nil } -func (n *NetLink) RouteAdd(route *Route) error { - return netlink.RouteAdd(route) +func (n *NetLink) RouteAdd(route Route) error { + netlinkRoute := routeToNetlinkRoute(route) + return netlink.RouteAdd(&netlinkRoute) } -func (n *NetLink) RouteDel(route *Route) error { - return netlink.RouteDel(route) +func (n *NetLink) RouteDel(route Route) error { + netlinkRoute := routeToNetlinkRoute(route) + return netlink.RouteDel(&netlinkRoute) } -func (n *NetLink) RouteReplace(route *Route) error { - return netlink.RouteReplace(route) +func (n *NetLink) RouteReplace(route Route) error { + netlinkRoute := routeToNetlinkRoute(route) + return netlink.RouteReplace(&netlinkRoute) +} + +func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) { + return Route{ + LinkIndex: netlinkRoute.LinkIndex, + Dst: netIPNetToNetipPrefix(netlinkRoute.Dst), + Src: netIPToNetipAddress(netlinkRoute.Src), + Gw: netIPToNetipAddress(netlinkRoute.Gw), + Priority: netlinkRoute.Priority, + Family: netlinkRoute.Family, + Table: netlinkRoute.Table, + Type: netlinkRoute.Type, + } +} + +func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) { + return netlink.Route{ + LinkIndex: route.LinkIndex, + Dst: netipPrefixToIPNet(route.Dst), + Src: netipAddrToNetIP(route.Src), + Gw: netipAddrToNetIP(route.Gw), + Priority: route.Priority, + Family: route.Family, + Table: route.Table, + Type: route.Type, + } } diff --git a/internal/netlink/rule.go b/internal/netlink/rule.go index 8fcdd38b..92656b46 100644 --- a/internal/netlink/rule.go +++ b/internal/netlink/rule.go @@ -1,21 +1,90 @@ package netlink -import "github.com/vishvananda/netlink" +import ( + "fmt" + "net/netip" -type Rule = netlink.Rule + "github.com/vishvananda/netlink" +) -func NewRule() *Rule { - return netlink.NewRule() +type Rule struct { + Priority int + Family int + Table int + Mark int + Src netip.Prefix + Dst netip.Prefix + Invert bool +} + +func (r Rule) String() string { + from := "all" + if r.Src.IsValid() { + from = r.Src.String() + } + + to := "all" + if r.Dst.IsValid() { + to = r.Dst.String() + } + + return fmt.Sprintf("ip rule %d: from %s to %s table %d", + r.Priority, from, to, r.Table) +} + +func NewRule() Rule { + // defaults found from netlink.NewRule() for fields we use, + // the rest of the defaults is set when converting from a `Rule` + // to a `netlink.Rule` + return Rule{ + Priority: -1, + Mark: -1, + } } func (n *NetLink) RuleList(family int) (rules []Rule, err error) { - return netlink.RuleList(family) + netlinkRules, err := netlink.RuleList(family) + if err != nil { + return nil, err + } + + rules = make([]Rule, len(netlinkRules)) + for i := range netlinkRules { + rules[i] = netlinkRuleToRule(netlinkRules[i]) + } + return rules, nil } -func (n *NetLink) RuleAdd(rule *Rule) error { - return netlink.RuleAdd(rule) +func (n *NetLink) RuleAdd(rule Rule) error { + netlinkRule := ruleToNetlinkRule(rule) + return netlink.RuleAdd(&netlinkRule) } -func (n *NetLink) RuleDel(rule *Rule) error { - return netlink.RuleDel(rule) +func (n *NetLink) RuleDel(rule Rule) error { + netlinkRule := ruleToNetlinkRule(rule) + return netlink.RuleDel(&netlinkRule) +} + +func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) { + netlinkRule = *netlink.NewRule() + netlinkRule.Priority = rule.Priority + netlinkRule.Family = rule.Family + netlinkRule.Table = rule.Table + netlinkRule.Mark = rule.Mark + netlinkRule.Src = netipPrefixToIPNet(rule.Src) + netlinkRule.Dst = netipPrefixToIPNet(rule.Dst) + netlinkRule.Invert = rule.Invert + return netlinkRule +} + +func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) { + return Rule{ + Priority: netlinkRule.Priority, + Family: netlinkRule.Family, + Table: netlinkRule.Table, + Mark: netlinkRule.Mark, + Src: netIPNetToNetipPrefix(netlinkRule.Src), + Dst: netIPNetToNetipPrefix(netlinkRule.Dst), + Invert: netlinkRule.Invert, + } } diff --git a/internal/routing/conversion.go b/internal/routing/conversion.go index 1cb0fd4d..2d38a3df 100644 --- a/internal/routing/conversion.go +++ b/internal/routing/conversion.go @@ -6,34 +6,6 @@ import ( "net/netip" ) -func NetipPrefixToIPNet(prefix *netip.Prefix) (ipNet *net.IPNet) { - if prefix == nil { - return nil - } - - s := prefix.String() - ip, ipNet, err := net.ParseCIDR(s) - if err != nil { - panic(err) - } - ipNet.IP = ip - return ipNet -} - -func netIPNetToNetipPrefix(ipNet net.IPNet) (prefix netip.Prefix) { - if len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len { - return prefix - } - var ip netip.Addr - if ipv4 := ipNet.IP.To4(); ipv4 != nil { - ip = netip.AddrFrom4([4]byte(ipv4)) - } else { - ip = netip.AddrFrom16([16]byte(ipNet.IP)) - } - bits, _ := ipNet.Mask.Size() - return netip.PrefixFrom(ip, bits) -} - func netIPToNetipAddress(ip net.IP) (address netip.Addr) { address, ok := netip.AddrFromSlice(ip) if !ok { diff --git a/internal/routing/conversion_test.go b/internal/routing/conversion_test.go index 53002956..a30dfbb5 100644 --- a/internal/routing/conversion_test.go +++ b/internal/routing/conversion_test.go @@ -8,54 +8,6 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_netIPNetToNetipPrefix(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - ipNet net.IPNet - prefix netip.Prefix - }{ - "empty ipnet": {}, - "custom sized IP in ipnet": { - ipNet: net.IPNet{ - IP: net.IP{1}, - }, - }, - "IPv4 ipnet": { - ipNet: net.IPNet{ - IP: net.IP{1, 2, 3, 4}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), - }, - "IPv4-in-IPv6 ipnet": { - ipNet: net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), - Mask: net.IPMask{255, 255, 255, 0}, - }, - prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), - }, - "IPv6 ipnet": { - ipNet: net.IPNet{ - IP: net.IPv6loopback, - Mask: net.IPMask{0xff}, - }, - prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8), - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - prefix := netIPNetToNetipPrefix(testCase.ipNet) - - assert.Equal(t, testCase.prefix, prefix) - }) - } -} - func Test_netIPToNetipAddress(t *testing.T) { t.Parallel() diff --git a/internal/routing/default.go b/internal/routing/default.go index 60e0746e..d550ec7d 100644 --- a/internal/routing/default.go +++ b/internal/routing/default.go @@ -25,17 +25,17 @@ func (d DefaultRoute) String() string { } func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll) if err != nil { return nil, fmt.Errorf("listing routes: %w", err) } for _, route := range routes { - if route.Dst != nil { + if route.Dst.IsValid() { continue } defaultRoute := DefaultRoute{ - Gateway: netIPToNetipAddress(route.Gw), + Gateway: route.Gw, Family: route.Family, } linkIndex := route.LinkIndex @@ -43,11 +43,10 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) { if err != nil { return nil, fmt.Errorf("obtaining link by index: for default route at index %d: %w", linkIndex, err) } - attributes := link.Attrs() - defaultRoute.NetInterface = attributes.Name - family := netlink.FAMILY_V6 - if route.Gw.To4() != nil { - family = netlink.FAMILY_V4 + defaultRoute.NetInterface = link.Name + family := netlink.FamilyV6 + if route.Gw.Is4() { + family = netlink.FamilyV4 } defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface, family) if err != nil { diff --git a/internal/routing/inbound.go b/internal/routing/inbound.go index e402d3c5..239f5666 100644 --- a/internal/routing/inbound.go +++ b/internal/routing/inbound.go @@ -23,7 +23,7 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err for _, defaultRoute := range defaultRoutes { defaultDestination := defaultDestinationIPv4 - if defaultRoute.Family == netlink.FAMILY_V6 { + if defaultRoute.Family == netlink.FamilyV6 { defaultDestination = defaultDestinationIPv6 } @@ -43,7 +43,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e for _, defaultRoute := range defaultRoutes { defaultDestination := defaultDestinationIPv4 - if defaultRoute.Family == netlink.FAMILY_V6 { + if defaultRoute.Family == netlink.FamilyV6 { defaultDestination = defaultDestinationIPv6 } @@ -68,8 +68,8 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo bits = 128 } defaultIPMasked := netip.PrefixFrom(assignedIP, bits) - ruleDstNet := (*netip.Prefix)(nil) - err = r.addIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority) + ruleDstNet := netip.Prefix{} + err = r.addIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority) if err != nil { return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err) } @@ -86,8 +86,8 @@ func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRo bits = 128 } defaultIPMasked := netip.PrefixFrom(assignedIP, bits) - ruleDstNet := (*netip.Prefix)(nil) - err = r.deleteIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority) + ruleDstNet := netip.Prefix{} + err = r.deleteIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority) if err != nil { return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err) } diff --git a/internal/routing/ip.go b/internal/routing/ip.go index 4fc96376..9f9d7429 100644 --- a/internal/routing/ip.go +++ b/internal/routing/ip.go @@ -19,8 +19,8 @@ var ( ) func ipMatchesFamily(ip netip.Addr, family int) bool { - return (family == netlink.FAMILY_V4 && ip.Is4()) || - (family == netlink.FAMILY_V6 && ip.Is6()) + return (family == netlink.FamilyV4 && ip.Is4()) || + (family == netlink.FamilyV6 && ip.Is6()) } func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) { diff --git a/internal/routing/local.go b/internal/routing/local.go index 59c7e2c9..075ca2a7 100644 --- a/internal/routing/local.go +++ b/internal/routing/local.go @@ -29,25 +29,25 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { localLinks := make(map[int]struct{}) for _, link := range links { - if link.Attrs().EncapType != "ether" { + if link.EncapType != "ether" { continue } - localLinks[link.Attrs().Index] = struct{}{} - r.logger.Info("local ethernet link found: " + link.Attrs().Name) + localLinks[link.Index] = struct{}{} + r.logger.Info("local ethernet link found: " + link.Name) } if len(localLinks) == 0 { return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links)) } - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll) if err != nil { return localNetworks, fmt.Errorf("listing routes: %w", err) } for _, route := range routes { - if route.Gw != nil || route.Dst == nil { + if route.Gw.IsValid() || !route.Dst.IsValid() { continue } else if _, ok := localLinks[route.LinkIndex]; !ok { continue @@ -55,7 +55,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { var localNet LocalNetwork - localNet.IPNet = netIPNetToNetipPrefix(*route.Dst) + localNet.IPNet = route.Dst r.logger.Info("local ipnet found: " + localNet.IPNet.String()) link, err := r.netLinker.LinkByIndex(route.LinkIndex) @@ -63,11 +63,11 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { return localNetworks, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err) } - localNet.InterfaceName = link.Attrs().Name + localNet.InterfaceName = link.Name - family := netlink.FAMILY_V6 + family := netlink.FamilyV6 if localNet.IPNet.Addr().Is4() { - family = netlink.FAMILY_V4 + family = netlink.FamilyV4 } ip, err := r.assignedIP(localNet.InterfaceName, family) if err != nil { @@ -96,7 +96,8 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) { const localPriority = 98 // Main table was setup correctly by Docker, just need to add rules to use it - err = r.addIPRule(nil, &subnet.IPNet, mainTable, localPriority) + src := netip.Prefix{} + err = r.addIPRule(src, subnet.IPNet, mainTable, localPriority) if err != nil { return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err) } diff --git a/internal/routing/mocks_test.go b/internal/routing/mocks_test.go index 23e73cd4..90f41438 100644 --- a/internal/routing/mocks_test.go +++ b/internal/routing/mocks_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - netlink "github.com/vishvananda/netlink" + netlink "github.com/qdm12/gluetun/internal/netlink" ) // MockNetLinker is a mock of NetLinker interface. @@ -50,7 +50,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca } // AddrReplace mocks base method. -func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error { +func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) ret0, _ := ret[0].(error) @@ -79,11 +79,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call { } // LinkAdd mocks base method. -func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error { +func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkAdd", arg0) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 } // LinkAdd indicates an expected call of LinkAdd. @@ -166,11 +167,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call } // LinkSetUp mocks base method. -func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error { +func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkSetUp", arg0) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 } // LinkSetUp indicates an expected call of LinkSetUp. @@ -180,7 +182,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call { } // RouteAdd mocks base method. -func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error { +func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RouteAdd", arg0) ret0, _ := ret[0].(error) @@ -194,7 +196,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call { } // RouteDel mocks base method. -func (m *MockNetLinker) RouteDel(arg0 *netlink.Route) error { +func (m *MockNetLinker) RouteDel(arg0 netlink.Route) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RouteDel", arg0) ret0, _ := ret[0].(error) @@ -208,7 +210,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call { } // RouteList mocks base method. -func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) { +func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RouteList", arg0, arg1) ret0, _ := ret[0].([]netlink.Route) @@ -223,7 +225,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C } // RouteReplace mocks base method. -func (m *MockNetLinker) RouteReplace(arg0 *netlink.Route) error { +func (m *MockNetLinker) RouteReplace(arg0 netlink.Route) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RouteReplace", arg0) ret0, _ := ret[0].(error) @@ -237,7 +239,7 @@ func (mr *MockNetLinkerMockRecorder) RouteReplace(arg0 interface{}) *gomock.Call } // RuleAdd mocks base method. -func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error { +func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RuleAdd", arg0) ret0, _ := ret[0].(error) @@ -251,7 +253,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call { } // RuleDel mocks base method. -func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error { +func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RuleDel", arg0) ret0, _ := ret[0].(error) diff --git a/internal/routing/outbound.go b/internal/routing/outbound.go index 7a0531f3..f2c81ee3 100644 --- a/internal/routing/outbound.go +++ b/internal/routing/outbound.go @@ -56,8 +56,8 @@ func (r *Routing) removeOutboundSubnets(subnets []netip.Prefix, } } - ruleSrcNet := (*netip.Prefix)(nil) - ruleDstNet := &subnets[i] + ruleSrcNet := netip.Prefix{} + ruleDstNet := subnets[i] err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) if err != nil { warnings = append(warnings, @@ -81,8 +81,8 @@ func (r *Routing) addOutboundSubnets(subnets []netip.Prefix, } } - ruleSrcNet := (*netip.Prefix)(nil) - ruleDstNet := &subnets[i] + ruleSrcNet := netip.Prefix{} + ruleDstNet := subnets[i] err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) if err != nil { return fmt.Errorf("adding rule: for subnet %s: %w", subnet, err) diff --git a/internal/routing/routes.go b/internal/routing/routes.go index c9422ba3..264c6394 100644 --- a/internal/routing/routes.go +++ b/internal/routing/routes.go @@ -23,12 +23,12 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr, } route := netlink.Route{ - Dst: NetipPrefixToIPNet(&destination), - Gw: gateway.AsSlice(), - LinkIndex: link.Attrs().Index, + Dst: destination, + Gw: gateway, + LinkIndex: link.Index, Table: table, } - if err := r.netLinker.RouteReplace(&route); err != nil { + if err := r.netLinker.RouteReplace(route); err != nil { return fmt.Errorf("replacing route for subnet %s at interface %s: %w", destinationStr, iface, err) } @@ -51,12 +51,12 @@ func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr, } route := netlink.Route{ - Dst: NetipPrefixToIPNet(&destination), - Gw: gateway.AsSlice(), - LinkIndex: link.Attrs().Index, + Dst: destination, + Gw: gateway, + LinkIndex: link.Index, Table: table, } - if err := r.netLinker.RouteDel(&route); err != nil { + if err := r.netLinker.RouteDel(route); err != nil { return fmt.Errorf("deleting route: for subnet %s at interface %s: %w", destinationStr, iface, err) } diff --git a/internal/routing/routing.go b/internal/routing/routing.go index 14020e4b..03eb313d 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -18,30 +18,30 @@ type NetLinker interface { type Addresser interface { AddrList(link netlink.Link, family int) ( addresses []netlink.Addr, err error) - AddrReplace(link netlink.Link, addr *netlink.Addr) error + AddrReplace(link netlink.Link, addr netlink.Addr) error } type Router interface { - RouteList(link netlink.Link, family int) ( + RouteList(link *netlink.Link, family int) ( routes []netlink.Route, err error) - RouteAdd(route *netlink.Route) error - RouteDel(route *netlink.Route) error - RouteReplace(route *netlink.Route) error + RouteAdd(route netlink.Route) error + RouteDel(route netlink.Route) error + RouteReplace(route netlink.Route) error } type Ruler interface { RuleList(family int) (rules []netlink.Rule, err error) - RuleAdd(rule *netlink.Rule) error - RuleDel(rule *netlink.Rule) error + RuleAdd(rule netlink.Rule) error + RuleDel(rule netlink.Rule) error } type Linker interface { LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) LinkByIndex(index int) (link netlink.Link, err error) - LinkAdd(link netlink.Link) (err error) + LinkAdd(link netlink.Link) (linkIndex int, err error) LinkDel(link netlink.Link) (err error) - LinkSetUp(link netlink.Link) (err error) + LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetDown(link netlink.Link) (err error) } diff --git a/internal/routing/rules.go b/internal/routing/rules.go index 772c58b0..110effe6 100644 --- a/internal/routing/rules.go +++ b/internal/routing/rules.go @@ -1,30 +1,28 @@ package routing import ( - "bytes" "fmt" - "net" "net/netip" "github.com/qdm12/gluetun/internal/netlink" ) -func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error { +func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error { const add = true r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority)) rule := netlink.NewRule() - rule.Src = NetipPrefixToIPNet(src) - rule.Dst = NetipPrefixToIPNet(dst) + rule.Src = src + rule.Dst = dst rule.Priority = priority rule.Table = table - existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) + existingRules, err := r.netLinker.RuleList(netlink.FamilyAll) if err != nil { return fmt.Errorf("listing rules: %w", err) } for i := range existingRules { - if !rulesAreEqual(&existingRules[i], rule) { + if !rulesAreEqual(existingRules[i], rule) { continue } return nil // already exists @@ -36,22 +34,22 @@ func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error { return nil } -func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) error { +func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error { const add = false r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority)) rule := netlink.NewRule() - rule.Src = NetipPrefixToIPNet(src) - rule.Dst = NetipPrefixToIPNet(dst) + rule.Src = src + rule.Dst = dst rule.Priority = priority rule.Table = table - existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) + existingRules, err := r.netLinker.RuleList(netlink.FamilyAll) if err != nil { return fmt.Errorf("listing rules: %w", err) } for i := range existingRules { - if !rulesAreEqual(&existingRules[i], rule) { + if !rulesAreEqual(existingRules[i], rule) { continue } if err := r.netLinker.RuleDel(rule); err != nil { @@ -61,7 +59,7 @@ func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) erro return nil } -func ruleDbgMsg(add bool, src, dst *netip.Prefix, +func ruleDbgMsg(add bool, src, dst netip.Prefix, table, priority int) (debugMessage string) { debugMessage = "ip rule" @@ -71,11 +69,11 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix, debugMessage += " del" } - if src != nil { + if src.IsValid() { debugMessage += " from " + src.String() } - if dst != nil { + if dst.IsValid() { debugMessage += " to " + dst.String() } @@ -90,25 +88,20 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix, return debugMessage } -func rulesAreEqual(a, b *netlink.Rule) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - return false - } - return ipNetsAreEqual(a.Src, b.Src) && - ipNetsAreEqual(a.Dst, b.Dst) && +func rulesAreEqual(a, b netlink.Rule) bool { + return ipPrefixesAreEqual(a.Src, b.Src) && + ipPrefixesAreEqual(a.Dst, b.Dst) && a.Priority == b.Priority && a.Table == b.Table } -func ipNetsAreEqual(a, b *net.IPNet) bool { - if a == nil && b == nil { +func ipPrefixesAreEqual(a, b netip.Prefix) bool { + if !a.IsValid() && !b.IsValid() { return true } - if a == nil || b == nil { + if !a.IsValid() || !b.IsValid() { return false } - return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask) + return a.Bits() == b.Bits() && + a.Addr().Compare(b.Addr()) == 0 } diff --git a/internal/routing/rules_test.go b/internal/routing/rules_test.go index 2d01b645..12f63666 100644 --- a/internal/routing/rules_test.go +++ b/internal/routing/rules_test.go @@ -2,7 +2,6 @@ package routing import ( "errors" - "net" "net/netip" "testing" @@ -12,17 +11,16 @@ import ( "github.com/stretchr/testify/require" ) -func makeNetipPrefix(n byte) *netip.Prefix { +func makeNetipPrefix(n byte) netip.Prefix { const bits = 24 - prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits) - return &prefix + return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits) } -func makeIPRule(src, dst *netip.Prefix, - table, priority int) *netlink.Rule { +func makeIPRule(src, dst netip.Prefix, + table, priority int) netlink.Rule { rule := netlink.NewRule() - rule.Src = NetipPrefixToIPNet(src) - rule.Dst = NetipPrefixToIPNet(dst) + rule.Src = src + rule.Dst = dst rule.Table = table rule.Priority = priority return rule @@ -40,13 +38,13 @@ func Test_Routing_addIPRule(t *testing.T) { type ruleAddCall struct { expected bool - ruleToAdd *netlink.Rule + ruleToAdd netlink.Rule err error } testCases := map[string]struct { - src *netip.Prefix - dst *netip.Prefix + src netip.Prefix + dst netip.Prefix table int priority int dbgMsg string @@ -69,8 +67,8 @@ func Test_Routing_addIPRule(t *testing.T) { dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), - *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), + makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), + makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), }, }, }, @@ -95,8 +93,8 @@ func Test_Routing_addIPRule(t *testing.T) { dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), - *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101), + makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), + makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101), }, }, ruleAdd: ruleAddCall{ @@ -116,7 +114,7 @@ func Test_Routing_addIPRule(t *testing.T) { logger.EXPECT().Debug(testCase.dbgMsg) netLinker := NewMockNetLinker(ctrl) - netLinker.EXPECT().RuleList(netlink.FAMILY_ALL). + netLinker.EXPECT().RuleList(netlink.FamilyAll). Return(testCase.ruleList.rules, testCase.ruleList.err) if testCase.ruleAdd.expected { netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd). @@ -153,13 +151,13 @@ func Test_Routing_deleteIPRule(t *testing.T) { type ruleDelCall struct { expected bool - ruleToDel *netlink.Rule + ruleToDel netlink.Rule err error } testCases := map[string]struct { - src *netip.Prefix - dst *netip.Prefix + src netip.Prefix + dst netip.Prefix table int priority int dbgMsg string @@ -182,7 +180,7 @@ func Test_Routing_deleteIPRule(t *testing.T) { dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), + makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), }, }, ruleDel: ruleDelCall{ @@ -200,8 +198,8 @@ func Test_Routing_deleteIPRule(t *testing.T) { dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), - *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), + makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), + makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), }, }, ruleDel: ruleDelCall{ @@ -217,8 +215,8 @@ func Test_Routing_deleteIPRule(t *testing.T) { dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", ruleList: ruleListCall{ rules: []netlink.Rule{ - *makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), - *makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101), + makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), + makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101), }, }, }, @@ -234,7 +232,7 @@ func Test_Routing_deleteIPRule(t *testing.T) { logger.EXPECT().Debug(testCase.dbgMsg) netLinker := NewMockNetLinker(ctrl) - netLinker.EXPECT().RuleList(netlink.FAMILY_ALL). + netLinker.EXPECT().RuleList(netlink.FamilyAll). Return(testCase.ruleList.rules, testCase.ruleList.err) if testCase.ruleDel.expected { netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel). @@ -264,8 +262,8 @@ func Test_ruleDbgMsg(t *testing.T) { testCases := map[string]struct { add bool - src *netip.Prefix - dst *netip.Prefix + src netip.Prefix + dst netip.Prefix table int priority int dbgMsg string @@ -307,38 +305,79 @@ func Test_rulesAreEqual(t *testing.T) { t.Parallel() testCases := map[string]struct { - a *netlink.Rule - b *netlink.Rule + a netlink.Rule + b netlink.Rule equal bool }{ - "both nil": { + "both_empty": { equal: true, }, - "first nil": { - b: &netlink.Rule{}, - }, - "second nil": { - a: &netlink.Rule{}, - }, - "both not nil": { - a: &netlink.Rule{}, - b: &netlink.Rule{}, - equal: true, - }, - "both equal": { - a: &netlink.Rule{ - Src: &net.IPNet{ - IP: net.IPv4(1, 1, 1, 1), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + "not_equal_by_src": { + a: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Priority: 100, Table: 101, }, - b: &netlink.Rule{ - Src: &net.IPNet{ - IP: net.IPv4(1, 1, 1, 1), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + b: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + Priority: 100, + Table: 101, + }, + }, + "not_equal_by_dst": { + a: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32), + Priority: 100, + Table: 101, + }, + b: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + Priority: 100, + Table: 101, + }, + }, + "not_equal_by_priority": { + a: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + Priority: 999, + Table: 101, + }, + b: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + Priority: 100, + Table: 101, + }, + }, + "not_equal_by_table": { + a: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + Priority: 100, + Table: 999, + }, + b: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + Priority: 100, + Table: 101, + }, + }, + "equal": { + a: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + Priority: 100, + Table: 101, + }, + b: netlink.Rule{ + Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), Priority: 100, Table: 101, }, @@ -358,58 +397,39 @@ func Test_rulesAreEqual(t *testing.T) { } } -func Test_ipNetsAreEqual(t *testing.T) { +func Test_ipPrefixesAreEqual(t *testing.T) { t.Parallel() testCases := map[string]struct { - a *net.IPNet - b *net.IPNet + a netip.Prefix + b netip.Prefix equal bool }{ - "both nil": { + "both_not_valid": { equal: true, }, - "first nil": { - b: &net.IPNet{}, + "first_not_valid": { + b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), }, - "second nil": { - a: &net.IPNet{}, + "second_not_valid": { + a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), }, - "both not nil": { - a: &net.IPNet{}, - b: &net.IPNet{}, + "both_equal": { + a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), equal: true, }, - "both equal": { - a: &net.IPNet{ - IP: net.IPv4(1, 1, 1, 1), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - b: &net.IPNet{ - IP: net.IPv4(1, 1, 1, 1), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - equal: true, + "both_not_equal_by_IP": { + a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 24), }, - "both not equal by IP": { - a: &net.IPNet{ - IP: net.IPv4(1, 1, 1, 1), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - b: &net.IPNet{ - IP: net.IPv4(2, 2, 2, 2), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + "both_not_equal_by_bits": { + a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32), }, - "both not equal by mask": { - a: &net.IPNet{ - IP: net.IPv4(1, 1, 1, 1), - Mask: net.IPv4Mask(255, 255, 255, 255), - }, - b: &net.IPNet{ - IP: net.IPv4(1, 1, 1, 1), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, + "both_not_equal_by_IP_and_bits": { + a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24), + b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), }, } @@ -418,7 +438,7 @@ func Test_ipNetsAreEqual(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - equal := ipNetsAreEqual(testCase.a, testCase.b) + equal := ipPrefixesAreEqual(testCase.a, testCase.b) assert.Equal(t, testCase.equal, equal) }) diff --git a/internal/routing/vpn.go b/internal/routing/vpn.go index 4615ed98..86fdff5f 100644 --- a/internal/routing/vpn.go +++ b/internal/routing/vpn.go @@ -1,10 +1,8 @@ package routing import ( - "bytes" "errors" "fmt" - "net" "net/netip" "github.com/qdm12/gluetun/internal/netlink" @@ -16,14 +14,14 @@ var ( ) func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll) if err != nil { return ip, fmt.Errorf("listing routes: %w", err) } defaultLinkIndex := -1 for _, route := range routes { - if route.Dst == nil { + if !route.Dst.IsValid() { defaultLinkIndex = route.LinkIndex break } @@ -34,17 +32,17 @@ func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) { for _, route := range routes { if route.LinkIndex == defaultLinkIndex && - route.Dst != nil && - !IPIsPrivate(netIPToNetipAddress(route.Dst.IP)) && - bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) { - return netIPToNetipAddress(route.Dst.IP), nil + route.Dst.IsValid() && + !IPIsPrivate(route.Dst.Addr()) && + route.Dst.IsSingleIP() { + return route.Dst.Addr(), nil } } return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes)) } func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll) if err != nil { return ip, fmt.Errorf("listing routes: %w", err) } @@ -53,11 +51,11 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) { if err != nil { return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err) } - interfaceName := link.Attrs().Name + interfaceName := link.Name if interfaceName == vpnIntf && - route.Dst != nil && - route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) { - return netIPToNetipAddress(route.Gw), nil + route.Dst.IsValid() && + route.Dst.Addr().IsUnspecified() { + return route.Gw, nil } } return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes)) diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index a77748f5..d52e2297 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -42,7 +42,7 @@ type Storage interface { } type NetLinker interface { - AddrReplace(link netlink.Link, addr *netlink.Addr) error + AddrReplace(link netlink.Link, addr netlink.Addr) error Router Ruler Linker @@ -50,22 +50,22 @@ type NetLinker interface { } type Router interface { - RouteList(link netlink.Link, family int) ( + RouteList(link *netlink.Link, family int) ( routes []netlink.Route, err error) - RouteAdd(route *netlink.Route) error + RouteAdd(route netlink.Route) error } type Ruler interface { - RuleAdd(rule *netlink.Rule) error - RuleDel(rule *netlink.Rule) error + RuleAdd(rule netlink.Rule) error + RuleDel(rule netlink.Rule) error } type Linker interface { LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) - LinkAdd(link netlink.Link) (err error) + LinkAdd(link netlink.Link) (linkIndex int, err error) LinkDel(link netlink.Link) (err error) - LinkSetUp(link netlink.Link) (err error) + LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetDown(link netlink.Link) (err error) } diff --git a/internal/wireguard/address.go b/internal/wireguard/address.go index e4aa748a..196fb728 100644 --- a/internal/wireguard/address.go +++ b/internal/wireguard/address.go @@ -5,7 +5,6 @@ import ( "net/netip" "github.com/qdm12/gluetun/internal/netlink" - "github.com/qdm12/gluetun/internal/routing" ) func (w *Wireguard) addAddresses(link netlink.Link, @@ -15,15 +14,14 @@ func (w *Wireguard) addAddresses(link netlink.Link, continue } - ipNet := ipNet - address := &netlink.Addr{ - IPNet: routing.NetipPrefixToIPNet(&ipNet), + address := netlink.Addr{ + Network: ipNet, } err = w.netlink.AddrReplace(link, address) if err != nil { return fmt.Errorf("%w: when adding address %s to link %s", - err, address, link.Attrs().Name) + err, address, link.Name) } } diff --git a/internal/wireguard/address_test.go b/internal/wireguard/address_test.go index 9964e461..2d4a453e 100644 --- a/internal/wireguard/address_test.go +++ b/internal/wireguard/address_test.go @@ -7,7 +7,6 @@ import ( "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/netlink" - "github.com/qdm12/gluetun/internal/routing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,14 +17,6 @@ func Test_Wireguard_addAddresses(t *testing.T) { ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32) ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64) - newLink := func() netlink.Link { - linkAttrs := netlink.NewLinkAttrs() - linkAttrs.Name = "a_bridge" - return &netlink.Bridge{ - LinkAttrs: linkAttrs, - } - } - errDummy := errors.New("dummy") testCases := map[string]struct { @@ -35,15 +26,15 @@ func Test_Wireguard_addAddresses(t *testing.T) { err error }{ "success": { - link: newLink(), + link: netlink.Link{Type: "wireguard"}, addrs: []netip.Prefix{ipNetOne, ipNetTwo}, wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { netLinker := NewMockNetLinker(ctrl) firstCall := netLinker.EXPECT(). - AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). + AddrReplace(link, netlink.Addr{Network: ipNetOne}). Return(nil) netLinker.EXPECT(). - AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}). + AddrReplace(link, netlink.Addr{Network: ipNetTwo}). Return(nil).After(firstCall) return &Wireguard{ netlink: netLinker, @@ -54,12 +45,12 @@ func Test_Wireguard_addAddresses(t *testing.T) { }, }, "first add error": { - link: newLink(), + link: netlink.Link{Type: "wireguard", Name: "a_bridge"}, addrs: []netip.Prefix{ipNetOne, ipNetTwo}, wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { netLinker := NewMockNetLinker(ctrl) netLinker.EXPECT(). - AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). + AddrReplace(link, netlink.Addr{Network: ipNetOne}). Return(errDummy) return &Wireguard{ netlink: netLinker, @@ -71,15 +62,15 @@ func Test_Wireguard_addAddresses(t *testing.T) { err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"), }, "second add error": { - link: newLink(), + link: netlink.Link{Type: "wireguard", Name: "a_bridge"}, addrs: []netip.Prefix{ipNetOne, ipNetTwo}, wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { netLinker := NewMockNetLinker(ctrl) firstCall := netLinker.EXPECT(). - AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). + AddrReplace(link, netlink.Addr{Network: ipNetOne}). Return(nil) netLinker.EXPECT(). - AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}). + AddrReplace(link, netlink.Addr{Network: ipNetTwo}). Return(errDummy).After(firstCall) return &Wireguard{ netlink: netLinker, @@ -91,7 +82,6 @@ func Test_Wireguard_addAddresses(t *testing.T) { err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"), }, "ignore IPv6": { - link: newLink(), addrs: []netip.Prefix{ipNetTwo}, wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { return &Wireguard{ diff --git a/internal/wireguard/config.go b/internal/wireguard/config.go index 0e10554a..f351ed59 100644 --- a/internal/wireguard/config.go +++ b/internal/wireguard/config.go @@ -3,6 +3,7 @@ package wireguard import ( "fmt" "net" + "net/netip" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -53,8 +54,14 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) { PublicKey: publicKey, PresharedKey: preSharedKey, AllowedIPs: []net.IPNet{ - *allIPv4(), - *allIPv6(), + { + IP: net.IPv4(0, 0, 0, 0), + Mask: []byte{0, 0, 0, 0}, + }, + { + IP: net.IPv6zero, + Mask: []byte(net.IPv6zero), + }, }, ReplaceAllowedIPs: true, Endpoint: &net.UDPAddr{ @@ -68,16 +75,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) { return config, nil } -func allIPv4() (ipNet *net.IPNet) { - return &net.IPNet{ - IP: net.IPv4(0, 0, 0, 0), - Mask: []byte{0, 0, 0, 0}, - } +func allIPv4() (prefix netip.Prefix) { + const bits = 0 + return netip.PrefixFrom(netip.IPv4Unspecified(), bits) } -func allIPv6() (ipNet *net.IPNet) { - return &net.IPNet{ - IP: net.IPv6zero, - Mask: []byte(net.IPv6zero), - } +func allIPv6() (prefix netip.Prefix) { + const bits = 0 + return netip.PrefixFrom(netip.IPv6Unspecified(), bits) } diff --git a/internal/wireguard/netlink_integration_test.go b/internal/wireguard/netlink_integration_test.go index f3ecdb42..92e18f63 100644 --- a/internal/wireguard/netlink_integration_test.go +++ b/internal/wireguard/netlink_integration_test.go @@ -24,10 +24,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { netlinker := netlink.New(&noopDebugLogger{}) - linkAttrs := netlink.NewLinkAttrs() - linkAttrs.Name = "test_8081" - link := &netlink.Bridge{ - LinkAttrs: linkAttrs, + link := netlink.Link{ + Type: "bridge", + Name: "test_8081", } // Remove any previously created test interface from a crashed/panic @@ -37,8 +36,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { require.NoError(t, err) } - err = netlinker.LinkAdd(link) + linkIndex, err := netlinker.LinkAdd(link) require.NoError(t, err) + link.Index = linkIndex defer func() { err = netlinker.LinkDel(link) @@ -63,14 +63,12 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { err = wg.addAddresses(link, addresses) require.NoError(t, err) - netlinkAddresses, err := netlinker.AddrList(link, netlink.FAMILY_ALL) + netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll) require.NoError(t, err) require.Equal(t, len(addresses), len(netlinkAddresses)) for i, netlinkAddress := range netlinkAddresses { - require.NotNil(t, netlinkAddress.IPNet) - ipNet, err := netip.ParsePrefix(netlinkAddress.IPNet.String()) - require.NoError(t, err) - assert.Equal(t, addresses[i], ipNet) + require.NotNil(t, netlinkAddress.Network) + assert.Equal(t, addresses[i], netlinkAddress.Network) } } } @@ -95,7 +93,7 @@ func Test_netlink_Wireguard_addRule(t *testing.T) { assert.NoError(t, err) }() - rules, err := netlinker.RuleList(netlink.FAMILY_V4) + rules, err := netlinker.RuleList(netlink.FamilyV4) require.NoError(t, err) var rule netlink.Rule var ruleFound bool @@ -107,15 +105,10 @@ func Test_netlink_Wireguard_addRule(t *testing.T) { } require.True(t, ruleFound) expectedRule := netlink.Rule{ - Invert: true, - Priority: rulePriority, - Mark: firewallMark, - Table: firewallMark, - Mask: 4294967295, - Goto: -1, - Flow: -1, - SuppressIfgroup: -1, - SuppressPrefixlen: -1, + Invert: true, + Priority: rulePriority, + Mark: firewallMark, + Table: firewallMark, } assert.Equal(t, expectedRule, rule) diff --git a/internal/wireguard/netlinker.go b/internal/wireguard/netlinker.go index 9c80d3e5..02070015 100644 --- a/internal/wireguard/netlinker.go +++ b/internal/wireguard/netlinker.go @@ -5,7 +5,7 @@ import "github.com/qdm12/gluetun/internal/netlink" //go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker type NetLinker interface { - AddrReplace(link netlink.Link, addr *netlink.Addr) error + AddrReplace(link netlink.Link, addr netlink.Addr) error Router Ruler Linker @@ -13,21 +13,21 @@ type NetLinker interface { } type Router interface { - RouteList(link netlink.Link, family int) ( + RouteList(link *netlink.Link, family int) ( routes []netlink.Route, err error) - RouteAdd(route *netlink.Route) error + RouteAdd(route netlink.Route) error } type Ruler interface { - RuleAdd(rule *netlink.Rule) error - RuleDel(rule *netlink.Rule) error + RuleAdd(rule netlink.Rule) error + RuleDel(rule netlink.Rule) error } type Linker interface { - LinkAdd(link netlink.Link) (err error) + LinkAdd(link netlink.Link) (linkIndex int, err error) LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) - LinkSetUp(link netlink.Link) error + LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetDown(link netlink.Link) error LinkDel(link netlink.Link) error } diff --git a/internal/wireguard/netlinker_mock_test.go b/internal/wireguard/netlinker_mock_test.go index c041fadb..4e8199a4 100644 --- a/internal/wireguard/netlinker_mock_test.go +++ b/internal/wireguard/netlinker_mock_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - netlink "github.com/vishvananda/netlink" + netlink "github.com/qdm12/gluetun/internal/netlink" ) // MockNetLinker is a mock of NetLinker interface. @@ -35,7 +35,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder { } // AddrReplace mocks base method. -func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error { +func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) ret0, _ := ret[0].(error) @@ -64,11 +64,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call { } // LinkAdd mocks base method. -func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error { +func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkAdd", arg0) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 } // LinkAdd indicates an expected call of LinkAdd. @@ -136,11 +137,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call } // LinkSetUp mocks base method. -func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error { +func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LinkSetUp", arg0) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 } // LinkSetUp indicates an expected call of LinkSetUp. @@ -150,7 +152,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call { } // RouteAdd mocks base method. -func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error { +func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RouteAdd", arg0) ret0, _ := ret[0].(error) @@ -164,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call { } // RouteList mocks base method. -func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) { +func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RouteList", arg0, arg1) ret0, _ := ret[0].([]netlink.Route) @@ -179,7 +181,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C } // RuleAdd mocks base method. -func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error { +func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RuleAdd", arg0) ret0, _ := ret[0].(error) @@ -193,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call { } // RuleDel mocks base method. -func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error { +func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RuleDel", arg0) ret0, _ := ret[0].(error) diff --git a/internal/wireguard/route.go b/internal/wireguard/route.go index 11c0be18..c133a937 100644 --- a/internal/wireguard/route.go +++ b/internal/wireguard/route.go @@ -2,17 +2,17 @@ package wireguard import ( "fmt" - "net" + "net/netip" "github.com/qdm12/gluetun/internal/netlink" ) // TODO add IPv6 route if IPv6 is supported -func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet, +func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix, firewallMark int) (err error) { - route := &netlink.Route{ - LinkIndex: link.Attrs().Index, + route := netlink.Route{ + LinkIndex: link.Index, Dst: dst, Table: firewallMark, } @@ -21,7 +21,7 @@ func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet, if err != nil { return fmt.Errorf( "adding route for link %s, destination %s and table %d: %w", - link.Attrs().Name, dst, firewallMark, err) + link.Name, dst, firewallMark, err) } return err diff --git a/internal/wireguard/route_test.go b/internal/wireguard/route_test.go index 90638e88..b95a9cc4 100644 --- a/internal/wireguard/route_test.go +++ b/internal/wireguard/route_test.go @@ -2,7 +2,7 @@ package wireguard import ( "errors" - "net" + "net/netip" "testing" "github.com/golang/mock/gomock" @@ -15,41 +15,40 @@ func Test_Wireguard_addRoute(t *testing.T) { t.Parallel() const linkIndex = 88 - newLink := func() netlink.Link { - linkAttrs := netlink.NewLinkAttrs() - linkAttrs.Name = "a_bridge" - linkAttrs.Index = linkIndex - return &netlink.Bridge{ - LinkAttrs: linkAttrs, - } - } - ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)} + + ipPrefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32) + const firewallMark = 51820 errDummy := errors.New("dummy") testCases := map[string]struct { link netlink.Link - dst *net.IPNet - expectedRoute *netlink.Route + dst netip.Prefix + expectedRoute netlink.Route routeAddErr error err error }{ "success": { - link: newLink(), - dst: ipNet, - expectedRoute: &netlink.Route{ + link: netlink.Link{ + Index: linkIndex, + }, + dst: ipPrefix, + expectedRoute: netlink.Route{ LinkIndex: linkIndex, - Dst: ipNet, + Dst: ipPrefix, Table: firewallMark, }, }, "route add error": { - link: newLink(), - dst: ipNet, - expectedRoute: &netlink.Route{ + link: netlink.Link{ + Name: "a_bridge", + Index: linkIndex, + }, + dst: ipPrefix, + expectedRoute: netlink.Route{ LinkIndex: linkIndex, - Dst: ipNet, + Dst: ipPrefix, Table: firewallMark, }, routeAddErr: errDummy, diff --git a/internal/wireguard/rule_test.go b/internal/wireguard/rule_test.go index 30a86e3f..b883d39a 100644 --- a/internal/wireguard/rule_test.go +++ b/internal/wireguard/rule_test.go @@ -21,54 +21,39 @@ func Test_Wireguard_addRule(t *testing.T) { errDummy := errors.New("dummy") testCases := map[string]struct { - expectedRule *netlink.Rule + expectedRule netlink.Rule ruleAddErr error err error ruleDelErr error cleanupErr error }{ "success": { - expectedRule: &netlink.Rule{ - Invert: true, - Priority: rulePriority, - Mark: firewallMark, - Table: firewallMark, - Mask: -1, - Goto: -1, - Flow: -1, - SuppressIfgroup: -1, - SuppressPrefixlen: -1, - Family: family, + expectedRule: netlink.Rule{ + Invert: true, + Priority: rulePriority, + Mark: firewallMark, + Table: firewallMark, + Family: family, }, }, "rule add error": { - expectedRule: &netlink.Rule{ - Invert: true, - Priority: rulePriority, - Mark: firewallMark, - Table: firewallMark, - Mask: -1, - Goto: -1, - Flow: -1, - SuppressIfgroup: -1, - SuppressPrefixlen: -1, - Family: family, + expectedRule: netlink.Rule{ + Invert: true, + Priority: rulePriority, + Mark: firewallMark, + Table: firewallMark, + Family: family, }, ruleAddErr: errDummy, err: errors.New("adding rule ip rule 987: from all to all table 456: dummy"), }, "rule delete error": { - expectedRule: &netlink.Rule{ - Invert: true, - Priority: rulePriority, - Mark: firewallMark, - Table: firewallMark, - Mask: -1, - Goto: -1, - Flow: -1, - SuppressIfgroup: -1, - SuppressPrefixlen: -1, - Family: family, + expectedRule: netlink.Rule{ + Invert: true, + Priority: rulePriority, + Mark: firewallMark, + Table: firewallMark, + Family: family, }, ruleDelErr: errDummy, cleanupErr: errors.New("deleting rule ip rule 987: from all to all table 456: dummy"), diff --git a/internal/wireguard/run.go b/internal/wireguard/run.go index d4106771..e9095766 100644 --- a/internal/wireguard/run.go +++ b/internal/wireguard/run.go @@ -93,10 +93,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< return } - if err := w.netlink.LinkSetUp(link); err != nil { + linkIndex, err := w.netlink.LinkSetUp(link) + if err != nil { waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err) return } + link.Index = linkIndex closers.add("shutting down link", stepFour, func() error { return w.netlink.LinkSetDown(link) }) @@ -161,17 +163,16 @@ func setupKernelSpace(ctx context.Context, interfaceName string, netLinker NetLinker, mtu uint16, closers *closers, logger Logger) ( link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) { - linkAttrs := netlink.LinkAttrs{ + link = netlink.Link{ + Type: "wireguard", Name: interfaceName, - MTU: int(mtu), + MTU: mtu, } - link = &netlink.Wireguard{ - LinkAttrs: linkAttrs, - } - err = netLinker.LinkAdd(link) + linkIndex, err := netLinker.LinkAdd(link) if err != nil { - return nil, nil, fmt.Errorf("%w: %s", ErrAddLink, err) + return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err) } + link.Index = linkIndex closers.add("deleting link", stepFive, func() error { return netLinker.LinkDel(link) }) @@ -191,22 +192,22 @@ func setupUserSpace(ctx context.Context, link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) { tun, err := tun.CreateTUN(interfaceName, int(mtu)) if err != nil { - return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) + return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) } closers.add("closing TUN device", stepSeven, tun.Close) tunName, err := tun.Name() if err != nil { - return nil, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) + return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) } else if tunName != interfaceName { - return nil, nil, fmt.Errorf("%w: names don't match: expected %q and got %q", + return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q", ErrCreateTun, interfaceName, tunName) } link, err = netLinker.LinkByName(interfaceName) if err != nil { - return nil, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err) + return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err) } closers.add("deleting link", stepFive, func() error { return netLinker.LinkDel(link) @@ -226,14 +227,14 @@ func setupUserSpace(ctx context.Context, uapiFile, err := ipc.UAPIOpen(interfaceName) if err != nil { - return nil, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) + return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) } closers.add("closing UAPI file", stepThree, uapiFile.Close) uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile) if err != nil { - return nil, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err) + return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err) } closers.add("closing UAPI listener", stepTwo, uapiListener.Close)