From ee82a85543d6f29b57eba64998bdb6e230d33e6e Mon Sep 17 00:00:00 2001 From: "Quentin McGaw (desktop)" Date: Mon, 23 Aug 2021 20:56:10 +0000 Subject: [PATCH] Maint: `internal/routing` uses `internal/netlink` --- cmd/gluetun/main.go | 2 +- internal/netlink/address.go | 6 ++++++ internal/netlink/interface.go | 12 ++++-------- internal/netlink/link.go | 23 +++++++++++++++++++++++ internal/netlink/route.go | 23 +++++++++++++++++++++++ internal/netlink/rule.go | 12 ++++++++++++ internal/routing/mutate.go | 16 ++++++++-------- internal/routing/reader.go | 22 +++++++++++----------- internal/routing/routing.go | 8 ++++++-- internal/wireguard/constructor.go | 4 +--- 10 files changed, 95 insertions(+), 33 deletions(-) create mode 100644 internal/netlink/link.go diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 6892267b..a10c70c1 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -234,7 +234,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, Prefix: "routing: ", Level: firewallLogLevel, }) - routingConf := routing.New(routingLogger) + routingConf := routing.New(netLinker, routingLogger) defaultInterface, defaultGateway, err := routingConf.DefaultRoute() if err != nil { diff --git a/internal/netlink/address.go b/internal/netlink/address.go index 8252e962..c972c27a 100644 --- a/internal/netlink/address.go +++ b/internal/netlink/address.go @@ -2,6 +2,12 @@ package netlink import "github.com/vishvananda/netlink" +var _ Addresser = (*NetLink)(nil) + +type Addresser interface { + AddrAdd(link netlink.Link, addr *netlink.Addr) error +} + func (n *NetLink) AddrAdd(link netlink.Link, addr *netlink.Addr) error { return netlink.AddrAdd(link, addr) } diff --git a/internal/netlink/interface.go b/internal/netlink/interface.go index 0e091be2..c2239eb0 100644 --- a/internal/netlink/interface.go +++ b/internal/netlink/interface.go @@ -1,14 +1,10 @@ package netlink -import "github.com/vishvananda/netlink" - -//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . NetLinker - var _ NetLinker = (*NetLink)(nil) type NetLinker interface { - AddrAdd(link netlink.Link, addr *netlink.Addr) error - RouteAdd(route *netlink.Route) error - RuleAdd(rule *netlink.Rule) error - RuleDel(rule *netlink.Rule) error + Addresser + Linker + Router + Ruler } diff --git a/internal/netlink/link.go b/internal/netlink/link.go new file mode 100644 index 00000000..80223e75 --- /dev/null +++ b/internal/netlink/link.go @@ -0,0 +1,23 @@ +package netlink + +import "github.com/vishvananda/netlink" + +var _ Linker = (*NetLink)(nil) + +type Linker interface { + LinkList() (links []netlink.Link, err error) + LinkByName(name string) (link netlink.Link, err error) + LinkByIndex(index int) (link netlink.Link, err error) +} + +func (n *NetLink) LinkList() (links []netlink.Link, err error) { + return netlink.LinkList() +} + +func (n *NetLink) LinkByName(name string) (link netlink.Link, err error) { + return netlink.LinkByName(name) +} + +func (n *NetLink) LinkByIndex(index int) (link netlink.Link, err error) { + return netlink.LinkByIndex(index) +} diff --git a/internal/netlink/route.go b/internal/netlink/route.go index 13a3a4b1..1414c0e4 100644 --- a/internal/netlink/route.go +++ b/internal/netlink/route.go @@ -2,6 +2,29 @@ package netlink import "github.com/vishvananda/netlink" +var _ Router = (*NetLink)(nil) + +type Router interface { + RouteList(link netlink.Link, family int) ( + routes []netlink.Route, err error) + RouteAdd(route *netlink.Route) error + RouteDel(route *netlink.Route) error + RouteReplace(route *netlink.Route) error +} + +func (n *NetLink) RouteList(link netlink.Link, family int) ( + routes []netlink.Route, err error) { + return netlink.RouteList(link, family) +} + func (n *NetLink) RouteAdd(route *netlink.Route) error { return netlink.RouteAdd(route) } + +func (n *NetLink) RouteDel(route *netlink.Route) error { + return netlink.RouteDel(route) +} + +func (n *NetLink) RouteReplace(route *netlink.Route) error { + return netlink.RouteReplace(route) +} diff --git a/internal/netlink/rule.go b/internal/netlink/rule.go index 9952460c..f6087941 100644 --- a/internal/netlink/rule.go +++ b/internal/netlink/rule.go @@ -2,6 +2,18 @@ package netlink import "github.com/vishvananda/netlink" +var _ Ruler = (*NetLink)(nil) + +type Ruler interface { + RuleList(family int) (rules []netlink.Rule, err error) + RuleAdd(rule *netlink.Rule) error + RuleDel(rule *netlink.Rule) error +} + +func (n *NetLink) RuleList(family int) (rules []netlink.Rule, err error) { + return netlink.RuleList(family) +} + func (n *NetLink) RuleAdd(rule *netlink.Rule) error { return netlink.RuleAdd(rule) } diff --git a/internal/routing/mutate.go b/internal/routing/mutate.go index d09568ed..39d87ccc 100644 --- a/internal/routing/mutate.go +++ b/internal/routing/mutate.go @@ -25,7 +25,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, iface strin " dev " + iface + " table " + strconv.Itoa(table)) - link, err := netlink.LinkByName(iface) + link, err := r.netLinker.LinkByName(iface) if err != nil { return fmt.Errorf("%w: interface %s: %s", ErrLinkByName, iface, err) } @@ -35,7 +35,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, iface strin LinkIndex: link.Attrs().Index, Table: table, } - if err := netlink.RouteReplace(&route); err != nil { + if err := r.netLinker.RouteReplace(&route); err != nil { return fmt.Errorf("%w: for subnet %s at interface %s: %s", ErrRouteReplace, destinationStr, iface, err) } @@ -50,7 +50,7 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP, iface st " dev " + iface + " table " + strconv.Itoa(table)) - link, err := netlink.LinkByName(iface) + link, err := r.netLinker.LinkByName(iface) if err != nil { return fmt.Errorf("%w: for interface %s: %s", ErrLinkByName, iface, err) } @@ -60,7 +60,7 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP, iface st LinkIndex: link.Attrs().Index, Table: table, } - if err := netlink.RouteDel(&route); err != nil { + if err := r.netLinker.RouteDel(&route); err != nil { return fmt.Errorf("%w: for subnet %s at interface %s: %s", ErrRouteDelete, destinationStr, iface, err) } @@ -77,7 +77,7 @@ func (r *Routing) addIPRule(src net.IP, table, priority int) error { rule.Priority = priority rule.Table = table - rules, err := netlink.RuleList(netlink.FAMILY_ALL) + rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) if err != nil { return fmt.Errorf("%w: %s", ErrRulesList, err) } @@ -91,7 +91,7 @@ func (r *Routing) addIPRule(src net.IP, table, priority int) error { } } - if err := netlink.RuleAdd(rule); err != nil { + if err := r.netLinker.RuleAdd(rule); err != nil { return fmt.Errorf("%w: for rule %q: %s", ErrRuleAdd, rule, err) } return nil @@ -107,7 +107,7 @@ func (r *Routing) deleteIPRule(src net.IP, table, priority int) error { rule.Priority = priority rule.Table = table - rules, err := netlink.RuleList(netlink.FAMILY_ALL) + rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) if err != nil { return fmt.Errorf("%w: %s", ErrRulesList, err) } @@ -117,7 +117,7 @@ func (r *Routing) deleteIPRule(src net.IP, table, priority int) error { bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) && existingRule.Priority == rule.Priority && existingRule.Table == rule.Table { - if err := netlink.RuleDel(rule); err != nil { + if err := r.netLinker.RuleDel(rule); err != nil { return fmt.Errorf("%w: for rule %q: %s", ErrRuleDel, rule, err) } } diff --git a/internal/routing/reader.go b/internal/routing/reader.go index fc8ca3e0..db8007fe 100644 --- a/internal/routing/reader.go +++ b/internal/routing/reader.go @@ -38,7 +38,7 @@ type DefaultRouteGetter interface { } func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { - routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) if err != nil { return "", nil, fmt.Errorf("%w: %s", ErrRoutesList, err) } @@ -46,7 +46,7 @@ func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP if route.Dst == nil { defaultGateway = route.Gw linkIndex := route.LinkIndex - link, err := netlink.LinkByIndex(linkIndex) + link, err := r.netLinker.LinkByIndex(linkIndex) if err != nil { return "", nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err) } @@ -65,7 +65,7 @@ type DefaultIPGetter interface { } func (r *Routing) DefaultIP() (ip net.IP, err error) { - routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) if err != nil { return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) } @@ -74,7 +74,7 @@ func (r *Routing) DefaultIP() (ip net.IP, err error) { for _, route := range routes { if route.Dst == nil { linkIndex := route.LinkIndex - link, err := netlink.LinkByIndex(linkIndex) + link, err := r.netLinker.LinkByIndex(linkIndex) if err != nil { return nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err) } @@ -93,7 +93,7 @@ type LocalSubnetGetter interface { } func (r *Routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { - routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) if err != nil { return defaultSubnet, fmt.Errorf("%w: %s", ErrRoutesList, err) } @@ -126,7 +126,7 @@ type LocalNetworksGetter interface { } func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { - links, err := netlink.LinkList() + links, err := r.netLinker.LinkList() if err != nil { return localNetworks, fmt.Errorf("%w: %s", ErrLinkList, err) } @@ -146,7 +146,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links)) } - routes, err := netlink.RouteList(nil, netlink.FAMILY_V4) + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_V4) if err != nil { return localNetworks, fmt.Errorf("%w: %s", ErrRoutesList, err) } @@ -163,7 +163,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { localNet.IPNet = route.Dst r.logger.Info("local ipnet found: " + localNet.IPNet.String()) - link, err := netlink.LinkByIndex(route.LinkIndex) + link, err := r.netLinker.LinkByIndex(route.LinkIndex) if err != nil { return localNetworks, fmt.Errorf("%w: at index %d: %s", ErrLinkByIndex, route.LinkIndex, err) } @@ -213,7 +213,7 @@ type VPNDestinationIPGetter interface { } func (r *Routing) VPNDestinationIP() (ip net.IP, err error) { - routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) if err != nil { return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) } @@ -245,12 +245,12 @@ type VPNLocalGatewayIPGetter interface { } func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) { - routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) if err != nil { return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) } for _, route := range routes { - link, err := netlink.LinkByIndex(route.LinkIndex) + link, err := r.netLinker.LinkByIndex(route.LinkIndex) if err != nil { return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err) } diff --git a/internal/routing/routing.go b/internal/routing/routing.go index a246f49a..fe8dd147 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -5,6 +5,7 @@ import ( "net" "sync" + "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/golibs/logging" ) @@ -33,14 +34,17 @@ type Writer interface { } type Routing struct { + netLinker netlink.NetLinker logger logging.Logger outboundSubnets []net.IPNet stateMutex sync.RWMutex } // New creates a new routing instance. -func New(logger logging.Logger) *Routing { +func New(netLinker netlink.NetLinker, + logger logging.Logger) *Routing { return &Routing{ - logger: logger, + netLinker: netLinker, + logger: logger, } } diff --git a/internal/wireguard/constructor.go b/internal/wireguard/constructor.go index 70caa547..a8cbe954 100644 --- a/internal/wireguard/constructor.go +++ b/internal/wireguard/constructor.go @@ -1,7 +1,5 @@ package wireguard -import "github.com/qdm12/gluetun/internal/netlink" - var _ Wireguarder = (*Wireguard)(nil) type Wireguarder interface { @@ -12,7 +10,7 @@ type Wireguarder interface { type Wireguard struct { logger Logger settings Settings - netlink netlink.NetLinker + netlink NetLinker } func New(settings Settings, netlink NetLinker,