Routing improvements (#268)
- Fixes #82 - Remove `EXTRA_SUBNETS` - Remove no longer needed iptables rules - Reduce routing interface arity - Routing setup is done in main.go instead of in the firewall - Routing setup gets reverted at shutdown
This commit is contained in:
62
internal/routing/enable.go
Normal file
62
internal/routing/enable.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSetup = fmt.Errorf("cannot setup routing")
|
||||
ErrTeardown = fmt.Errorf("cannot teardown routing")
|
||||
)
|
||||
|
||||
const (
|
||||
table = 200
|
||||
priority = 100
|
||||
)
|
||||
|
||||
func (r *routing) Setup() (err error) {
|
||||
defaultIP, err := r.defaultIP()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrSetup, err)
|
||||
}
|
||||
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrSetup, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if err := r.TearDown(); err != nil {
|
||||
r.logger.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := r.addIPRule(defaultIP, table, priority); err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrSetup, err)
|
||||
}
|
||||
if err := r.addRouteVia(net.IPNet{}, defaultGateway, defaultInterfaceName, table); err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrSetup, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *routing) TearDown() error {
|
||||
defaultIP, err := r.defaultIP()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrTeardown, err)
|
||||
}
|
||||
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrTeardown, err)
|
||||
}
|
||||
|
||||
if err := r.deleteRouteVia(net.IPNet{}, defaultGateway, defaultInterfaceName, table); err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrTeardown, err)
|
||||
}
|
||||
if err := r.deleteIPRule(defaultIP, table, priority); err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrTeardown, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,17 +1,18 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error {
|
||||
func (r *routing) addRouteVia(destination net.IPNet, gateway net.IP, iface string, table int) error {
|
||||
destinationStr := destination.String()
|
||||
r.logger.Info("adding route for %s", destinationStr)
|
||||
if r.debug {
|
||||
fmt.Printf("ip route add %s via %s dev %s\n", destinationStr, gateway, iface)
|
||||
fmt.Printf("ip route replace %s via %s dev %s table %d\n", destinationStr, gateway, iface, table)
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(iface)
|
||||
@@ -22,6 +23,7 @@ func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface strin
|
||||
Dst: &destination,
|
||||
Gw: gateway,
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: table,
|
||||
}
|
||||
if err := netlink.RouteReplace(&route); err != nil {
|
||||
return fmt.Errorf("cannot add route for %s: %w", destinationStr, err)
|
||||
@@ -29,17 +31,80 @@ func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *routing) DeleteRouteVia(destination net.IPNet) (err error) {
|
||||
func (r *routing) deleteRouteVia(destination net.IPNet, gateway net.IP, iface string, table int) (err error) {
|
||||
destinationStr := destination.String()
|
||||
r.logger.Info("deleting route for %s", destinationStr)
|
||||
if r.debug {
|
||||
fmt.Printf("ip route del %s\n", destinationStr)
|
||||
fmt.Printf("ip route delete %s via %s dev %s table %d\n", destinationStr, gateway, iface, table)
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(iface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete route for %s: %w", destinationStr, err)
|
||||
}
|
||||
route := netlink.Route{
|
||||
Dst: &destination,
|
||||
Dst: &destination,
|
||||
Gw: gateway,
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: table,
|
||||
}
|
||||
if err := netlink.RouteDel(&route); err != nil {
|
||||
return fmt.Errorf("cannot delete route for %s: %w", destinationStr, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *routing) addIPRule(src net.IP, table, priority int) error {
|
||||
if r.debug {
|
||||
fmt.Printf("ip rule add from %s lookup %d pref %d\n",
|
||||
src, table, priority)
|
||||
}
|
||||
|
||||
rule := netlink.NewRule()
|
||||
rule.Src = netlink.NewIPNet(src)
|
||||
rule.Priority = priority
|
||||
rule.Table = table
|
||||
|
||||
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add ip rule: %w", err)
|
||||
}
|
||||
for _, existingRule := range rules {
|
||||
if existingRule.Src != nil &&
|
||||
existingRule.Src.IP.Equal(rule.Src.IP) &&
|
||||
bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) &&
|
||||
existingRule.Priority == rule.Priority &&
|
||||
existingRule.Table == rule.Table {
|
||||
return nil // already exists
|
||||
}
|
||||
}
|
||||
|
||||
return netlink.RuleAdd(rule)
|
||||
}
|
||||
|
||||
func (r *routing) deleteIPRule(src net.IP, table, priority int) error {
|
||||
if r.debug {
|
||||
fmt.Printf("ip rule del from %s lookup %d pref %d\n",
|
||||
src, table, priority)
|
||||
}
|
||||
|
||||
rule := netlink.NewRule()
|
||||
rule.Src = netlink.NewIPNet(src)
|
||||
rule.Priority = priority
|
||||
rule.Table = table
|
||||
|
||||
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add ip rule: %w", err)
|
||||
}
|
||||
for _, existingRule := range rules {
|
||||
if existingRule.Src != nil &&
|
||||
existingRule.Src.IP.Equal(rule.Src.IP) &&
|
||||
bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) &&
|
||||
existingRule.Priority == rule.Priority &&
|
||||
existingRule.Table == rule.Table {
|
||||
return netlink.RuleDel(rule)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -31,6 +31,30 @@ func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
|
||||
return "", nil, fmt.Errorf("cannot find default route in %d routes", len(routes))
|
||||
}
|
||||
|
||||
func (r *routing) defaultIP() (ip net.IP, err error) {
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot get default IP address: %w", err)
|
||||
}
|
||||
|
||||
defaultLinkName := ""
|
||||
for _, route := range routes {
|
||||
if route.Dst == nil {
|
||||
linkIndex := route.LinkIndex
|
||||
link, err := netlink.LinkByIndex(linkIndex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot get default IP address: %w", err)
|
||||
}
|
||||
defaultLinkName = link.Attrs().Name
|
||||
}
|
||||
}
|
||||
if len(defaultLinkName) == 0 {
|
||||
return nil, fmt.Errorf("cannot find default link name in %d routes", len(routes))
|
||||
}
|
||||
|
||||
return r.assignedIP(defaultLinkName)
|
||||
}
|
||||
|
||||
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
@@ -60,6 +84,26 @@ func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
||||
return defaultSubnet, fmt.Errorf("cannot find default subnet in %d routes", len(routes))
|
||||
}
|
||||
|
||||
func (r *routing) assignedIP(interfaceName string) (ip net.IP, err error) {
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addresses, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, address := range addresses {
|
||||
switch value := address.(type) {
|
||||
case *net.IPAddr:
|
||||
return value.IP, nil
|
||||
case *net.IPNet:
|
||||
return value.IP, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("IP address not found in addresses of interface %s", interfaceName)
|
||||
}
|
||||
|
||||
func (r *routing) VPNDestinationIP() (ip net.IP, err error) {
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
)
|
||||
|
||||
type Routing interface {
|
||||
AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error
|
||||
DeleteRouteVia(destination net.IPNet) (err error)
|
||||
Setup() (err error)
|
||||
TearDown() error
|
||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
|
||||
LocalSubnet() (defaultSubnet net.IPNet, err error)
|
||||
VPNDestinationIP() (ip net.IP, err error)
|
||||
|
||||
Reference in New Issue
Block a user