Netlink Go library to interact with IP routes (#267)
This commit is contained in:
@@ -81,7 +81,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
|||||||
alpineConf := alpine.NewConfigurator(fileManager)
|
alpineConf := alpine.NewConfigurator(fileManager)
|
||||||
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
|
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
|
||||||
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
||||||
routingConf := routing.NewRouting(logger, fileManager)
|
routingConf := routing.NewRouting(logger)
|
||||||
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
|
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
|
||||||
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
|
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
|
||||||
streamMerger := command.NewStreamMerger()
|
streamMerger := command.NewStreamMerger()
|
||||||
@@ -364,7 +364,6 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit
|
|
||||||
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{},
|
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{},
|
||||||
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
|
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
|
||||||
routing routing.Routing, logger logging.Logger, httpClient *http.Client,
|
routing routing.Routing, logger logging.Logger, httpClient *http.Client,
|
||||||
@@ -388,17 +387,12 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
|
|||||||
tickerWg.Add(2) //nolint:gomnd
|
tickerWg.Add(2) //nolint:gomnd
|
||||||
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
||||||
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
||||||
defaultInterface, _, err := routing.DefaultRoute()
|
vpnDestination, err := routing.VPNDestinationIP()
|
||||||
if err != nil {
|
|
||||||
logger.Warn(err)
|
|
||||||
} else {
|
|
||||||
vpnDestination, err := routing.VPNDestinationIP(defaultInterface)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
} else {
|
} else {
|
||||||
logger.Info("VPN routing IP address: %s", vpnDestination)
|
logger.Info("VPN routing IP address: %s", vpnDestination)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if portForwardingEnabled {
|
if portForwardingEnabled {
|
||||||
// TODO make instantaneous once v3 go out of service
|
// TODO make instantaneous once v3 go out of service
|
||||||
const waitDuration = 5 * time.Second
|
const waitDuration = 5 * time.Second
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -9,6 +9,7 @@ require (
|
|||||||
github.com/qdm12/golibs v0.0.0-20201018204514-1d5986880422
|
github.com/qdm12/golibs v0.0.0-20201018204514-1d5986880422
|
||||||
github.com/qdm12/ss-server v0.0.0-20200819124651-6428e626ee83
|
github.com/qdm12/ss-server v0.0.0-20200819124651-6428e626ee83
|
||||||
github.com/stretchr/testify v1.6.1
|
github.com/stretchr/testify v1.6.1
|
||||||
|
github.com/vishvananda/netlink v1.1.0
|
||||||
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0
|
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0
|
||||||
golang.org/x/sys v0.0.0-20201018121011-98379d014ca7
|
golang.org/x/sys v0.0.0-20201018121011-98379d014ca7
|
||||||
)
|
)
|
||||||
|
|||||||
5
go.sum
5
go.sum
@@ -87,6 +87,10 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy
|
|||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
|
||||||
|
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
|
||||||
|
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
|
||||||
|
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
|
||||||
go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY=
|
go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY=
|
||||||
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||||
go.uber.org/multierr v1.3.0 h1:sFPn2GLc3poCkfrpIXGhBD2X0CMIo4Q/zSULXrj/+uc=
|
go.uber.org/multierr v1.3.0 h1:sFPn2GLc3poCkfrpIXGhBD2X0CMIo4Q/zSULXrj/+uc=
|
||||||
@@ -116,6 +120,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
|
|||||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201018121011-98379d014ca7 h1:CNOpL+H7PSxBI7dF/EIUsfOguRSzWp6CQ91yxZE6PG4=
|
golang.org/x/sys v0.0.0-20201018121011-98379d014ca7 h1:CNOpL+H7PSxBI7dF/EIUsfOguRSzWp6CQ91yxZE6PG4=
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func (c *configurator) enable(ctx context.Context) (err error) {
|
|||||||
}
|
}
|
||||||
// Re-ensure all routes exist
|
// Re-ensure all routes exist
|
||||||
for _, subnet := range c.allowedSubnets {
|
for _, subnet := range c.allowedSubnets {
|
||||||
if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
if err := c.routing.AddRouteVia(subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe
|
|||||||
|
|
||||||
if !c.enabled {
|
if !c.enabled {
|
||||||
c.logger.Info("firewall disabled, only updating allowed subnets internal list and updating routes")
|
c.logger.Info("firewall disabled, only updating allowed subnets internal list and updating routes")
|
||||||
c.updateSubnetRoutes(ctx, c.allowedSubnets, subnets)
|
c.updateSubnetRoutes(c.allowedSubnets, subnets)
|
||||||
c.allowedSubnets = make([]net.IPNet, len(subnets))
|
c.allowedSubnets = make([]net.IPNet, len(subnets))
|
||||||
copy(c.allowedSubnets, subnets)
|
copy(c.allowedSubnets, subnets)
|
||||||
return nil
|
return nil
|
||||||
@@ -95,7 +95,7 @@ func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, d
|
|||||||
failed = true
|
failed = true
|
||||||
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
||||||
}
|
}
|
||||||
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
|
if err := c.routing.DeleteRouteVia(subnet); err != nil {
|
||||||
failed = true
|
failed = true
|
||||||
c.logger.Error("cannot remove outdated allowed subnet route: %s", err)
|
c.logger.Error("cannot remove outdated allowed subnet route: %s", err)
|
||||||
}
|
}
|
||||||
@@ -116,7 +116,7 @@ func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defa
|
|||||||
if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil {
|
if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
if err := c.routing.AddRouteVia(subnet, defaultGateway, defaultInterface); err != nil {
|
||||||
return fmt.Errorf("cannot add route for allowed subnet: %w", err)
|
return fmt.Errorf("cannot add route for allowed subnet: %w", err)
|
||||||
}
|
}
|
||||||
c.allowedSubnets = append(c.allowedSubnets, subnet)
|
c.allowedSubnets = append(c.allowedSubnets, subnet)
|
||||||
@@ -125,19 +125,19 @@ func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateSubnetRoutes does not return an error in order to try to run as many route commands as possible.
|
// updateSubnetRoutes does not return an error in order to try to run as many route commands as possible.
|
||||||
func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSubnets []net.IPNet) {
|
func (c *configurator) updateSubnetRoutes(oldSubnets, newSubnets []net.IPNet) {
|
||||||
subnetsToAdd := findSubnetsToAdd(oldSubnets, newSubnets)
|
subnetsToAdd := findSubnetsToAdd(oldSubnets, newSubnets)
|
||||||
subnetsToRemove := findSubnetsToRemove(oldSubnets, newSubnets)
|
subnetsToRemove := findSubnetsToRemove(oldSubnets, newSubnets)
|
||||||
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, subnet := range subnetsToRemove {
|
for _, subnet := range subnetsToRemove {
|
||||||
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
|
if err := c.routing.DeleteRouteVia(subnet); err != nil {
|
||||||
c.logger.Error("cannot remove outdated route for subnet: %s", err)
|
c.logger.Error("cannot remove outdated route for subnet: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, subnet := range subnetsToAdd {
|
for _, subnet := range subnetsToAdd {
|
||||||
if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
if err := c.routing.AddRouteVia(subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
||||||
c.logger.Error("cannot add route for subnet: %s", err)
|
c.logger.Error("cannot add route for subnet: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type routingEntry struct {
|
|
||||||
iface string
|
|
||||||
destination net.IP
|
|
||||||
gateway net.IP
|
|
||||||
flags string
|
|
||||||
refCount int
|
|
||||||
use int
|
|
||||||
metric int
|
|
||||||
mask net.IPMask
|
|
||||||
mtu int
|
|
||||||
window int
|
|
||||||
irtt int
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseRoutingEntry(s string) (r routingEntry, err error) {
|
|
||||||
wrapError := func(err error) error {
|
|
||||||
return fmt.Errorf("line %q: %w", s, err)
|
|
||||||
}
|
|
||||||
fields := strings.Fields(s)
|
|
||||||
const minFields = 11
|
|
||||||
if len(fields) < minFields {
|
|
||||||
return r, wrapError(fmt.Errorf("not enough fields"))
|
|
||||||
}
|
|
||||||
r.iface = fields[0]
|
|
||||||
r.destination, err = reversedHexToIPv4(fields[1])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
r.gateway, err = reversedHexToIPv4(fields[2])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
r.flags = fields[3]
|
|
||||||
r.refCount, err = strconv.Atoi(fields[4])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
r.use, err = strconv.Atoi(fields[5])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
r.metric, err = strconv.Atoi(fields[6])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
r.mask, err = hexToIPv4Mask(fields[7])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
r.mtu, err = strconv.Atoi(fields[8])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
r.window, err = strconv.Atoi(fields[9])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
r.irtt, err = strconv.Atoi(fields[10])
|
|
||||||
if err != nil {
|
|
||||||
return r, wrapError(err)
|
|
||||||
}
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func reversedHexToIPv4(reversedHex string) (ip net.IP, err error) {
|
|
||||||
bytes, err := hex.DecodeString(reversedHex)
|
|
||||||
const nBytesRequired = 4
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err)
|
|
||||||
} else if L := len(bytes); L != nBytesRequired {
|
|
||||||
return nil, fmt.Errorf("hex string contains %d bytes instead of %d", L, nBytesRequired)
|
|
||||||
}
|
|
||||||
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) {
|
|
||||||
bytes, err := hex.DecodeString(hexString)
|
|
||||||
const nBytesRequired = 4
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err)
|
|
||||||
} else if L := len(bytes); L != nBytesRequired {
|
|
||||||
return nil, fmt.Errorf("hex string contains %d bytes instead of %d", L, nBytesRequired)
|
|
||||||
}
|
|
||||||
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
|
|
||||||
}
|
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
//nolint:lll
|
|
||||||
func Test_parseRoutingEntry(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
s string
|
|
||||||
r routingEntry
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"empty string": {
|
|
||||||
err: fmt.Errorf("line \"\": not enough fields"),
|
|
||||||
},
|
|
||||||
"not enough fields": {
|
|
||||||
s: "a b c d e",
|
|
||||||
err: fmt.Errorf("line \"a b c d e\": not enough fields"),
|
|
||||||
},
|
|
||||||
"bad destination": {
|
|
||||||
s: "eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0",
|
|
||||||
err: fmt.Errorf("line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
|
|
||||||
},
|
|
||||||
"bad gateway": {
|
|
||||||
s: "eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0",
|
|
||||||
err: fmt.Errorf("line \"eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
|
|
||||||
},
|
|
||||||
"bad ref count": {
|
|
||||||
s: "eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0",
|
|
||||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
|
||||||
},
|
|
||||||
"bad use": {
|
|
||||||
s: "eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0",
|
|
||||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
|
||||||
},
|
|
||||||
"bad metric": {
|
|
||||||
s: "eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0",
|
|
||||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
|
||||||
},
|
|
||||||
"bad mask": {
|
|
||||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0",
|
|
||||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0\": cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'"),
|
|
||||||
},
|
|
||||||
"bad mtu": {
|
|
||||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0",
|
|
||||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
|
||||||
},
|
|
||||||
"bad window": {
|
|
||||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0",
|
|
||||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
|
||||||
},
|
|
||||||
"bad irtt": {
|
|
||||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x",
|
|
||||||
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x\": strconv.Atoi: parsing \"x\": invalid syntax"),
|
|
||||||
},
|
|
||||||
"success": {
|
|
||||||
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0",
|
|
||||||
r: routingEntry{
|
|
||||||
iface: "eth0",
|
|
||||||
destination: net.IP{192, 168, 2, 0},
|
|
||||||
gateway: net.IP{10, 0, 0, 1},
|
|
||||||
flags: "0003",
|
|
||||||
mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
r, err := parseRoutingEntry(tc.s)
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tc.r, r)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_reversedHexToIPv4(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
reversedHex string
|
|
||||||
IP net.IP
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"empty hex": {
|
|
||||||
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
|
|
||||||
"bad hex": {
|
|
||||||
reversedHex: "x",
|
|
||||||
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
|
||||||
"3 bytes hex": {
|
|
||||||
reversedHex: "9abcde",
|
|
||||||
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
|
|
||||||
"correct hex": {
|
|
||||||
reversedHex: "010011AC",
|
|
||||||
IP: []byte{0xac, 0x11, 0x0, 0x1},
|
|
||||||
err: nil},
|
|
||||||
"correct hex 2": {
|
|
||||||
reversedHex: "000011AC",
|
|
||||||
IP: []byte{0xac, 0x11, 0x0, 0x0},
|
|
||||||
err: nil},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
IP, err := reversedHexToIPv4(tc.reversedHex)
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.IP, IP)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_hexMaskToDecMask(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
hexString string
|
|
||||||
mask net.IPMask
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"empty hex": {
|
|
||||||
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
|
|
||||||
"bad hex": {
|
|
||||||
hexString: "x",
|
|
||||||
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
|
||||||
"3 bytes hex": {
|
|
||||||
hexString: "9abcde",
|
|
||||||
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
|
|
||||||
"16": {
|
|
||||||
hexString: "0000FFFF",
|
|
||||||
mask: []byte{0xff, 0xff, 0x0, 0x0},
|
|
||||||
err: nil},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
mask, err := hexToIPv4Mask(tc.hexString)
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.mask, mask)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,48 +1,45 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *routing) AddRouteVia(ctx context.Context,
|
func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error {
|
||||||
subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
destinationStr := destination.String()
|
||||||
subnetStr := subnet.String()
|
r.logger.Info("adding route for %s", destinationStr)
|
||||||
r.logger.Info("adding %s as route via %s %s", subnetStr, defaultGateway, defaultInterface)
|
|
||||||
exists, err := r.routeExists(subnet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if r.debug {
|
if r.debug {
|
||||||
fmt.Printf("ip route add %s via %s dev %s\n", subnetStr, defaultGateway, defaultInterface)
|
fmt.Printf("ip route add %s via %s dev %s\n", destinationStr, gateway, iface)
|
||||||
}
|
}
|
||||||
output, err := r.commander.Run(ctx,
|
|
||||||
"ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
|
link, err := netlink.LinkByName(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w",
|
return fmt.Errorf("cannot add route for %s: %w", destinationStr, err)
|
||||||
subnetStr, defaultGateway, "dev", defaultInterface, output, err)
|
}
|
||||||
|
route := netlink.Route{
|
||||||
|
Dst: &destination,
|
||||||
|
Gw: gateway,
|
||||||
|
LinkIndex: link.Attrs().Index,
|
||||||
|
}
|
||||||
|
if err := netlink.RouteReplace(&route); err != nil {
|
||||||
|
return fmt.Errorf("cannot add route for %s: %w", destinationStr, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *routing) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) {
|
func (r *routing) DeleteRouteVia(destination net.IPNet) (err error) {
|
||||||
subnetStr := subnet.String()
|
destinationStr := destination.String()
|
||||||
r.logger.Info("deleting route for %s", subnetStr)
|
r.logger.Info("deleting route for %s", destinationStr)
|
||||||
exists, err := r.routeExists(subnet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if !exists { // thanks to @npawelek https://github.com/npawelek
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if r.debug {
|
if r.debug {
|
||||||
fmt.Printf("ip route del %s\n", subnetStr)
|
fmt.Printf("ip route del %s\n", destinationStr)
|
||||||
}
|
}
|
||||||
output, err := r.commander.Run(ctx, "ip", "route", "del", subnetStr)
|
route := netlink.Route{
|
||||||
if err != nil {
|
Dst: &destination,
|
||||||
return fmt.Errorf("cannot delete route for %s: %s: %w", subnetStr, output, err)
|
}
|
||||||
|
if err := netlink.RouteDel(&route); err != nil {
|
||||||
|
return fmt.Errorf("cannot delete route for %s: %w", destinationStr, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,86 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
|
||||||
"github.com/qdm12/golibs/command/mock_command"
|
|
||||||
"github.com/qdm12/golibs/files/mock_files"
|
|
||||||
"github.com/qdm12/golibs/logging/mock_logging"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_DeleteRouteVia(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ctx := context.Background()
|
|
||||||
tests := map[string]struct {
|
|
||||||
subnet net.IPNet
|
|
||||||
runOutput string
|
|
||||||
runErr error
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"no output no error": {
|
|
||||||
subnet: net.IPNet{
|
|
||||||
IP: net.IP{192, 168, 2, 0},
|
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"error only": {
|
|
||||||
subnet: net.IPNet{
|
|
||||||
IP: net.IP{192, 168, 2, 0},
|
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
},
|
|
||||||
runErr: fmt.Errorf("error"),
|
|
||||||
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: : error"),
|
|
||||||
},
|
|
||||||
"error and output": {
|
|
||||||
subnet: net.IPNet{
|
|
||||||
IP: net.IP{192, 168, 2, 0},
|
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
},
|
|
||||||
runErr: fmt.Errorf("error"),
|
|
||||||
runOutput: "output",
|
|
||||||
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: output: error"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
defer mockCtrl.Finish()
|
|
||||||
|
|
||||||
subnetStr := tc.subnet.String()
|
|
||||||
|
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
|
||||||
logger.EXPECT().Info("deleting route for %s")
|
|
||||||
commander := mock_command.NewMockCommander(mockCtrl)
|
|
||||||
commander.EXPECT().Run(ctx, "ip", "route", "del", subnetStr).
|
|
||||||
Return(tc.runOutput, tc.runErr).Times(1)
|
|
||||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
|
||||||
//nolint:lll
|
|
||||||
routesData := []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
|
||||||
`)
|
|
||||||
fileManager.EXPECT().ReadFile(string(constants.NetRoute)).Return(routesData, nil)
|
|
||||||
r := &routing{
|
|
||||||
logger: logger,
|
|
||||||
commander: commander,
|
|
||||||
fileManager: fileManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := r.DeleteRouteVia(ctx, tc.subnet)
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -4,120 +4,105 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/golibs/files"
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
|
|
||||||
lines := strings.Split(strings.TrimSuffix(string(data), "\n"), "\n")
|
|
||||||
lines = lines[1:]
|
|
||||||
entries = make([]routingEntry, len(lines))
|
|
||||||
for i := range lines {
|
|
||||||
entries[i], err = parseRoutingEntry(lines[i])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("line %d in %s: %w", i+1, constants.NetRoute, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return entries, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRoutingEntries(fileManager files.FileManager) (entries []routingEntry, err error) {
|
|
||||||
data, err := fileManager.ReadFile(string(constants.NetRoute))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return parseRoutingTable(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
||||||
entries, err := getRoutingEntries(r.fileManager)
|
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
const minEntries = 2
|
for _, route := range routes {
|
||||||
if len(entries) < minEntries {
|
if route.Dst == nil {
|
||||||
return "", nil, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute)
|
defaultGateway = route.Gw
|
||||||
|
linkIndex := route.LinkIndex
|
||||||
|
link, err := netlink.LinkByIndex(linkIndex)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("cannot obtain link with index %d for default route: %w", linkIndex, err)
|
||||||
}
|
}
|
||||||
var defaultRouteEntry routingEntry
|
attributes := link.Attrs()
|
||||||
for _, entry := range entries {
|
defaultInterface = attributes.Name
|
||||||
if entry.mask.String() == "00000000" {
|
|
||||||
defaultRouteEntry = entry
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if defaultRouteEntry.iface == "" {
|
|
||||||
return "", nil, fmt.Errorf("cannot find default route")
|
|
||||||
}
|
|
||||||
defaultInterface = defaultRouteEntry.iface
|
|
||||||
defaultGateway = defaultRouteEntry.gateway
|
|
||||||
r.logger.Info("default route found: interface %s, gateway %s", defaultInterface, defaultGateway.String())
|
r.logger.Info("default route found: interface %s, gateway %s", defaultInterface, defaultGateway.String())
|
||||||
return defaultInterface, defaultGateway, nil
|
return defaultInterface, defaultGateway, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", nil, fmt.Errorf("cannot find default route in %d routes", len(routes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
||||||
entries, err := getRoutingEntries(r.fileManager)
|
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return defaultSubnet, err
|
return defaultSubnet, fmt.Errorf("cannot find local subnet: %w", err)
|
||||||
}
|
}
|
||||||
const minEntries = 2
|
|
||||||
if len(entries) < minEntries {
|
defaultLinkIndex := -1
|
||||||
return defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute)
|
for _, route := range routes {
|
||||||
}
|
if route.Dst == nil {
|
||||||
var localSubnetEntry routingEntry
|
defaultLinkIndex = route.LinkIndex
|
||||||
for _, entry := range entries {
|
|
||||||
if entry.gateway.Equal(net.IP{0, 0, 0, 0}) && !strings.HasPrefix(entry.iface, "tun") {
|
|
||||||
localSubnetEntry = entry
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if localSubnetEntry.iface == "" {
|
if defaultLinkIndex == -1 {
|
||||||
return defaultSubnet, fmt.Errorf("cannot find local subnet route")
|
return defaultSubnet, fmt.Errorf("cannot find local subnet: cannot find default link")
|
||||||
}
|
}
|
||||||
defaultSubnet = net.IPNet{IP: localSubnetEntry.destination, Mask: localSubnetEntry.mask}
|
|
||||||
|
for _, route := range routes {
|
||||||
|
if route.Gw != nil || route.LinkIndex != defaultLinkIndex {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defaultSubnet = *route.Dst
|
||||||
r.logger.Info("local subnet found: %s", defaultSubnet.String())
|
r.logger.Info("local subnet found: %s", defaultSubnet.String())
|
||||||
return defaultSubnet, nil
|
return defaultSubnet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultSubnet, fmt.Errorf("cannot find default subnet in %d routes", len(routes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
|
func (r *routing) VPNDestinationIP() (ip net.IP, err error) {
|
||||||
entries, err := getRoutingEntries(r.fileManager)
|
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("cannot check route existence: %w", err)
|
return nil, fmt.Errorf("cannot find VPN destination IP: %w", err)
|
||||||
}
|
}
|
||||||
for _, entry := range entries {
|
|
||||||
entrySubnet := net.IPNet{IP: entry.destination, Mask: entry.mask}
|
|
||||||
if entrySubnet.String() == subnet.String() {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *routing) VPNDestinationIP(defaultInterface string) (ip net.IP, err error) {
|
defaultLinkIndex := -1
|
||||||
entries, err := getRoutingEntries(r.fileManager)
|
for _, route := range routes {
|
||||||
if err != nil {
|
if route.Dst == nil {
|
||||||
return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err)
|
defaultLinkIndex = route.LinkIndex
|
||||||
}
|
break
|
||||||
for _, entry := range entries {
|
|
||||||
if entry.iface == defaultInterface &&
|
|
||||||
!ipIsPrivate(entry.destination) &&
|
|
||||||
bytes.Equal(entry.mask, net.IPMask{255, 255, 255, 255}) {
|
|
||||||
return entry.destination, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("cannot find VPN gateway IP address from ip routes")
|
if defaultLinkIndex == -1 {
|
||||||
|
return nil, fmt.Errorf("cannot find VPN destination IP: cannot find default link")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
if route.LinkIndex == defaultLinkIndex &&
|
||||||
|
route.Dst != nil &&
|
||||||
|
!ipIsPrivate(route.Dst.IP) &&
|
||||||
|
bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) {
|
||||||
|
return route.Dst.IP, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("cannot find VPN destination IP address from ip routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
|
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
|
||||||
entries, err := getRoutingEntries(r.fileManager)
|
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot find VPN local gateway IP address: %w", err)
|
return nil, fmt.Errorf("cannot find VPN local gateway IP: %w", err)
|
||||||
}
|
}
|
||||||
for _, entry := range entries {
|
for _, route := range routes {
|
||||||
if entry.iface == string(constants.TUN) &&
|
link, err := netlink.LinkByIndex(route.LinkIndex)
|
||||||
entry.destination.Equal(net.IP{0, 0, 0, 0}) {
|
if err != nil {
|
||||||
return entry.gateway, nil
|
return nil, fmt.Errorf("cannot find VPN local gateway IP: %w", err)
|
||||||
|
}
|
||||||
|
interfaceName := link.Attrs().Name
|
||||||
|
if interfaceName == string(constants.TUN) &&
|
||||||
|
route.Dst != nil &&
|
||||||
|
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
|
||||||
|
return route.Gw, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("cannot find VPN local gateway IP address from ip routes")
|
return nil, fmt.Errorf("cannot find VPN local gateway IP address from ip routes")
|
||||||
|
|||||||
@@ -1,352 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
|
||||||
"github.com/qdm12/golibs/files/mock_files"
|
|
||||||
"github.com/qdm12/golibs/logging/mock_logging"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
//nolint:lll
|
|
||||||
const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
tun0 00000000 050A030A 0003 0 0 0 00000080 0 0 0
|
|
||||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
|
||||||
tun0 010A030A 050A030A 0007 0 0 0 FFFFFFFF 0 0 0
|
|
||||||
tun0 050A030A 00000000 0005 0 0 0 FFFFFFFF 0 0 0
|
|
||||||
eth0 42196956 010011AC 0007 0 0 0 FFFFFFFF 0 0 0
|
|
||||||
tun0 00000080 050A030A 0003 0 0 0 00000080 0 0 0
|
|
||||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
|
|
||||||
`
|
|
||||||
|
|
||||||
//nolint:lll
|
|
||||||
func Test_parseRoutingTable(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
data []byte
|
|
||||||
entries []routingEntry
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"nil data": {
|
|
||||||
entries: []routingEntry{},
|
|
||||||
},
|
|
||||||
"legend only": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
`),
|
|
||||||
entries: []routingEntry{},
|
|
||||||
},
|
|
||||||
"legend and single line": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
|
||||||
`),
|
|
||||||
entries: []routingEntry{{
|
|
||||||
iface: "eth0",
|
|
||||||
destination: net.IP{192, 168, 2, 0},
|
|
||||||
gateway: net.IP{10, 0, 0, 1},
|
|
||||||
flags: "0003",
|
|
||||||
mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
"legend and two lines": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
|
||||||
eth0 0002A8C0 0100000A 0002 0 0 0 00FFFFFF 0 0 0
|
|
||||||
`),
|
|
||||||
entries: []routingEntry{
|
|
||||||
{
|
|
||||||
iface: "eth0",
|
|
||||||
destination: net.IP{192, 168, 2, 0},
|
|
||||||
gateway: net.IP{10, 0, 0, 1},
|
|
||||||
flags: "0003",
|
|
||||||
mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
iface: "eth0",
|
|
||||||
destination: net.IP{192, 168, 2, 0},
|
|
||||||
gateway: net.IP{10, 0, 0, 1},
|
|
||||||
flags: "0002",
|
|
||||||
mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
"error": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
|
||||||
`),
|
|
||||||
entries: nil,
|
|
||||||
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
entries, err := parseRoutingTable(tc.data)
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.entries, entries)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:lll
|
|
||||||
func Test_DefaultRoute(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
data []byte
|
|
||||||
readErr error
|
|
||||||
defaultInterface string
|
|
||||||
defaultGateway net.IP
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"no data": {
|
|
||||||
err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)},
|
|
||||||
"read error": {
|
|
||||||
readErr: fmt.Errorf("error"),
|
|
||||||
err: fmt.Errorf("error")},
|
|
||||||
"parse error": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 x
|
|
||||||
`),
|
|
||||||
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")},
|
|
||||||
"single entry": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0
|
|
||||||
`),
|
|
||||||
err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)},
|
|
||||||
"success": {
|
|
||||||
data: []byte(exampleRouteData),
|
|
||||||
defaultInterface: "eth0",
|
|
||||||
defaultGateway: net.IP{172, 17, 0, 1},
|
|
||||||
},
|
|
||||||
"not found": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 00000000 010011AC 0003 0 0 0 10000000 0 0 0
|
|
||||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
|
|
||||||
`),
|
|
||||||
err: fmt.Errorf("cannot find default route"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
defer mockCtrl.Finish()
|
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
|
||||||
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
|
||||||
|
|
||||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
|
||||||
Return(tc.data, tc.readErr).Times(1)
|
|
||||||
if tc.err == nil {
|
|
||||||
logger.EXPECT().Info(
|
|
||||||
"default route found: interface %s, gateway %s",
|
|
||||||
tc.defaultInterface, tc.defaultGateway.String(),
|
|
||||||
).Times(1)
|
|
||||||
}
|
|
||||||
r := &routing{logger: logger, fileManager: filemanager}
|
|
||||||
defaultInterface, defaultGateway, err := r.DefaultRoute()
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.defaultInterface, defaultInterface)
|
|
||||||
assert.Equal(t, tc.defaultGateway, defaultGateway)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:lll
|
|
||||||
func Test_LocalSubnet(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
data []byte
|
|
||||||
readErr error
|
|
||||||
localSubnet net.IPNet
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"no data": {
|
|
||||||
err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)},
|
|
||||||
"read error": {
|
|
||||||
readErr: fmt.Errorf("error"),
|
|
||||||
err: fmt.Errorf("error")},
|
|
||||||
"parse error": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 x
|
|
||||||
`),
|
|
||||||
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")},
|
|
||||||
"single entry": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0
|
|
||||||
`),
|
|
||||||
err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)},
|
|
||||||
"success": {
|
|
||||||
data: []byte(exampleRouteData),
|
|
||||||
localSubnet: net.IPNet{
|
|
||||||
IP: net.IP{172, 17, 0, 0},
|
|
||||||
Mask: net.IPMask{255, 255, 0, 0},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"not found": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
|
||||||
eth0 000011AC 10000000 0001 0 0 0 0000FFFF 0 0 0
|
|
||||||
`),
|
|
||||||
err: fmt.Errorf("cannot find local subnet route"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
defer mockCtrl.Finish()
|
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
|
||||||
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
|
||||||
|
|
||||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
|
||||||
Return(tc.data, tc.readErr).Times(1)
|
|
||||||
if tc.err == nil {
|
|
||||||
logger.EXPECT().Info("local subnet found: %s", tc.localSubnet.String()).Times(1)
|
|
||||||
}
|
|
||||||
r := &routing{logger: logger, fileManager: filemanager}
|
|
||||||
localSubnet, err := r.LocalSubnet()
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.localSubnet, localSubnet)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:lll
|
|
||||||
func Test_routeExists(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
subnet net.IPNet
|
|
||||||
data []byte
|
|
||||||
readErr error
|
|
||||||
exists bool
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"no data": {},
|
|
||||||
"read error": {
|
|
||||||
readErr: fmt.Errorf("error"),
|
|
||||||
err: fmt.Errorf("cannot check route existence: error"),
|
|
||||||
},
|
|
||||||
"parse error": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 x
|
|
||||||
`),
|
|
||||||
err: fmt.Errorf("cannot check route existence: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"),
|
|
||||||
},
|
|
||||||
"not existing": {
|
|
||||||
subnet: net.IPNet{
|
|
||||||
IP: net.IP{192, 168, 2, 0},
|
|
||||||
Mask: net.IPMask{255, 255, 255, 128},
|
|
||||||
},
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
|
||||||
`),
|
|
||||||
},
|
|
||||||
"existing": {
|
|
||||||
subnet: net.IPNet{
|
|
||||||
IP: net.IP{192, 168, 2, 0},
|
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
},
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
|
||||||
`),
|
|
||||||
exists: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
defer mockCtrl.Finish()
|
|
||||||
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
|
||||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
|
||||||
Return(tc.data, tc.readErr).Times(1)
|
|
||||||
r := &routing{fileManager: filemanager}
|
|
||||||
exists, err := r.routeExists(tc.subnet)
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.exists, exists)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:lll
|
|
||||||
func Test_VPNDestinationIP(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
defaultInterface string
|
|
||||||
data []byte
|
|
||||||
readErr error
|
|
||||||
ip net.IP
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"no data": {
|
|
||||||
err: fmt.Errorf("cannot find VPN gateway IP address from ip routes"),
|
|
||||||
},
|
|
||||||
"read error": {
|
|
||||||
readErr: fmt.Errorf("error"),
|
|
||||||
err: fmt.Errorf("cannot find VPN gateway IP address: error"),
|
|
||||||
},
|
|
||||||
"parse error": {
|
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
eth0 x
|
|
||||||
`),
|
|
||||||
err: fmt.Errorf("cannot find VPN gateway IP address: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"),
|
|
||||||
},
|
|
||||||
"found eth0": {
|
|
||||||
defaultInterface: "eth0",
|
|
||||||
data: []byte(exampleRouteData),
|
|
||||||
ip: net.IP{86, 105, 25, 66},
|
|
||||||
},
|
|
||||||
"not found tun0": {
|
|
||||||
defaultInterface: "tun0",
|
|
||||||
data: []byte(exampleRouteData),
|
|
||||||
err: fmt.Errorf("cannot find VPN gateway IP address from ip routes"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
defer mockCtrl.Finish()
|
|
||||||
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
|
||||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
|
||||||
Return(tc.data, tc.readErr).Times(1)
|
|
||||||
r := &routing{fileManager: filemanager}
|
|
||||||
ip, err := r.VPNDestinationIP(tc.defaultInterface)
|
|
||||||
if tc.err != nil {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.ip, ip)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,37 +1,30 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/qdm12/golibs/command"
|
|
||||||
"github.com/qdm12/golibs/files"
|
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Routing interface {
|
type Routing interface {
|
||||||
AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error
|
||||||
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
|
DeleteRouteVia(destination net.IPNet) (err error)
|
||||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
|
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
|
||||||
LocalSubnet() (defaultSubnet net.IPNet, err error)
|
LocalSubnet() (defaultSubnet net.IPNet, err error)
|
||||||
VPNDestinationIP(defaultInterface string) (ip net.IP, err error)
|
VPNDestinationIP() (ip net.IP, err error)
|
||||||
VPNLocalGatewayIP() (ip net.IP, err error)
|
VPNLocalGatewayIP() (ip net.IP, err error)
|
||||||
SetDebug()
|
SetDebug()
|
||||||
}
|
}
|
||||||
|
|
||||||
type routing struct {
|
type routing struct {
|
||||||
commander command.Commander
|
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
fileManager files.FileManager
|
|
||||||
debug bool
|
debug bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConfigurator creates a new Configurator instance.
|
// NewConfigurator creates a new Configurator instance.
|
||||||
func NewRouting(logger logging.Logger, fileManager files.FileManager) Routing {
|
func NewRouting(logger logging.Logger) Routing {
|
||||||
return &routing{
|
return &routing{
|
||||||
commander: command.NewCommander(),
|
|
||||||
logger: logger.WithPrefix("routing: "),
|
logger: logger.WithPrefix("routing: "),
|
||||||
fileManager: fileManager,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user