Netlink Go library to interact with IP routes (#267)
This commit is contained in:
@@ -81,7 +81,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
alpineConf := alpine.NewConfigurator(fileManager)
|
||||
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
|
||||
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
||||
routingConf := routing.NewRouting(logger, fileManager)
|
||||
routingConf := routing.NewRouting(logger)
|
||||
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
|
||||
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
|
||||
streamMerger := command.NewStreamMerger()
|
||||
@@ -364,7 +364,6 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{},
|
||||
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
|
||||
routing routing.Routing, logger logging.Logger, httpClient *http.Client,
|
||||
@@ -388,17 +387,12 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
|
||||
tickerWg.Add(2) //nolint:gomnd
|
||||
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
||||
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
||||
defaultInterface, _, err := routing.DefaultRoute()
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
vpnDestination, err := routing.VPNDestinationIP(defaultInterface)
|
||||
vpnDestination, err := routing.VPNDestinationIP()
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
logger.Info("VPN routing IP address: %s", vpnDestination)
|
||||
}
|
||||
}
|
||||
if portForwardingEnabled {
|
||||
// TODO make instantaneous once v3 go out of service
|
||||
const waitDuration = 5 * time.Second
|
||||
|
||||
1
go.mod
1
go.mod
@@ -9,6 +9,7 @@ require (
|
||||
github.com/qdm12/golibs v0.0.0-20201018204514-1d5986880422
|
||||
github.com/qdm12/ss-server v0.0.0-20200819124651-6428e626ee83
|
||||
github.com/stretchr/testify v1.6.1
|
||||
github.com/vishvananda/netlink v1.1.0
|
||||
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0
|
||||
golang.org/x/sys v0.0.0-20201018121011-98379d014ca7
|
||||
)
|
||||
|
||||
5
go.sum
5
go.sum
@@ -87,6 +87,10 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
|
||||
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
|
||||
go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY=
|
||||
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||
go.uber.org/multierr v1.3.0 h1:sFPn2GLc3poCkfrpIXGhBD2X0CMIo4Q/zSULXrj/+uc=
|
||||
@@ -116,6 +120,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201018121011-98379d014ca7 h1:CNOpL+H7PSxBI7dF/EIUsfOguRSzWp6CQ91yxZE6PG4=
|
||||
|
||||
@@ -109,7 +109,7 @@ func (c *configurator) enable(ctx context.Context) (err error) {
|
||||
}
|
||||
// Re-ensure all routes exist
|
||||
for _, subnet := range c.allowedSubnets {
|
||||
if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
||||
if err := c.routing.AddRouteVia(subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating allowed subnets internal list and updating routes")
|
||||
c.updateSubnetRoutes(ctx, c.allowedSubnets, subnets)
|
||||
c.updateSubnetRoutes(c.allowedSubnets, subnets)
|
||||
c.allowedSubnets = make([]net.IPNet, len(subnets))
|
||||
copy(c.allowedSubnets, subnets)
|
||||
return nil
|
||||
@@ -95,7 +95,7 @@ func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, d
|
||||
failed = true
|
||||
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
||||
}
|
||||
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
|
||||
if err := c.routing.DeleteRouteVia(subnet); err != nil {
|
||||
failed = true
|
||||
c.logger.Error("cannot remove outdated allowed subnet route: %s", err)
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defa
|
||||
if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil {
|
||||
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||
}
|
||||
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
||||
if err := c.routing.AddRouteVia(subnet, defaultGateway, defaultInterface); err != nil {
|
||||
return fmt.Errorf("cannot add route for allowed subnet: %w", err)
|
||||
}
|
||||
c.allowedSubnets = append(c.allowedSubnets, subnet)
|
||||
@@ -125,19 +125,19 @@ func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defa
|
||||
}
|
||||
|
||||
// updateSubnetRoutes does not return an error in order to try to run as many route commands as possible.
|
||||
func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSubnets []net.IPNet) {
|
||||
func (c *configurator) updateSubnetRoutes(oldSubnets, newSubnets []net.IPNet) {
|
||||
subnetsToAdd := findSubnetsToAdd(oldSubnets, newSubnets)
|
||||
subnetsToRemove := findSubnetsToRemove(oldSubnets, newSubnets)
|
||||
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
||||
return
|
||||
}
|
||||
for _, subnet := range subnetsToRemove {
|
||||
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
|
||||
if err := c.routing.DeleteRouteVia(subnet); err != nil {
|
||||
c.logger.Error("cannot remove outdated route for subnet: %s", err)
|
||||
}
|
||||
}
|
||||
for _, subnet := range subnetsToAdd {
|
||||
if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
||||
if err := c.routing.AddRouteVia(subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
||||
c.logger.Error("cannot add route for subnet: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type routingEntry struct {
|
||||
iface string
|
||||
destination net.IP
|
||||
gateway net.IP
|
||||
flags string
|
||||
refCount int
|
||||
use int
|
||||
metric int
|
||||
mask net.IPMask
|
||||
mtu int
|
||||
window int
|
||||
irtt int
|
||||
}
|
||||
|
||||
func parseRoutingEntry(s string) (r routingEntry, err error) {
|
||||
wrapError := func(err error) error {
|
||||
return fmt.Errorf("line %q: %w", s, err)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
const minFields = 11
|
||||
if len(fields) < minFields {
|
||||
return r, wrapError(fmt.Errorf("not enough fields"))
|
||||
}
|
||||
r.iface = fields[0]
|
||||
r.destination, err = reversedHexToIPv4(fields[1])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
r.gateway, err = reversedHexToIPv4(fields[2])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
r.flags = fields[3]
|
||||
r.refCount, err = strconv.Atoi(fields[4])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
r.use, err = strconv.Atoi(fields[5])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
r.metric, err = strconv.Atoi(fields[6])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
r.mask, err = hexToIPv4Mask(fields[7])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
r.mtu, err = strconv.Atoi(fields[8])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
r.window, err = strconv.Atoi(fields[9])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
r.irtt, err = strconv.Atoi(fields[10])
|
||||
if err != nil {
|
||||
return r, wrapError(err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func reversedHexToIPv4(reversedHex string) (ip net.IP, err error) {
|
||||
bytes, err := hex.DecodeString(reversedHex)
|
||||
const nBytesRequired = 4
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err)
|
||||
} else if L := len(bytes); L != nBytesRequired {
|
||||
return nil, fmt.Errorf("hex string contains %d bytes instead of %d", L, nBytesRequired)
|
||||
}
|
||||
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
|
||||
}
|
||||
|
||||
func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) {
|
||||
bytes, err := hex.DecodeString(hexString)
|
||||
const nBytesRequired = 4
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err)
|
||||
} else if L := len(bytes); L != nBytesRequired {
|
||||
return nil, fmt.Errorf("hex string contains %d bytes instead of %d", L, nBytesRequired)
|
||||
}
|
||||
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//nolint:lll
|
||||
func Test_parseRoutingEntry(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
s string
|
||||
r routingEntry
|
||||
err error
|
||||
}{
|
||||
"empty string": {
|
||||
err: fmt.Errorf("line \"\": not enough fields"),
|
||||
},
|
||||
"not enough fields": {
|
||||
s: "a b c d e",
|
||||
err: fmt.Errorf("line \"a b c d e\": not enough fields"),
|
||||
},
|
||||
"bad destination": {
|
||||
s: "eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0",
|
||||
err: fmt.Errorf("line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
|
||||
},
|
||||
"bad gateway": {
|
||||
s: "eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0",
|
||||
err: fmt.Errorf("line \"eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
|
||||
},
|
||||
"bad ref count": {
|
||||
s: "eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0",
|
||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
||||
},
|
||||
"bad use": {
|
||||
s: "eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0",
|
||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
||||
},
|
||||
"bad metric": {
|
||||
s: "eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0",
|
||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
||||
},
|
||||
"bad mask": {
|
||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0",
|
||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0\": cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'"),
|
||||
},
|
||||
"bad mtu": {
|
||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0",
|
||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
||||
},
|
||||
"bad window": {
|
||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0",
|
||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
||||
},
|
||||
"bad irtt": {
|
||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x",
|
||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
||||
},
|
||||
"success": {
|
||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0",
|
||||
r: routingEntry{
|
||||
iface: "eth0",
|
||||
destination: net.IP{192, 168, 2, 0},
|
||||
gateway: net.IP{10, 0, 0, 1},
|
||||
flags: "0003",
|
||||
mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r, err := parseRoutingEntry(tc.s)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.r, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_reversedHexToIPv4(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
reversedHex string
|
||||
IP net.IP
|
||||
err error
|
||||
}{
|
||||
"empty hex": {
|
||||
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
|
||||
"bad hex": {
|
||||
reversedHex: "x",
|
||||
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"3 bytes hex": {
|
||||
reversedHex: "9abcde",
|
||||
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
|
||||
"correct hex": {
|
||||
reversedHex: "010011AC",
|
||||
IP: []byte{0xac, 0x11, 0x0, 0x1},
|
||||
err: nil},
|
||||
"correct hex 2": {
|
||||
reversedHex: "000011AC",
|
||||
IP: []byte{0xac, 0x11, 0x0, 0x0},
|
||||
err: nil},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
IP, err := reversedHexToIPv4(tc.reversedHex)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.IP, IP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_hexMaskToDecMask(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
hexString string
|
||||
mask net.IPMask
|
||||
err error
|
||||
}{
|
||||
"empty hex": {
|
||||
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
|
||||
"bad hex": {
|
||||
hexString: "x",
|
||||
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"3 bytes hex": {
|
||||
hexString: "9abcde",
|
||||
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
|
||||
"16": {
|
||||
hexString: "0000FFFF",
|
||||
mask: []byte{0xff, 0xff, 0x0, 0x0},
|
||||
err: nil},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mask, err := hexToIPv4Mask(tc.hexString)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.mask, mask)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,48 +1,45 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (r *routing) AddRouteVia(ctx context.Context,
|
||||
subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
||||
subnetStr := subnet.String()
|
||||
r.logger.Info("adding %s as route via %s %s", subnetStr, defaultGateway, defaultInterface)
|
||||
exists, err := r.routeExists(subnet)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if exists {
|
||||
return nil
|
||||
}
|
||||
func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error {
|
||||
destinationStr := destination.String()
|
||||
r.logger.Info("adding route for %s", destinationStr)
|
||||
if r.debug {
|
||||
fmt.Printf("ip route add %s via %s dev %s\n", subnetStr, defaultGateway, defaultInterface)
|
||||
fmt.Printf("ip route add %s via %s dev %s\n", destinationStr, gateway, iface)
|
||||
}
|
||||
output, err := r.commander.Run(ctx,
|
||||
"ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
|
||||
|
||||
link, err := netlink.LinkByName(iface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w",
|
||||
subnetStr, defaultGateway, "dev", defaultInterface, output, err)
|
||||
return fmt.Errorf("cannot add route for %s: %w", destinationStr, err)
|
||||
}
|
||||
route := netlink.Route{
|
||||
Dst: &destination,
|
||||
Gw: gateway,
|
||||
LinkIndex: link.Attrs().Index,
|
||||
}
|
||||
if err := netlink.RouteReplace(&route); err != nil {
|
||||
return fmt.Errorf("cannot add route for %s: %w", destinationStr, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *routing) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) {
|
||||
subnetStr := subnet.String()
|
||||
r.logger.Info("deleting route for %s", subnetStr)
|
||||
exists, err := r.routeExists(subnet)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !exists { // thanks to @npawelek https://github.com/npawelek
|
||||
return nil
|
||||
}
|
||||
func (r *routing) DeleteRouteVia(destination net.IPNet) (err error) {
|
||||
destinationStr := destination.String()
|
||||
r.logger.Info("deleting route for %s", destinationStr)
|
||||
if r.debug {
|
||||
fmt.Printf("ip route del %s\n", subnetStr)
|
||||
fmt.Printf("ip route del %s\n", destinationStr)
|
||||
}
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "del", subnetStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete route for %s: %s: %w", subnetStr, output, err)
|
||||
route := netlink.Route{
|
||||
Dst: &destination,
|
||||
}
|
||||
if err := netlink.RouteDel(&route); err != nil {
|
||||
return fmt.Errorf("cannot delete route for %s: %w", destinationStr, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/command/mock_command"
|
||||
"github.com/qdm12/golibs/files/mock_files"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_DeleteRouteVia(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
tests := map[string]struct {
|
||||
subnet net.IPNet
|
||||
runOutput string
|
||||
runErr error
|
||||
err error
|
||||
}{
|
||||
"no output no error": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
},
|
||||
"error only": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
runErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: : error"),
|
||||
},
|
||||
"error and output": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
runErr: fmt.Errorf("error"),
|
||||
runOutput: "output",
|
||||
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: output: error"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
subnetStr := tc.subnet.String()
|
||||
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("deleting route for %s")
|
||||
commander := mock_command.NewMockCommander(mockCtrl)
|
||||
commander.EXPECT().Run(ctx, "ip", "route", "del", subnetStr).
|
||||
Return(tc.runOutput, tc.runErr).Times(1)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
//nolint:lll
|
||||
routesData := []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||
`)
|
||||
fileManager.EXPECT().ReadFile(string(constants.NetRoute)).Return(routesData, nil)
|
||||
r := &routing{
|
||||
logger: logger,
|
||||
commander: commander,
|
||||
fileManager: fileManager,
|
||||
}
|
||||
|
||||
err := r.DeleteRouteVia(ctx, tc.subnet)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,120 +4,105 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
|
||||
lines := strings.Split(strings.TrimSuffix(string(data), "\n"), "\n")
|
||||
lines = lines[1:]
|
||||
entries = make([]routingEntry, len(lines))
|
||||
for i := range lines {
|
||||
entries[i], err = parseRoutingEntry(lines[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("line %d in %s: %w", i+1, constants.NetRoute, err)
|
||||
}
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func getRoutingEntries(fileManager files.FileManager) (entries []routingEntry, err error) {
|
||||
data, err := fileManager.ReadFile(string(constants.NetRoute))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseRoutingTable(data)
|
||||
}
|
||||
|
||||
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, fmt.Errorf("cannot list routes: %w", err)
|
||||
}
|
||||
const minEntries = 2
|
||||
if len(entries) < minEntries {
|
||||
return "", nil, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute)
|
||||
for _, route := range routes {
|
||||
if route.Dst == nil {
|
||||
defaultGateway = route.Gw
|
||||
linkIndex := route.LinkIndex
|
||||
link, err := netlink.LinkByIndex(linkIndex)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("cannot obtain link with index %d for default route: %w", linkIndex, err)
|
||||
}
|
||||
var defaultRouteEntry routingEntry
|
||||
for _, entry := range entries {
|
||||
if entry.mask.String() == "00000000" {
|
||||
defaultRouteEntry = entry
|
||||
break
|
||||
}
|
||||
}
|
||||
if defaultRouteEntry.iface == "" {
|
||||
return "", nil, fmt.Errorf("cannot find default route")
|
||||
}
|
||||
defaultInterface = defaultRouteEntry.iface
|
||||
defaultGateway = defaultRouteEntry.gateway
|
||||
attributes := link.Attrs()
|
||||
defaultInterface = attributes.Name
|
||||
r.logger.Info("default route found: interface %s, gateway %s", defaultInterface, defaultGateway.String())
|
||||
return defaultInterface, defaultGateway, nil
|
||||
}
|
||||
}
|
||||
return "", nil, fmt.Errorf("cannot find default route in %d routes", len(routes))
|
||||
}
|
||||
|
||||
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return defaultSubnet, err
|
||||
return defaultSubnet, fmt.Errorf("cannot find local subnet: %w", err)
|
||||
}
|
||||
const minEntries = 2
|
||||
if len(entries) < minEntries {
|
||||
return defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute)
|
||||
}
|
||||
var localSubnetEntry routingEntry
|
||||
for _, entry := range entries {
|
||||
if entry.gateway.Equal(net.IP{0, 0, 0, 0}) && !strings.HasPrefix(entry.iface, "tun") {
|
||||
localSubnetEntry = entry
|
||||
|
||||
defaultLinkIndex := -1
|
||||
for _, route := range routes {
|
||||
if route.Dst == nil {
|
||||
defaultLinkIndex = route.LinkIndex
|
||||
break
|
||||
}
|
||||
}
|
||||
if localSubnetEntry.iface == "" {
|
||||
return defaultSubnet, fmt.Errorf("cannot find local subnet route")
|
||||
if defaultLinkIndex == -1 {
|
||||
return defaultSubnet, fmt.Errorf("cannot find local subnet: cannot find default link")
|
||||
}
|
||||
defaultSubnet = net.IPNet{IP: localSubnetEntry.destination, Mask: localSubnetEntry.mask}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Gw != nil || route.LinkIndex != defaultLinkIndex {
|
||||
continue
|
||||
}
|
||||
defaultSubnet = *route.Dst
|
||||
r.logger.Info("local subnet found: %s", defaultSubnet.String())
|
||||
return defaultSubnet, nil
|
||||
}
|
||||
|
||||
return defaultSubnet, fmt.Errorf("cannot find default subnet in %d routes", len(routes))
|
||||
}
|
||||
|
||||
func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
func (r *routing) VPNDestinationIP() (ip net.IP, err error) {
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot check route existence: %w", err)
|
||||
return nil, fmt.Errorf("cannot find VPN destination IP: %w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
entrySubnet := net.IPNet{IP: entry.destination, Mask: entry.mask}
|
||||
if entrySubnet.String() == subnet.String() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *routing) VPNDestinationIP(defaultInterface string) (ip net.IP, err error) {
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.iface == defaultInterface &&
|
||||
!ipIsPrivate(entry.destination) &&
|
||||
bytes.Equal(entry.mask, net.IPMask{255, 255, 255, 255}) {
|
||||
return entry.destination, nil
|
||||
defaultLinkIndex := -1
|
||||
for _, route := range routes {
|
||||
if route.Dst == nil {
|
||||
defaultLinkIndex = route.LinkIndex
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("cannot find VPN gateway IP address from ip routes")
|
||||
if defaultLinkIndex == -1 {
|
||||
return nil, fmt.Errorf("cannot find VPN destination IP: cannot find default link")
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.LinkIndex == defaultLinkIndex &&
|
||||
route.Dst != nil &&
|
||||
!ipIsPrivate(route.Dst.IP) &&
|
||||
bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) {
|
||||
return route.Dst.IP, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("cannot find VPN destination IP address from ip routes")
|
||||
}
|
||||
|
||||
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find VPN local gateway IP address: %w", err)
|
||||
return nil, fmt.Errorf("cannot find VPN local gateway IP: %w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.iface == string(constants.TUN) &&
|
||||
entry.destination.Equal(net.IP{0, 0, 0, 0}) {
|
||||
return entry.gateway, nil
|
||||
for _, route := range routes {
|
||||
link, err := netlink.LinkByIndex(route.LinkIndex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find VPN local gateway IP: %w", err)
|
||||
}
|
||||
interfaceName := link.Attrs().Name
|
||||
if interfaceName == string(constants.TUN) &&
|
||||
route.Dst != nil &&
|
||||
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
|
||||
return route.Gw, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("cannot find VPN local gateway IP address from ip routes")
|
||||
|
||||
@@ -1,352 +0,0 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files/mock_files"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//nolint:lll
|
||||
const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
tun0 00000000 050A030A 0003 0 0 0 00000080 0 0 0
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
tun0 010A030A 050A030A 0007 0 0 0 FFFFFFFF 0 0 0
|
||||
tun0 050A030A 00000000 0005 0 0 0 FFFFFFFF 0 0 0
|
||||
eth0 42196956 010011AC 0007 0 0 0 FFFFFFFF 0 0 0
|
||||
tun0 00000080 050A030A 0003 0 0 0 00000080 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
|
||||
`
|
||||
|
||||
//nolint:lll
|
||||
func Test_parseRoutingTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
entries []routingEntry
|
||||
err error
|
||||
}{
|
||||
"nil data": {
|
||||
entries: []routingEntry{},
|
||||
},
|
||||
"legend only": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
`),
|
||||
entries: []routingEntry{},
|
||||
},
|
||||
"legend and single line": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||
`),
|
||||
entries: []routingEntry{{
|
||||
iface: "eth0",
|
||||
destination: net.IP{192, 168, 2, 0},
|
||||
gateway: net.IP{10, 0, 0, 1},
|
||||
flags: "0003",
|
||||
mask: net.IPMask{255, 255, 255, 0},
|
||||
}},
|
||||
},
|
||||
"legend and two lines": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||
eth0 0002A8C0 0100000A 0002 0 0 0 00FFFFFF 0 0 0
|
||||
`),
|
||||
entries: []routingEntry{
|
||||
{
|
||||
iface: "eth0",
|
||||
destination: net.IP{192, 168, 2, 0},
|
||||
gateway: net.IP{10, 0, 0, 1},
|
||||
flags: "0003",
|
||||
mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
{
|
||||
iface: "eth0",
|
||||
destination: net.IP{192, 168, 2, 0},
|
||||
gateway: net.IP{10, 0, 0, 1},
|
||||
flags: "0002",
|
||||
mask: net.IPMask{255, 255, 255, 0},
|
||||
}},
|
||||
},
|
||||
"error": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||
`),
|
||||
entries: nil,
|
||||
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
entries, err := parseRoutingTable(tc.data)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.entries, entries)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:lll
|
||||
func Test_DefaultRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
readErr error
|
||||
defaultInterface string
|
||||
defaultGateway net.IP
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error")},
|
||||
"parse error": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 x
|
||||
`),
|
||||
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")},
|
||||
"single entry": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0
|
||||
`),
|
||||
err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)},
|
||||
"success": {
|
||||
data: []byte(exampleRouteData),
|
||||
defaultInterface: "eth0",
|
||||
defaultGateway: net.IP{172, 17, 0, 1},
|
||||
},
|
||||
"not found": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 10000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
|
||||
`),
|
||||
err: fmt.Errorf("cannot find default route"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
||||
|
||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
||||
Return(tc.data, tc.readErr).Times(1)
|
||||
if tc.err == nil {
|
||||
logger.EXPECT().Info(
|
||||
"default route found: interface %s, gateway %s",
|
||||
tc.defaultInterface, tc.defaultGateway.String(),
|
||||
).Times(1)
|
||||
}
|
||||
r := &routing{logger: logger, fileManager: filemanager}
|
||||
defaultInterface, defaultGateway, err := r.DefaultRoute()
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.defaultInterface, defaultInterface)
|
||||
assert.Equal(t, tc.defaultGateway, defaultGateway)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:lll
|
||||
func Test_LocalSubnet(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
readErr error
|
||||
localSubnet net.IPNet
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error")},
|
||||
"parse error": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 x
|
||||
`),
|
||||
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")},
|
||||
"single entry": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0
|
||||
`),
|
||||
err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)},
|
||||
"success": {
|
||||
data: []byte(exampleRouteData),
|
||||
localSubnet: net.IPNet{
|
||||
IP: net.IP{172, 17, 0, 0},
|
||||
Mask: net.IPMask{255, 255, 0, 0},
|
||||
},
|
||||
},
|
||||
"not found": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 10000000 0001 0 0 0 0000FFFF 0 0 0
|
||||
`),
|
||||
err: fmt.Errorf("cannot find local subnet route"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
||||
|
||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
||||
Return(tc.data, tc.readErr).Times(1)
|
||||
if tc.err == nil {
|
||||
logger.EXPECT().Info("local subnet found: %s", tc.localSubnet.String()).Times(1)
|
||||
}
|
||||
r := &routing{logger: logger, fileManager: filemanager}
|
||||
localSubnet, err := r.LocalSubnet()
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.localSubnet, localSubnet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:lll
|
||||
func Test_routeExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
subnet net.IPNet
|
||||
data []byte
|
||||
readErr error
|
||||
exists bool
|
||||
err error
|
||||
}{
|
||||
"no data": {},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("cannot check route existence: error"),
|
||||
},
|
||||
"parse error": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 x
|
||||
`),
|
||||
err: fmt.Errorf("cannot check route existence: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"),
|
||||
},
|
||||
"not existing": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 128},
|
||||
},
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||
`),
|
||||
},
|
||||
"existing": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||
`),
|
||||
exists: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
||||
Return(tc.data, tc.readErr).Times(1)
|
||||
r := &routing{fileManager: filemanager}
|
||||
exists, err := r.routeExists(tc.subnet)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.exists, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:lll
|
||||
func Test_VPNDestinationIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
defaultInterface string
|
||||
data []byte
|
||||
readErr error
|
||||
ip net.IP
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
err: fmt.Errorf("cannot find VPN gateway IP address from ip routes"),
|
||||
},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("cannot find VPN gateway IP address: error"),
|
||||
},
|
||||
"parse error": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 x
|
||||
`),
|
||||
err: fmt.Errorf("cannot find VPN gateway IP address: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"),
|
||||
},
|
||||
"found eth0": {
|
||||
defaultInterface: "eth0",
|
||||
data: []byte(exampleRouteData),
|
||||
ip: net.IP{86, 105, 25, 66},
|
||||
},
|
||||
"not found tun0": {
|
||||
defaultInterface: "tun0",
|
||||
data: []byte(exampleRouteData),
|
||||
err: fmt.Errorf("cannot find VPN gateway IP address from ip routes"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
||||
Return(tc.data, tc.readErr).Times(1)
|
||||
r := &routing{fileManager: filemanager}
|
||||
ip, err := r.VPNDestinationIP(tc.defaultInterface)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.ip, ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,37 +1,30 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type Routing interface {
|
||||
AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
|
||||
AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error
|
||||
DeleteRouteVia(destination net.IPNet) (err error)
|
||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
|
||||
LocalSubnet() (defaultSubnet net.IPNet, err error)
|
||||
VPNDestinationIP(defaultInterface string) (ip net.IP, err error)
|
||||
VPNDestinationIP() (ip net.IP, err error)
|
||||
VPNLocalGatewayIP() (ip net.IP, err error)
|
||||
SetDebug()
|
||||
}
|
||||
|
||||
type routing struct {
|
||||
commander command.Commander
|
||||
logger logging.Logger
|
||||
fileManager files.FileManager
|
||||
debug bool
|
||||
}
|
||||
|
||||
// NewConfigurator creates a new Configurator instance.
|
||||
func NewRouting(logger logging.Logger, fileManager files.FileManager) Routing {
|
||||
func NewRouting(logger logging.Logger) Routing {
|
||||
return &routing{
|
||||
commander: command.NewCommander(),
|
||||
logger: logger.WithPrefix("routing: "),
|
||||
fileManager: fileManager,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user