diff --git a/internal/routing/rules.go b/internal/routing/rules.go index 763ffa20..b74b0e56 100644 --- a/internal/routing/rules.go +++ b/internal/routing/rules.go @@ -23,18 +23,15 @@ func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error { rule.Priority = priority rule.Table = table - rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) + existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) if err != nil { return fmt.Errorf("%w: %s", errRulesList, err) } - for _, existingRule := range rules { - if existingRule.Src != nil && - existingRule.Src.IP.Equal(rule.Src.IP) && - bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) && - existingRule.Priority == rule.Priority && - existingRule.Table == rule.Table { - return nil // already exists + for i := range existingRules { + if !rulesAreEqual(&existingRules[i], rule) { + continue } + return nil // already exists } if err := r.netLinker.RuleAdd(rule); err != nil { @@ -53,19 +50,16 @@ func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error { rule.Priority = priority rule.Table = table - rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) + existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) if err != nil { return fmt.Errorf("%w: %s", errRulesList, err) } - for _, existingRule := range rules { - if existingRule.Src != nil && - existingRule.Src.IP.Equal(rule.Src.IP) && - bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) && - existingRule.Priority == rule.Priority && - existingRule.Table == rule.Table { - if err := r.netLinker.RuleDel(rule); err != nil { - return fmt.Errorf("%w: for rule: %s", err, rule) - } + for i := range existingRules { + if !rulesAreEqual(&existingRules[i], rule) { + continue + } + if err := r.netLinker.RuleDel(rule); err != nil { + return fmt.Errorf("%w: for rule: %s", err, rule) } } return nil @@ -99,3 +93,27 @@ func ruleDbgMsg(add bool, src, dst *net.IPNet, return debugMessage } + +func rulesAreEqual(a, b *netlink.Rule) bool { + // fmt.Println(a, b) + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return ipNetsAreEqual(a.Src, b.Src) && + ipNetsAreEqual(a.Dst, b.Dst) && + a.Priority == b.Priority && + a.Table == b.Table +} + +func ipNetsAreEqual(a, b *net.IPNet) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask) +} diff --git a/internal/routing/rules_test.go b/internal/routing/rules_test.go index 7a951803..8ddfb8a4 100644 --- a/internal/routing/rules_test.go +++ b/internal/routing/rules_test.go @@ -306,3 +306,125 @@ func Test_ruleDbgMsg(t *testing.T) { }) } } + +func Test_rulesAreEqual(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + a *netlink.Rule + b *netlink.Rule + equal bool + }{ + "both nil": { + equal: true, + }, + "first nil": { + b: &netlink.Rule{}, + }, + "second nil": { + a: &netlink.Rule{}, + }, + "both not nil": { + a: &netlink.Rule{}, + b: &netlink.Rule{}, + equal: true, + }, + "both equal": { + a: &netlink.Rule{ + Src: &net.IPNet{ + IP: net.IPv4(1, 1, 1, 1), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + Priority: 100, + Table: 101, + }, + b: &netlink.Rule{ + Src: &net.IPNet{ + IP: net.IPv4(1, 1, 1, 1), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + Priority: 100, + Table: 101, + }, + equal: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + equal := rulesAreEqual(testCase.a, testCase.b) + + assert.Equal(t, testCase.equal, equal) + }) + } +} + +func Test_ipNetsAreEqual(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + a *net.IPNet + b *net.IPNet + equal bool + }{ + "both nil": { + equal: true, + }, + "first nil": { + b: &net.IPNet{}, + }, + "second nil": { + a: &net.IPNet{}, + }, + "both not nil": { + a: &net.IPNet{}, + b: &net.IPNet{}, + equal: true, + }, + "both equal": { + a: &net.IPNet{ + IP: net.IPv4(1, 1, 1, 1), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + b: &net.IPNet{ + IP: net.IPv4(1, 1, 1, 1), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + equal: true, + }, + "both not equal by IP": { + a: &net.IPNet{ + IP: net.IPv4(1, 1, 1, 1), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + b: &net.IPNet{ + IP: net.IPv4(2, 2, 2, 2), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, + "both not equal by mask": { + a: &net.IPNet{ + IP: net.IPv4(1, 1, 1, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + b: &net.IPNet{ + IP: net.IPv4(1, 1, 1, 1), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + equal := ipNetsAreEqual(testCase.a, testCase.b) + + assert.Equal(t, testCase.equal, equal) + }) + } +}