chore(netlink): define own types with minimal fields

- Allow to swap `github.com/vishvananda/netlink`
- Allow to add build tags for each platform
- One step closer to development on non-Linux platforms
This commit is contained in:
Quentin McGaw
2023-05-29 06:44:58 +00:00
parent 163ac48ce4
commit 38ddcfa756
34 changed files with 828 additions and 493 deletions

View File

@@ -6,34 +6,6 @@ import (
"net/netip"
)
func NetipPrefixToIPNet(prefix *netip.Prefix) (ipNet *net.IPNet) {
if prefix == nil {
return nil
}
s := prefix.String()
ip, ipNet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
ipNet.IP = ip
return ipNet
}
func netIPNetToNetipPrefix(ipNet net.IPNet) (prefix netip.Prefix) {
if len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len {
return prefix
}
var ip netip.Addr
if ipv4 := ipNet.IP.To4(); ipv4 != nil {
ip = netip.AddrFrom4([4]byte(ipv4))
} else {
ip = netip.AddrFrom16([16]byte(ipNet.IP))
}
bits, _ := ipNet.Mask.Size()
return netip.PrefixFrom(ip, bits)
}
func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
address, ok := netip.AddrFromSlice(ip)
if !ok {

View File

@@ -8,54 +8,6 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_netIPNetToNetipPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ipNet net.IPNet
prefix netip.Prefix
}{
"empty ipnet": {},
"custom sized IP in ipnet": {
ipNet: net.IPNet{
IP: net.IP{1},
},
},
"IPv4 ipnet": {
ipNet: net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv4-in-IPv6 ipnet": {
ipNet: net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv6 ipnet": {
ipNet: net.IPNet{
IP: net.IPv6loopback,
Mask: net.IPMask{0xff},
},
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix := netIPNetToNetipPrefix(testCase.ipNet)
assert.Equal(t, testCase.prefix, prefix)
})
}
}
func Test_netIPToNetipAddress(t *testing.T) {
t.Parallel()

View File

@@ -25,17 +25,17 @@ func (d DefaultRoute) String() string {
}
func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil {
return nil, fmt.Errorf("listing routes: %w", err)
}
for _, route := range routes {
if route.Dst != nil {
if route.Dst.IsValid() {
continue
}
defaultRoute := DefaultRoute{
Gateway: netIPToNetipAddress(route.Gw),
Gateway: route.Gw,
Family: route.Family,
}
linkIndex := route.LinkIndex
@@ -43,11 +43,10 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
if err != nil {
return nil, fmt.Errorf("obtaining link by index: for default route at index %d: %w", linkIndex, err)
}
attributes := link.Attrs()
defaultRoute.NetInterface = attributes.Name
family := netlink.FAMILY_V6
if route.Gw.To4() != nil {
family = netlink.FAMILY_V4
defaultRoute.NetInterface = link.Name
family := netlink.FamilyV6
if route.Gw.Is4() {
family = netlink.FamilyV4
}
defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface, family)
if err != nil {

View File

@@ -23,7 +23,7 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err
for _, defaultRoute := range defaultRoutes {
defaultDestination := defaultDestinationIPv4
if defaultRoute.Family == netlink.FAMILY_V6 {
if defaultRoute.Family == netlink.FamilyV6 {
defaultDestination = defaultDestinationIPv6
}
@@ -43,7 +43,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
for _, defaultRoute := range defaultRoutes {
defaultDestination := defaultDestinationIPv4
if defaultRoute.Family == netlink.FAMILY_V6 {
if defaultRoute.Family == netlink.FamilyV6 {
defaultDestination = defaultDestinationIPv6
}
@@ -68,8 +68,8 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
bits = 128
}
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
ruleDstNet := (*netip.Prefix)(nil)
err = r.addIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority)
ruleDstNet := netip.Prefix{}
err = r.addIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority)
if err != nil {
return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err)
}
@@ -86,8 +86,8 @@ func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
bits = 128
}
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
ruleDstNet := (*netip.Prefix)(nil)
err = r.deleteIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority)
ruleDstNet := netip.Prefix{}
err = r.deleteIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority)
if err != nil {
return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err)
}

View File

