Maint: internal/routing IP rules functions

- Take in `src` as `*net.IPNet` instead of `net.IP`
- Take `dst` IP network
- Debug logged `ip rule` dynamically built
- Add unit tests for all IP rules functions
This commit is contained in:
Quentin McGaw (desktop)
2021-08-26 13:59:43 +00:00
parent 6c2a3e36b5
commit 2901db3cf3
5 changed files with 612 additions and 16 deletions

View File

@@ -4,6 +4,8 @@ import (
"errors"
"fmt"
"net"
"github.com/qdm12/gluetun/internal/netlink"
)
const (
@@ -53,8 +55,11 @@ func (r *Routing) addRuleInboundFromDefault(table int) (err error) {
return fmt.Errorf("%w: %s", errDefaultIP, err)
}
if err := r.addIPRule(defaultIP, table, inboundPriority); err != nil {
return fmt.Errorf("%w: %s", errIPRuleAdd, err)
defaultIPMasked32 := netlink.NewIPNet(defaultIP)
ruleDstNet := (*net.IPNet)(nil)
err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
if err != nil {
return fmt.Errorf("%w: %s", errRuleAdd, err)
}
return nil
@@ -66,8 +71,11 @@ func (r *Routing) delRuleInboundFromDefault(table int) (err error) {
return fmt.Errorf("%w: %s", errDefaultIP, err)
}
if err := r.deleteIPRule(defaultIP, table, inboundPriority); err != nil {
return fmt.Errorf("%w: %s", errIPRuleAdd, err)
defaultIPMasked32 := netlink.NewIPNet(defaultIP)
ruleDstNet := (*net.IPNet)(nil)
err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
if err != nil {
return fmt.Errorf("%w: %s", errRuleDelete, err)
}
return nil

View File

@@ -5,23 +5,21 @@ import (
"errors"
"fmt"
"net"
"strconv"
"github.com/qdm12/gluetun/internal/netlink"
)
var (
errIPRuleAdd = errors.New("cannot add IP rule")
errRulesList = errors.New("cannot list rules")
)
func (r *Routing) addIPRule(src net.IP, table, priority int) error {
r.logger.Debug("ip rule add from " + src.String() +
" lookup " + strconv.Itoa(table) +
" pref " + strconv.Itoa(priority))
func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
const add = true
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule()
rule.Src = netlink.NewIPNet(src)
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
@@ -45,13 +43,13 @@ func (r *Routing) addIPRule(src net.IP, table, priority int) error {
return nil
}
func (r *Routing) deleteIPRule(src net.IP, table, priority int) error {
r.logger.Debug("ip rule del from " + src.String() +
" lookup " + strconv.Itoa(table) +
" pref " + strconv.Itoa(priority))
func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error {
const add = false
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule()
rule.Src = netlink.NewIPNet(src)
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
@@ -72,3 +70,32 @@ func (r *Routing) deleteIPRule(src net.IP, table, priority int) error {
}
return nil
}
func ruleDbgMsg(add bool, src, dst *net.IPNet,
table, priority int) (debugMessage string) {
debugMessage = "ip rule"
if add {
debugMessage += " add"
} else {
debugMessage += " del"
}
if src != nil {
debugMessage += " from " + src.String()
}
if dst != nil {
debugMessage += " to " + dst.String()
}
if table != 0 {
debugMessage += " lookup " + fmt.Sprint(table)
}
if priority != -1 {
debugMessage += " pref " + fmt.Sprint(priority)
}
return debugMessage
}

View File

@@ -0,0 +1,308 @@
package routing
import (
"errors"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/netlink/mock_netlink"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func makeIPNet(t *testing.T, n byte) *net.IPNet {
t.Helper()
return &net.IPNet{
IP: net.IPv4(n, n, n, 0),
Mask: net.IPv4Mask(255, 255, 255, 0),
}
}
func makeIPRule(t *testing.T, src, dst *net.IPNet,
table, priority int) *netlink.Rule {
t.Helper()
rule := netlink.NewRule()
rule.Src = src
rule.Dst = dst
rule.Table = table
rule.Priority = priority
return rule
}
func Test_Routing_addIPRule(t *testing.T) {
t.Parallel()
errDummy := errors.New("dummy error")
type ruleListCall struct {
rules []netlink.Rule
err error
}
type ruleAddCall struct {
expected bool
ruleToAdd *netlink.Rule
err error
}
testCases := map[string]struct {
src *net.IPNet
dst *net.IPNet
table int
priority int
dbgMsg string
ruleList ruleListCall
ruleAdd ruleAddCall
err error
}{
"list error": {
dbgMsg: "ip rule add pref 0",
ruleList: ruleListCall{
err: errDummy,
},
err: errors.New("cannot list rules: dummy error"),
},
"rule already exists": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
table: 99,
priority: 99,
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99),
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
},
},
},
"add rule error": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
table: 99,
priority: 99,
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleAdd: ruleAddCall{
expected: true,
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
err: errDummy,
},
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 table 99"),
},
"add rule success": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
table: 99,
priority: 99,
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99),
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 101, 101),
},
},
ruleAdd: ruleAddCall{
expected: true,
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
logger := mock_logging.NewMockLogger(ctrl)
logger.EXPECT().Debug(testCase.dbgMsg)
netLinker := mock_netlink.NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL).
Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleAdd.expected {
netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd).
Return(testCase.ruleAdd.err)
}
r := Routing{
logger: logger,
netLinker: netLinker,
}
err := r.addIPRule(testCase.src, testCase.dst,
testCase.table, testCase.priority)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func Test_Routing_deleteIPRule(t *testing.T) {
t.Parallel()
errDummy := errors.New("dummy error")
type ruleListCall struct {
rules []netlink.Rule
err error
}
type ruleDelCall struct {
expected bool
ruleToDel *netlink.Rule
err error
}
testCases := map[string]struct {
src *net.IPNet
dst *net.IPNet
table int
priority int
dbgMsg string
ruleList ruleListCall
ruleDel ruleDelCall
err error
}{
"list error": {
dbgMsg: "ip rule del pref 0",
ruleList: ruleListCall{
err: errDummy,
},
err: errors.New("cannot list rules: dummy error"),
},
"rule delete error": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
table: 99,
priority: 99,
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
},
},
ruleDel: ruleDelCall{
expected: true,
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
err: errDummy,
},
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 table 99"),
},
"rule deleted": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
table: 99,
priority: 99,
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99),
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
},
},
ruleDel: ruleDelCall{
expected: true,
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
},
},
"rule does not exist": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
table: 99,
priority: 99,
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(t, makeIPNet(t, 2), makeIPNet(t, 2), 99, 99),
*makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 101, 101),
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
logger := mock_logging.NewMockLogger(ctrl)
logger.EXPECT().Debug(testCase.dbgMsg)
netLinker := mock_netlink.NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL).
Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleDel.expected {
netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel).
Return(testCase.ruleDel.err)
}
r := Routing{
logger: logger,
netLinker: netLinker,
}
err := r.deleteIPRule(testCase.src, testCase.dst,
testCase.table, testCase.priority)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func Test_ruleDbgMsg(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
add bool
src *net.IPNet
dst *net.IPNet
table int
priority int
dbgMsg string
}{
"default values": {
dbgMsg: "ip rule del pref 0",
},
"add rule": {
add: true,
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
table: 100,
priority: 101,
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
"del rule": {
src: makeIPNet(t, 1),
dst: makeIPNet(t, 2),
table: 100,
priority: 101,
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 100 pref 101",
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
dbgMsg := ruleDbgMsg(testCase.add, testCase.src,
testCase.dst, testCase.table, testCase.priority)
assert.Equal(t, testCase.dbgMsg, dbgMsg)
})
}
}