Routing improvements (#268)

- Fixes #82 
- Remove `EXTRA_SUBNETS`
- Remove no longer needed iptables rules
- Reduce routing interface arity
- Routing setup is done in main.go instead of in the firewall
- Routing setup gets reverted at shutdown
This commit is contained in:
Quentin McGaw
2020-10-24 18:05:11 -04:00
committed by GitHub
parent 716eb14da1
commit ed4fcc17b3
15 changed files with 209 additions and 251 deletions

View File

@@ -0,0 +1,62 @@
package routing
import (
"fmt"
"net"
)
var (
ErrSetup = fmt.Errorf("cannot setup routing")
ErrTeardown = fmt.Errorf("cannot teardown routing")
)
const (
table = 200
priority = 100
)
func (r *routing) Setup() (err error) {
defaultIP, err := r.defaultIP()
if err != nil {
return fmt.Errorf("%s: %w", ErrSetup, err)
}
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
if err != nil {
return fmt.Errorf("%s: %w", ErrSetup, err)
}
defer func() {
if err == nil {
return
}
if err := r.TearDown(); err != nil {
r.logger.Error(err)
}
}()
if err := r.addIPRule(defaultIP, table, priority); err != nil {
return fmt.Errorf("%s: %w", ErrSetup, err)
}
if err := r.addRouteVia(net.IPNet{}, defaultGateway, defaultInterfaceName, table); err != nil {
return fmt.Errorf("%s: %w", ErrSetup, err)
}
return nil
}
func (r *routing) TearDown() error {
defaultIP, err := r.defaultIP()
if err != nil {
return fmt.Errorf("%s: %w", ErrTeardown, err)
}
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
if err != nil {
return fmt.Errorf("%s: %w", ErrTeardown, err)
}
if err := r.deleteRouteVia(net.IPNet{}, defaultGateway, defaultInterfaceName, table); err != nil {
return fmt.Errorf("%s: %w", ErrTeardown, err)
}
if err := r.deleteIPRule(defaultIP, table, priority); err != nil {
return fmt.Errorf("%s: %w", ErrTeardown, err)
}
return nil
}

View File

@@ -1,17 +1,18 @@
package routing
import (
"bytes"
"fmt"
"net"
"github.com/vishvananda/netlink"
)
func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error {
func (r *routing) addRouteVia(destination net.IPNet, gateway net.IP, iface string, table int) error {
destinationStr := destination.String()
r.logger.Info("adding route for %s", destinationStr)
if r.debug {
fmt.Printf("ip route add %s via %s dev %s\n", destinationStr, gateway, iface)
fmt.Printf("ip route replace %s via %s dev %s table %d\n", destinationStr, gateway, iface, table)
}
link, err := netlink.LinkByName(iface)
@@ -22,6 +23,7 @@ func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface strin
Dst: &destination,
Gw: gateway,
LinkIndex: link.Attrs().Index,
Table: table,
}
if err := netlink.RouteReplace(&route); err != nil {
return fmt.Errorf("cannot add route for %s: %w", destinationStr, err)
@@ -29,17 +31,80 @@ func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface strin
return nil
}
func (r *routing) DeleteRouteVia(destination net.IPNet) (err error) {
func (r *routing) deleteRouteVia(destination net.IPNet, gateway net.IP, iface string, table int) (err error) {
destinationStr := destination.String()
r.logger.Info("deleting route for %s", destinationStr)
if r.debug {
fmt.Printf("ip route del %s\n", destinationStr)
fmt.Printf("ip route delete %s via %s dev %s table %d\n", destinationStr, gateway, iface, table)
}
link, err := netlink.LinkByName(iface)
if err != nil {
return fmt.Errorf("cannot delete route for %s: %w", destinationStr, err)
}
route := netlink.Route{
Dst: &destination,
Dst: &destination,
Gw: gateway,
LinkIndex: link.Attrs().Index,
Table: table,
}
if err := netlink.RouteDel(&route); err != nil {
return fmt.Errorf("cannot delete route for %s: %w", destinationStr, err)
}
return nil
}
func (r *routing) addIPRule(src net.IP, table, priority int) error {
if r.debug {
fmt.Printf("ip rule add from %s lookup %d pref %d\n",
src, table, priority)
}
rule := netlink.NewRule()
rule.Src = netlink.NewIPNet(src)
rule.Priority = priority
rule.Table = table
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("cannot add ip rule: %w", err)
}
for _, existingRule := range rules {
if existingRule.Src != nil &&
existingRule.Src.IP.Equal(rule.Src.IP) &&
bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) &&
existingRule.Priority == rule.Priority &&
existingRule.Table == rule.Table {
return nil // already exists
}
}
return netlink.RuleAdd(rule)
}
func (r *routing) deleteIPRule(src net.IP, table, priority int) error {
if r.debug {
fmt.Printf("ip rule del from %s lookup %d pref %d\n",
src, table, priority)
}
rule := netlink.NewRule()
rule.Src = netlink.NewIPNet(src)
rule.Priority = priority
rule.Table = table
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("cannot add ip rule: %w", err)
}
for _, existingRule := range rules {
if existingRule.Src != nil &&
existingRule.Src.IP.Equal(rule.Src.IP) &&
bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) &&
existingRule.Priority == rule.Priority &&
existingRule.Table == rule.Table {
return netlink.RuleDel(rule)
}
}
return nil
}

View File

@@ -31,6 +31,30 @@ func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
return "", nil, fmt.Errorf("cannot find default route in %d routes", len(routes))
}
func (r *routing) defaultIP() (ip net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("cannot get default IP address: %w", err)
}
defaultLinkName := ""
for _, route := range routes {
if route.Dst == nil {
linkIndex := route.LinkIndex
link, err := netlink.LinkByIndex(linkIndex)
if err != nil {
return nil, fmt.Errorf("cannot get default IP address: %w", err)
}
defaultLinkName = link.Attrs().Name
}
}
if len(defaultLinkName) == 0 {
return nil, fmt.Errorf("cannot find default link name in %d routes", len(routes))
}
return r.assignedIP(defaultLinkName)
}
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
@@ -60,6 +84,26 @@ func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
return defaultSubnet, fmt.Errorf("cannot find default subnet in %d routes", len(routes))
}
func (r *routing) assignedIP(interfaceName string) (ip net.IP, err error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return nil, err
}
addresses, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, address := range addresses {
switch value := address.(type) {
case *net.IPAddr:
return value.IP, nil
case *net.IPNet:
return value.IP, nil
}
}
return nil, fmt.Errorf("IP address not found in addresses of interface %s", interfaceName)
}
func (r *routing) VPNDestinationIP() (ip net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {

View File

@@ -7,8 +7,8 @@ import (
)
type Routing interface {
AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error
DeleteRouteVia(destination net.IPNet) (err error)
Setup() (err error)
TearDown() error
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
LocalSubnet() (defaultSubnet net.IPNet, err error)
VPNDestinationIP() (ip net.IP, err error)