@@ -19,8 +19,8 @@ var (
)
func ipMatchesFamily(ip netip.Addr, family int) bool {
return (family == netlink.FAMILY_V4 && ip.Is4()) ||
(family == netlink.FAMILY_V6 && ip.Is6())
return (family == netlink.FamilyV4 && ip.Is4()) ||
(family == netlink.FamilyV6 && ip.Is6())
}
func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) {

View File

@@ -29,25 +29,25 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
localLinks := make(map[int]struct{})
for _, link := range links {
if link.Attrs().EncapType != "ether" {
if link.EncapType != "ether" {
continue
}
localLinks[link.Attrs().Index] = struct{}{}
r.logger.Info("local ethernet link found: " + link.Attrs().Name)
localLinks[link.Index] = struct{}{}
r.logger.Info("local ethernet link found: " + link.Name)
}
if len(localLinks) == 0 {
return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links))
}
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil {
return localNetworks, fmt.Errorf("listing routes: %w", err)
}
for _, route := range routes {
if route.Gw != nil || route.Dst == nil {
if route.Gw.IsValid() || !route.Dst.IsValid() {
continue
} else if _, ok := localLinks[route.LinkIndex]; !ok {
continue
@@ -55,7 +55,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
var localNet LocalNetwork
localNet.IPNet = netIPNetToNetipPrefix(*route.Dst)
localNet.IPNet = route.Dst
r.logger.Info("local ipnet found: " + localNet.IPNet.String())
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
@@ -63,11 +63,11 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
return localNetworks, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
}
localNet.InterfaceName = link.Attrs().Name
localNet.InterfaceName = link.Name
family := netlink.FAMILY_V6
family := netlink.FamilyV6
if localNet.IPNet.Addr().Is4() {
family = netlink.FAMILY_V4
family = netlink.FamilyV4
}
ip, err := r.assignedIP(localNet.InterfaceName, family)
if err != nil {
@@ -96,7 +96,8 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
const localPriority = 98
// Main table was setup correctly by Docker, just need to add rules to use it
err = r.addIPRule(nil, &subnet.IPNet, mainTable, localPriority)
src := netip.Prefix{}
err = r.addIPRule(src, subnet.IPNet, mainTable, localPriority)
if err != nil {
return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err)
}

View File

@@ -8,7 +8,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
netlink "github.com/vishvananda/netlink"
netlink "github.com/qdm12/gluetun/internal/netlink"
)
// MockNetLinker is a mock of NetLinker interface.
@@ -50,7 +50,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error {
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -79,11 +79,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkAdd indicates an expected call of LinkAdd.
@@ -166,11 +167,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -180,7 +182,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
}
// RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error)
@@ -194,7 +196,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
}
// RouteDel mocks base method.
func (m *MockNetLinker) RouteDel(arg0 *netlink.Route) error {
func (m *MockNetLinker) RouteDel(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteDel", arg0)
ret0, _ := ret[0].(error)
@@ -208,7 +210,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Route)
@@ -223,7 +225,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C
}
// RouteReplace mocks base method.
func (m *MockNetLinker) RouteReplace(arg0 *netlink.Route) error {
func (m *MockNetLinker) RouteReplace(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteReplace", arg0)
ret0, _ := ret[0].(error)
@@ -237,7 +239,7 @@ func (mr *MockNetLinkerMockRecorder) RouteReplace(arg0 interface{}) *gomock.Call
}
// RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error)
@@ -251,7 +253,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
}
// RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error)

View File

