Maint: fix rules equality check for nil networks
This commit is contained in:
@@ -23,18 +23,15 @@ func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
|
|||||||
rule.Priority = priority
|
rule.Priority = priority
|
||||||
rule.Table = table
|
rule.Table = table
|
||||||
|
|
||||||
rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRulesList, err)
|
return fmt.Errorf("%w: %s", errRulesList, err)
|
||||||
}
|
}
|
||||||
for _, existingRule := range rules {
|
for i := range existingRules {
|
||||||
if existingRule.Src != nil &&
|
if !rulesAreEqual(&existingRules[i], rule) {
|
||||||
existingRule.Src.IP.Equal(rule.Src.IP) &&
|
continue
|
||||||
bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) &&
|
|
||||||
existingRule.Priority == rule.Priority &&
|
|
||||||
existingRule.Table == rule.Table {
|
|
||||||
return nil // already exists
|
|
||||||
}
|
}
|
||||||
|
return nil // already exists
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.netLinker.RuleAdd(rule); err != nil {
|
if err := r.netLinker.RuleAdd(rule); err != nil {
|
||||||
@@ -53,21 +50,18 @@ func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error {
|
|||||||
rule.Priority = priority
|
rule.Priority = priority
|
||||||
rule.Table = table
|
rule.Table = table
|
||||||
|
|
||||||
rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRulesList, err)
|
return fmt.Errorf("%w: %s", errRulesList, err)
|
||||||
}
|
}
|
||||||
for _, existingRule := range rules {
|
for i := range existingRules {
|
||||||
if existingRule.Src != nil &&
|
if !rulesAreEqual(&existingRules[i], rule) {
|
||||||
existingRule.Src.IP.Equal(rule.Src.IP) &&
|
continue
|
||||||
bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) &&
|
}
|
||||||
existingRule.Priority == rule.Priority &&
|
|
||||||
existingRule.Table == rule.Table {
|
|
||||||
if err := r.netLinker.RuleDel(rule); err != nil {
|
if err := r.netLinker.RuleDel(rule); err != nil {
|
||||||
return fmt.Errorf("%w: for rule: %s", err, rule)
|
return fmt.Errorf("%w: for rule: %s", err, rule)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,3 +93,27 @@ func ruleDbgMsg(add bool, src, dst *net.IPNet,
|
|||||||
|
|
||||||
return debugMessage
|
return debugMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rulesAreEqual(a, b *netlink.Rule) bool {
|
||||||
|
// fmt.Println(a, b)
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return ipNetsAreEqual(a.Src, b.Src) &&
|
||||||
|
ipNetsAreEqual(a.Dst, b.Dst) &&
|
||||||
|
a.Priority == b.Priority &&
|
||||||
|
a.Table == b.Table
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipNetsAreEqual(a, b *net.IPNet) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask)
|
||||||
|
}
|
||||||
|
|||||||
@@ -306,3 +306,125 @@ func Test_ruleDbgMsg(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_rulesAreEqual(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
a *netlink.Rule
|
||||||
|
b *netlink.Rule
|
||||||
|
equal bool
|
||||||
|
}{
|
||||||
|
"both nil": {
|
||||||
|
equal: true,
|
||||||
|
},
|
||||||
|
"first nil": {
|
||||||
|
b: &netlink.Rule{},
|
||||||
|
},
|
||||||
|
"second nil": {
|
||||||
|
a: &netlink.Rule{},
|
||||||
|
},
|
||||||
|
"both not nil": {
|
||||||
|
a: &netlink.Rule{},
|
||||||
|
b: &netlink.Rule{},
|
||||||
|
equal: true,
|
||||||
|
},
|
||||||
|
"both equal": {
|
||||||
|
a: &netlink.Rule{
|
||||||
|
Src: &net.IPNet{
|
||||||
|
IP: net.IPv4(1, 1, 1, 1),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
Priority: 100,
|
||||||
|
Table: 101,
|
||||||
|
},
|
||||||
|
b: &netlink.Rule{
|
||||||
|
Src: &net.IPNet{
|
||||||
|
IP: net.IPv4(1, 1, 1, 1),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
Priority: 100,
|
||||||
|
Table: 101,
|
||||||
|
},
|
||||||
|
equal: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
equal := rulesAreEqual(testCase.a, testCase.b)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.equal, equal)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_ipNetsAreEqual(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
a *net.IPNet
|
||||||
|
b *net.IPNet
|
||||||
|
equal bool
|
||||||
|
}{
|
||||||
|
"both nil": {
|
||||||
|
equal: true,
|
||||||
|
},
|
||||||
|
"first nil": {
|
||||||
|
b: &net.IPNet{},
|
||||||
|
},
|
||||||
|
"second nil": {
|
||||||
|
a: &net.IPNet{},
|
||||||
|
},
|
||||||
|
"both not nil": {
|
||||||
|
a: &net.IPNet{},
|
||||||
|
b: &net.IPNet{},
|
||||||
|
equal: true,
|
||||||
|
},
|
||||||
|
"both equal": {
|
||||||
|
a: &net.IPNet{
|
||||||
|
IP: net.IPv4(1, 1, 1, 1),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
b: &net.IPNet{
|
||||||
|
IP: net.IPv4(1, 1, 1, 1),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
equal: true,
|
||||||
|
},
|
||||||
|
"both not equal by IP": {
|
||||||
|
a: &net.IPNet{
|
||||||
|
IP: net.IPv4(1, 1, 1, 1),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
b: &net.IPNet{
|
||||||
|
IP: net.IPv4(2, 2, 2, 2),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"both not equal by mask": {
|
||||||
|
a: &net.IPNet{
|
||||||
|
IP: net.IPv4(1, 1, 1, 1),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||||
|
},
|
||||||
|
b: &net.IPNet{
|
||||||
|
IP: net.IPv4(1, 1, 1, 1),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 0, 0),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
equal := ipNetsAreEqual(testCase.a, testCase.b)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.equal, equal)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user