Maint: internal/routing uses internal/netlink

This commit is contained in:
Quentin McGaw (desktop)
2021-08-23 20:56:10 +00:00
parent 7907146aaf
commit ee82a85543
10 changed files with 95 additions and 33 deletions

View File

@@ -25,7 +25,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, iface strin
" dev " + iface +
" table " + strconv.Itoa(table))
link, err := netlink.LinkByName(iface)
link, err := r.netLinker.LinkByName(iface)
if err != nil {
return fmt.Errorf("%w: interface %s: %s", ErrLinkByName, iface, err)
}
@@ -35,7 +35,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, iface strin
LinkIndex: link.Attrs().Index,
Table: table,
}
if err := netlink.RouteReplace(&route); err != nil {
if err := r.netLinker.RouteReplace(&route); err != nil {
return fmt.Errorf("%w: for subnet %s at interface %s: %s",
ErrRouteReplace, destinationStr, iface, err)
}
@@ -50,7 +50,7 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP, iface st
" dev " + iface +
" table " + strconv.Itoa(table))
link, err := netlink.LinkByName(iface)
link, err := r.netLinker.LinkByName(iface)
if err != nil {
return fmt.Errorf("%w: for interface %s: %s", ErrLinkByName, iface, err)
}
@@ -60,7 +60,7 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP, iface st
LinkIndex: link.Attrs().Index,
Table: table,
}
if err := netlink.RouteDel(&route); err != nil {
if err := r.netLinker.RouteDel(&route); err != nil {
return fmt.Errorf("%w: for subnet %s at interface %s: %s",
ErrRouteDelete, destinationStr, iface, err)
}
@@ -77,7 +77,7 @@ func (r *Routing) addIPRule(src net.IP, table, priority int) error {
rule.Priority = priority
rule.Table = table
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("%w: %s", ErrRulesList, err)
}
@@ -91,7 +91,7 @@ func (r *Routing) addIPRule(src net.IP, table, priority int) error {
}
}
if err := netlink.RuleAdd(rule); err != nil {
if err := r.netLinker.RuleAdd(rule); err != nil {
return fmt.Errorf("%w: for rule %q: %s", ErrRuleAdd, rule, err)
}
return nil
@@ -107,7 +107,7 @@ func (r *Routing) deleteIPRule(src net.IP, table, priority int) error {
rule.Priority = priority
rule.Table = table
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("%w: %s", ErrRulesList, err)
}
@@ -117,7 +117,7 @@ func (r *Routing) deleteIPRule(src net.IP, table, priority int) error {
bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) &&
existingRule.Priority == rule.Priority &&
existingRule.Table == rule.Table {
if err := netlink.RuleDel(rule); err != nil {
if err := r.netLinker.RuleDel(rule); err != nil {
return fmt.Errorf("%w: for rule %q: %s", ErrRuleDel, rule, err)
}
}

View File

@@ -38,7 +38,7 @@ type DefaultRouteGetter interface {
}
func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return "", nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
}
@@ -46,7 +46,7 @@ func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
if route.Dst == nil {
defaultGateway = route.Gw
linkIndex := route.LinkIndex
link, err := netlink.LinkByIndex(linkIndex)
link, err := r.netLinker.LinkByIndex(linkIndex)
if err != nil {
return "", nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err)
}
@@ -65,7 +65,7 @@ type DefaultIPGetter interface {
}
func (r *Routing) DefaultIP() (ip net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
}
@@ -74,7 +74,7 @@ func (r *Routing) DefaultIP() (ip net.IP, err error) {
for _, route := range routes {
if route.Dst == nil {
linkIndex := route.LinkIndex
link, err := netlink.LinkByIndex(linkIndex)
link, err := r.netLinker.LinkByIndex(linkIndex)
if err != nil {
return nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err)
}
@@ -93,7 +93,7 @@ type LocalSubnetGetter interface {
}
func (r *Routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return defaultSubnet, fmt.Errorf("%w: %s", ErrRoutesList, err)
}
@@ -126,7 +126,7 @@ type LocalNetworksGetter interface {
}
func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
links, err := netlink.LinkList()
links, err := r.netLinker.LinkList()
if err != nil {
return localNetworks, fmt.Errorf("%w: %s", ErrLinkList, err)
}
@@ -146,7 +146,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links))
}
routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_V4)
if err != nil {
return localNetworks, fmt.Errorf("%w: %s", ErrRoutesList, err)
}
@@ -163,7 +163,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
localNet.IPNet = route.Dst
r.logger.Info("local ipnet found: " + localNet.IPNet.String())
link, err := netlink.LinkByIndex(route.LinkIndex)
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
if err != nil {
return localNetworks, fmt.Errorf("%w: at index %d: %s", ErrLinkByIndex, route.LinkIndex, err)
}
@@ -213,7 +213,7 @@ type VPNDestinationIPGetter interface {
}
func (r *Routing) VPNDestinationIP() (ip net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
}
@@ -245,12 +245,12 @@ type VPNLocalGatewayIPGetter interface {
}
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
}
for _, route := range routes {
link, err := netlink.LinkByIndex(route.LinkIndex)
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err)
}

View File

@@ -5,6 +5,7 @@ import (
"net"
"sync"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/golibs/logging"
)
@@ -33,14 +34,17 @@ type Writer interface {
}
type Routing struct {
netLinker netlink.NetLinker
logger logging.Logger
outboundSubnets []net.IPNet
stateMutex sync.RWMutex
}
// New creates a new routing instance.
func New(logger logging.Logger) *Routing {
func New(netLinker netlink.NetLinker,
logger logging.Logger) *Routing {
return &Routing{
logger: logger,
netLinker: netLinker,
logger: logger,
}
}