@@ -56,8 +56,8 @@ func (r *Routing) removeOutboundSubnets(subnets []netip.Prefix,
}
}
ruleSrcNet := (*netip.Prefix)(nil)
ruleDstNet := &subnets[i]
ruleSrcNet := netip.Prefix{}
ruleDstNet := subnets[i]
err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil {
warnings = append(warnings,
@@ -81,8 +81,8 @@ func (r *Routing) addOutboundSubnets(subnets []netip.Prefix,
}
}
ruleSrcNet := (*netip.Prefix)(nil)
ruleDstNet := &subnets[i]
ruleSrcNet := netip.Prefix{}
ruleDstNet := subnets[i]
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil {
return fmt.Errorf("adding rule: for subnet %s: %w", subnet, err)

View File

@@ -23,12 +23,12 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
}
route := netlink.Route{
Dst: NetipPrefixToIPNet(&destination),
Gw: gateway.AsSlice(),
LinkIndex: link.Attrs().Index,
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Table: table,
}
if err := r.netLinker.RouteReplace(&route); err != nil {
if err := r.netLinker.RouteReplace(route); err != nil {
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
destinationStr, iface, err)
}
@@ -51,12 +51,12 @@ func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
}
route := netlink.Route{
Dst: NetipPrefixToIPNet(&destination),
Gw: gateway.AsSlice(),
LinkIndex: link.Attrs().Index,
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Table: table,
}
if err := r.netLinker.RouteDel(&route); err != nil {
if err := r.netLinker.RouteDel(route); err != nil {
return fmt.Errorf("deleting route: for subnet %s at interface %s: %w",
destinationStr, iface, err)
}

View File

@@ -18,30 +18,30 @@ type NetLinker interface {
type Addresser interface {
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr *netlink.Addr) error
AddrReplace(link netlink.Link, addr netlink.Addr) error
}
type Router interface {
RouteList(link netlink.Link, family int) (
RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error
RouteDel(route *netlink.Route) error
RouteReplace(route *netlink.Route) error
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error)
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
}

View File

@@ -1,30 +1,28 @@
package routing
import (
"bytes"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error {
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
const add = true
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
return fmt.Errorf("listing rules: %w", err)
}
for i := range existingRules {
if !rulesAreEqual(&existingRules[i], rule) {
if !rulesAreEqual(existingRules[i], rule) {
continue
}
return nil // already exists
@@ -36,22 +34,22 @@ func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error {
return nil
}
func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) error {
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
const add = false
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
return fmt.Errorf("listing rules: %w", err)
}
for i := range existingRules {
if !rulesAreEqual(&existingRules[i], rule) {
if !rulesAreEqual(existingRules[i], rule) {
continue
}
if err := r.netLinker.RuleDel(rule); err != nil {
@@ -61,7 +59,7 @@ func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) erro
return nil
}
func ruleDbgMsg(add bool, src, dst *netip.Prefix,
func ruleDbgMsg(add bool, src, dst netip.Prefix,
table, priority int) (debugMessage string) {
debugMessage = "ip rule"
@@ -71,11 +69,11 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix,
debugMessage += " del"
}
if src != nil {
if src.IsValid() {
debugMessage += " from " + src.String()
}
if dst != nil {
if dst.IsValid() {
debugMessage += " to " + dst.String()
}
@@ -90,25 +88,20 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix,
return debugMessage
}
func rulesAreEqual(a, b *netlink.Rule) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return ipNetsAreEqual(a.Src, b.Src) &&
ipNetsAreEqual(a.Dst, b.Dst) &&
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
a.Priority == b.Priority &&
a.Table == b.Table
}
func ipNetsAreEqual(a, b *net.IPNet) bool {
if a == nil && b == nil {
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
if !a.IsValid() && !b.IsValid() {
return true
}
if a == nil || b == nil {
if !a.IsValid() || !b.IsValid() {
return false
}
return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask)
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}

View File

@@ -2,7 +2,6 @@ package routing
import (
"errors"
"net"
"net/netip"
"testing"
@@ -12,17 +11,16 @@ import (
"github.com/stretchr/testify/require"
)
func makeNetipPrefix(n byte) *netip.Prefix {
func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24
prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
return &prefix
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
}
func makeIPRule(src, dst *netip.Prefix,
table, priority int) *netlink.Rule {
func makeIPRule(src, dst netip.Prefix,
table, priority int) netlink.Rule {
rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Src = src
rule.Dst = dst
rule.Table = table
rule.Priority = priority
return rule
@@ -40,13 +38,13 @@ func Test_Routing_addIPRule(t *testing.T) {
type ruleAddCall struct {
expected bool
ruleToAdd *netlink.Rule
ruleToAdd netlink.Rule
err error
}
testCases := map[string]struct {
src *netip.Prefix
dst *netip.Prefix
src netip.Prefix
dst netip.Prefix
table int
priority int
dbgMsg string
@@ -69,8 +67,8 @@ func Test_Routing_addIPRule(t *testing.T) {
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
},
@@ -95,8 +93,8 @@ func Test_Routing_addIPRule(t *testing.T) {
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
},
},
ruleAdd: ruleAddCall{
@@ -116,7 +114,7 @@ func Test_Routing_addIPRule(t *testing.T) {
logger.EXPECT().Debug(testCase.dbgMsg)
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL).
netLinker.EXPECT().RuleList(netlink.FamilyAll).
Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleAdd.expected {
netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd).
@@ -153,13 +151,13 @@ func Test_Routing_deleteIPRule(t *testing.T) {
type ruleDelCall struct {
expected bool
ruleToDel *netlink.Rule
ruleToDel netlink.Rule
err error
}
testCases := map[string]struct {
src *netip.Prefix
dst *netip.Prefix
src netip.Prefix
dst netip.Prefix
table int
priority int
dbgMsg string
@@ -182,7 +180,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
ruleDel: ruleDelCall{
@@ -200,8 +198,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
ruleDel: ruleDelCall{
@@ -217,8 +215,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
},
},
},
@@ -234,7 +232,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
logger.EXPECT().Debug(testCase.dbgMsg)
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL).
netLinker.EXPECT().RuleList(netlink.FamilyAll).
Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleDel.expected {
netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel).
@@ -264,8 +262,8 @@ func Test_ruleDbgMsg(t *testing.T) {
testCases := map[string]struct {
add bool
src *netip.Prefix
dst *netip.Prefix
src netip.Prefix
dst netip.Prefix
table int
priority int
dbgMsg string
@@ -307,38 +305,79 @@ func Test_rulesAreEqual(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
a *netlink.Rule
b *netlink.Rule
a netlink.Rule
b netlink.Rule
equal bool
}{
"both nil": {
"both_empty": {
equal: true,
},
"first nil": {
b: &netlink.Rule{},
},
"second nil": {
a: &netlink.Rule{},
},
"both not nil": {
a: &netlink.Rule{},
b: &netlink.Rule{},
equal: true,
},
"both equal": {
a: &netlink.Rule{
Src: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
"not_equal_by_src": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
b: &netlink.Rule{
Src: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_dst": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_priority": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 999,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_table": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 999,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"equal": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
@@ -358,58 +397,39 @@ func Test_rulesAreEqual(t *testing.T) {
}
}
func Test_ipNetsAreEqual(t *testing.T) {
func Test_ipPrefixesAreEqual(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
a *net.IPNet
b *net.IPNet
a netip.Prefix
b netip.Prefix
equal bool
}{
"both nil": {
"both_not_valid": {
equal: true,
},
"first nil": {
b: &net.IPNet{},
"first_not_valid": {
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
},
"second nil": {
a: &net.IPNet{},
"second_not_valid": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
},
"both not nil": {
a: &net.IPNet{},
b: &net.IPNet{},
"both_equal": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
equal: true,
},
"both equal": {
a: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
b: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
equal: true,
"both_not_equal_by_IP": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 24),
},
"both not equal by IP": {
a: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
b: &net.IPNet{
IP: net.IPv4(2, 2, 2, 2),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
"both_not_equal_by_bits": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
},
"both not equal by mask": {
a: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
b: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
"both_not_equal_by_IP_and_bits": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
},
}
@@ -418,7 +438,7 @@ func Test_ipNetsAreEqual(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
equal := ipNetsAreEqual(testCase.a, testCase.b)
equal := ipPrefixesAreEqual(testCase.a, testCase.b)
assert.Equal(t, testCase.equal, equal)
})

View File

@@ -1,10 +1,8 @@
package routing
import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
@@ -16,14 +14,14 @@ var (
)
func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil {
return ip, fmt.Errorf("listing routes: %w", err)
}
defaultLinkIndex := -1
for _, route := range routes {
if route.Dst == nil {
if !route.Dst.IsValid() {
defaultLinkIndex = route.LinkIndex
break
}
@@ -34,17 +32,17 @@ func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
for _, route := range routes {
if route.LinkIndex == defaultLinkIndex &&
route.Dst != nil &&
!IPIsPrivate(netIPToNetipAddress(route.Dst.IP)) &&
bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) {
return netIPToNetipAddress(route.Dst.IP), nil
route.Dst.IsValid() &&
!IPIsPrivate(route.Dst.Addr()) &&
route.Dst.IsSingleIP() {
return route.Dst.Addr(), nil
}
}
return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
}
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil {
return ip, fmt.Errorf("listing routes: %w", err)
}
@@ -53,11 +51,11 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
if err != nil {
return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
}
interfaceName := link.Attrs().Name
interfaceName := link.Name
if interfaceName == vpnIntf &&
route.Dst != nil &&
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
return netIPToNetipAddress(route.Gw), nil
route.Dst.IsValid() &&
route.Dst.Addr().IsUnspecified() {
return route.Gw, nil
}
}
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))