fix(firewall): delete chain rules by line number (#2411)
- Fix #2334 - Parsing of iptables chains, contributing to progress for #1856
This commit is contained in:
98
internal/firewall/delete.go
Normal file
98
internal/firewall/delete.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isDeleteMatchInstruction returns true if the iptables instruction
|
||||
// is a delete instruction by rule matching. It returns false if the
|
||||
// instruction is a delete instruction by line number, or not a delete
|
||||
// instruction.
|
||||
func isDeleteMatchInstruction(instruction string) bool {
|
||||
fields := strings.Fields(instruction)
|
||||
for i, field := range fields {
|
||||
switch {
|
||||
case field != "-D" && field != "--delete": //nolint:goconst
|
||||
continue
|
||||
case i == len(fields)-1: // malformed: missing chain name
|
||||
return false
|
||||
case i == len(fields)-2: // chain name is last field
|
||||
return true
|
||||
default:
|
||||
// chain name is fields[i+1]
|
||||
const base, bitLength = 10, 16
|
||||
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
|
||||
return err != nil // not a line number
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
|
||||
runner Runner, logger Logger) (err error) {
|
||||
targetRule, err := parseIptablesInstruction(instruction)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing iptables command: %w", err)
|
||||
}
|
||||
|
||||
lineNumber, err := findLineNumber(ctx, iptablesBinary,
|
||||
targetRule, runner, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding iptables chain rule line number: %w", err)
|
||||
} else if lineNumber == 0 {
|
||||
logger.Debug("rule matching \"" + instruction + "\" not found")
|
||||
return nil
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
|
||||
instruction, lineNumber))
|
||||
|
||||
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
|
||||
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
|
||||
logger.Debug(cmd.String())
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||
if output != "" {
|
||||
err = fmt.Errorf("%w: %s", err, output)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findLineNumber finds the line number of an iptables rule.
|
||||
// It returns 0 if the rule is not found.
|
||||
func findLineNumber(ctx context.Context, iptablesBinary string,
|
||||
instruction iptablesInstruction, runner Runner, logger Logger) (
|
||||
lineNumber uint16, err error) {
|
||||
listFlags := []string{"-t", instruction.table, "-L", instruction.chain,
|
||||
"--line-numbers", "-n", "-v"}
|
||||
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
|
||||
logger.Debug(cmd.String())
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("command failed: %q: %w", cmd, err)
|
||||
if output != "" {
|
||||
err = fmt.Errorf("%w: %s", err, output)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
chain, err := parseChain(output)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing chain list: %w", err)
|
||||
}
|
||||
|
||||
for _, rule := range chain.rules {
|
||||
if instruction.equalToRule(instruction.table, chain.name, rule) {
|
||||
return rule.lineNumber, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
188
internal/firewall/delete_test.go
Normal file
188
internal/firewall/delete_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_isDeleteMatchInstruction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
isDeleteMatch bool
|
||||
}{
|
||||
"not_delete": {
|
||||
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
|
||||
},
|
||||
"malformed_missing_chain_name": {
|
||||
instruction: "-t nat -D",
|
||||
},
|
||||
"delete_chain_name_last_field": {
|
||||
instruction: "-t nat --delete PREROUTING",
|
||||
isDeleteMatch: true,
|
||||
},
|
||||
"delete_match": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
|
||||
isDeleteMatch: true,
|
||||
},
|
||||
"delete_line_number_last_field": {
|
||||
instruction: "-t nat -D PREROUTING 2",
|
||||
},
|
||||
"delete_line_number": {
|
||||
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
|
||||
},
|
||||
}
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
|
||||
|
||||
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
|
||||
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
|
||||
"^--line-numbers$", "^-n$", "^-v$")
|
||||
}
|
||||
|
||||
func Test_deleteIPTablesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const iptablesBinary = "/sbin/iptables"
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
instruction string
|
||||
makeRunner func(ctrl *gomock.Controller) *MockRunner
|
||||
makeLogger func(ctrl *gomock.Controller) *MockLogger
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"invalid_instruction": {
|
||||
instruction: "invalid",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing iptables command: iptables command is malformed: " +
|
||||
"fields count 1 is not even: \"invalid\"",
|
||||
},
|
||||
"list_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().
|
||||
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: `finding iptables chain rule line number: command failed: ` +
|
||||
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
|
||||
},
|
||||
"rule_not_found": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
|
||||
num pkts bytes target prot opt in out source destination
|
||||
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
|
||||
nil)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
|
||||
return logger
|
||||
},
|
||||
},
|
||||
"rule_found_delete_error": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||
"num pkts bytes target prot opt in out source destination \n"+
|
||||
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
errWrapped: errTest,
|
||||
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
|
||||
},
|
||||
"rule_found_delete_success": {
|
||||
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
|
||||
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
|
||||
"num pkts bytes target prot opt in out source destination \n"+
|
||||
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
|
||||
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
|
||||
nil)
|
||||
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
|
||||
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
|
||||
return runner
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
|
||||
logger := NewMockLogger(ctrl)
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
|
||||
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
|
||||
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
|
||||
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
|
||||
return logger
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
instruction := testCase.instruction
|
||||
var runner *MockRunner
|
||||
if testCase.makeRunner != nil {
|
||||
runner = testCase.makeRunner(ctrl)
|
||||
}
|
||||
var logger *MockLogger
|
||||
if testCase.makeLogger != nil {
|
||||
logger = testCase.makeLogger(ctrl)
|
||||
}
|
||||
|
||||
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
|
||||
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
13
internal/firewall/interfaces.go
Normal file
13
internal/firewall/interfaces.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package firewall
|
||||
|
||||
import "github.com/qdm12/golibs/command"
|
||||
|
||||
type Runner interface {
|
||||
Run(cmd command.ExecCmd) (output string, err error)
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -40,10 +40,14 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
|
||||
c.ip6tablesMutex.Lock() // only one ip6tables command at once
|
||||
defer c.ip6tablesMutex.Unlock()
|
||||
|
||||
c.logger.Debug(c.ip6Tables + " " + instruction)
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ip6Tables, instruction, output, err)
|
||||
@@ -55,7 +59,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
case "ACCEPT", "DROP": //nolint:goconst
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
}
|
||||
|
||||
@@ -70,10 +70,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
|
||||
c.logger.Debug(c.ipTables + " " + instruction)
|
||||
if isDeleteMatchInstruction(instruction) {
|
||||
return deleteIPTablesRule(ctx, c.ipTables, instruction,
|
||||
c.runner, c.logger)
|
||||
}
|
||||
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
|
||||
c.logger.Debug(cmd.String())
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
|
||||
c.ipTables, instruction, output, err)
|
||||
@@ -143,7 +147,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
|
||||
defaultInterface string, connection models.Connection, remove bool) error {
|
||||
protocol := connection.Protocol
|
||||
if protocol == "tcp-client" {
|
||||
protocol = "tcp"
|
||||
protocol = "tcp" //nolint:goconst
|
||||
}
|
||||
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
|
||||
|
||||
381
internal/firewall/list.go
Normal file
381
internal/firewall/list.go
Normal file
@@ -0,0 +1,381 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type chain struct {
|
||||
name string
|
||||
policy string
|
||||
packets uint64
|
||||
bytes uint64
|
||||
rules []chainRule
|
||||
}
|
||||
|
||||
type chainRule struct {
|
||||
lineNumber uint16 // starts from 1 and cannot be zero.
|
||||
packets uint64
|
||||
bytes uint64
|
||||
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
|
||||
protocol string // "tcp", "udp" or "" for all protocols.
|
||||
inputInterface string // input interface, for example "tun0" or "*""
|
||||
outputInterface string // output interface, for example "eth0" or "*""
|
||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
}
|
||||
|
||||
var (
|
||||
ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
)
|
||||
|
||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
// Text example:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// pkts bytes target prot opt in out source destination
|
||||
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
|
||||
iptablesOutput = strings.TrimSpace(iptablesOutput)
|
||||
linesWithComments := strings.Split(iptablesOutput, "\n")
|
||||
|
||||
// Filter out lines starting with a '#' character
|
||||
lines := make([]string, 0, len(linesWithComments))
|
||||
for _, line := range linesWithComments {
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, line)
|
||||
}
|
||||
|
||||
const minLines = 2 // chain general information line + legend line
|
||||
if len(lines) < minLines {
|
||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||
ErrChainListMalformed, iptablesOutput)
|
||||
}
|
||||
|
||||
c, err = parseChainGeneralDataLine(lines[0])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check for the legend line
|
||||
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
|
||||
legendLine := strings.TrimSpace(lines[1])
|
||||
legendFields := strings.Fields(legendLine)
|
||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||
}
|
||||
|
||||
lines = lines[2:] // remove chain general information line and legend line
|
||||
if len(lines) == 0 {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
c.rules = make([]chainRule, len(lines))
|
||||
for i, line := range lines {
|
||||
c.rules[i], err = parseChainRuleLine(line)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// parseChainGeneralDataLine parses the first line of iptables chain list output.
|
||||
// For example, it can parse the following line:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// It returns a chain struct with the parsed data.
|
||||
func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
runesToRemove := []rune{'(', ')', ','}
|
||||
for _, r := range runesToRemove {
|
||||
line = strings.ReplaceAll(line, string(r), "")
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
const expectedNumberOfFields = 8
|
||||
if len(fields) != expectedNumberOfFields {
|
||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
indexToExpectedValue := map[int]string{
|
||||
0: "Chain",
|
||||
2: "policy",
|
||||
5: "packets",
|
||||
7: "bytes",
|
||||
}
|
||||
for index, expectedValue := range indexToExpectedValue {
|
||||
if fields[index] == expectedValue {
|
||||
continue
|
||||
}
|
||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||
ErrChainListMalformed, expectedValue, index, line)
|
||||
}
|
||||
|
||||
base.name = fields[1] // chain name could be custom
|
||||
base.policy = fields[3]
|
||||
err = checkTarget(base.policy)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
|
||||
}
|
||||
|
||||
packets, err := parseMetricSize(fields[4])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
base.packets = packets
|
||||
|
||||
bytes, err := parseMetricSize(fields[6])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
base.bytes = bytes
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||
)
|
||||
|
||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
|
||||
const minFields = 10
|
||||
if len(fields) < minFields {
|
||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
for fieldIndex, field := range fields[:minFields] {
|
||||
err = parseChainRuleField(fieldIndex, field, &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(fields) > minFields {
|
||||
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||
if field == "" {
|
||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||
}
|
||||
|
||||
const (
|
||||
numIndex = iota
|
||||
packetsIndex
|
||||
bytesIndex
|
||||
targetIndex
|
||||
protocolIndex
|
||||
optIndex
|
||||
inputInterfaceIndex
|
||||
outputInterfaceIndex
|
||||
sourceIndex
|
||||
destinationIndex
|
||||
)
|
||||
|
||||
switch fieldIndex {
|
||||
case numIndex:
|
||||
rule.lineNumber, err = parseLineNumber(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing line number: %w", err)
|
||||
}
|
||||
case packetsIndex:
|
||||
rule.packets, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
case bytesIndex:
|
||||
rule.bytes, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
case targetIndex:
|
||||
err = checkTarget(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking target: %w", err)
|
||||
}
|
||||
rule.target = field
|
||||
case protocolIndex:
|
||||
rule.protocol, err = parseProtocol(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing protocol: %w", err)
|
||||
}
|
||||
case optIndex: // ignored
|
||||
case inputInterfaceIndex:
|
||||
rule.inputInterface = field
|
||||
case outputInterfaceIndex:
|
||||
rule.outputInterface = field
|
||||
case sourceIndex:
|
||||
rule.source, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case destinationIndex:
|
||||
rule.destination, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||
for i := 0; i < len(optionalFields); i++ {
|
||||
key := optionalFields[i]
|
||||
switch key {
|
||||
case "tcp", "udp":
|
||||
i++
|
||||
value := optionalFields[i]
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
||||
}
|
||||
rule.destinationPort = uint16(destinationPort)
|
||||
case "redir":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
case "ports":
|
||||
i++
|
||||
ports, err := parsePortsCSV(optionalFields[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||
}
|
||||
rule.redirPorts = ports
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fields := strings.Split(s, ",")
|
||||
ports = make([]uint16, len(fields))
|
||||
for i, field := range fields {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(field, base, bitLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
)
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if lineNumber == 0 {
|
||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||
}
|
||||
return uint16(lineNumber), nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTargetUnknown = errors.New("unknown target")
|
||||
)
|
||||
|
||||
func checkTarget(target string) (err error) {
|
||||
switch target {
|
||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrProtocolUnknown = errors.New("unknown protocol")
|
||||
)
|
||||
|
||||
func parseProtocol(s string) (protocol string, err error) {
|
||||
switch s {
|
||||
case "0":
|
||||
case "6":
|
||||
protocol = "tcp"
|
||||
case "17":
|
||||
protocol = "udp"
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||
)
|
||||
|
||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||
// returns the raw integer matching it.
|
||||
func parseMetricSize(size string) (n uint64, err error) {
|
||||
if size == "" {
|
||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||
}
|
||||
|
||||
//nolint:gomnd
|
||||
multiplerLetterToValue := map[byte]uint64{
|
||||
'K': 1000,
|
||||
'M': 1000000,
|
||||
'G': 1000000000,
|
||||
'T': 1000000000000,
|
||||
}
|
||||
|
||||
lastCharacter := size[len(size)-1]
|
||||
multiplier, ok := multiplerLetterToValue[lastCharacter]
|
||||
if ok { // multiplier present
|
||||
size = size[:len(size)-1]
|
||||
} else {
|
||||
multiplier = 1
|
||||
}
|
||||
|
||||
const base, bitLength = 10, 64
|
||||
n, err = strconv.ParseUint(size, base, bitLength)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||
}
|
||||
n *= multiplier
|
||||
return n, nil
|
||||
}
|
||||
121
internal/firewall/list_test.go
Normal file
121
internal/firewall/list_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseChain(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
iptablesOutput string
|
||||
table chain
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_output": {
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
|
||||
},
|
||||
"single_line_only": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
|
||||
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
|
||||
},
|
||||
"malformed_general_data_line": {
|
||||
iptablesOutput: `Chain INPUT
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
|
||||
"expected 8 fields in \"Chain INPUT\"",
|
||||
},
|
||||
"malformed_legend": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source`,
|
||||
errWrapped: ErrChainListMalformed,
|
||||
errMessage: "iptables chain list output is malformed: legend " +
|
||||
"\"num pkts bytes target prot opt in out source\" " +
|
||||
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
|
||||
},
|
||||
"no_rule": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source destination`,
|
||||
table: chain{
|
||||
name: "INPUT",
|
||||
policy: "ACCEPT",
|
||||
packets: 140000,
|
||||
bytes: 226000000,
|
||||
},
|
||||
},
|
||||
"some_rules": {
|
||||
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
num pkts bytes target prot opt in out source destination
|
||||
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
3 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
|
||||
`,
|
||||
table: chain{
|
||||
name: "INPUT",
|
||||
policy: "ACCEPT",
|
||||
packets: 140000,
|
||||
bytes: 226000000,
|
||||
rules: []chainRule{
|
||||
{
|
||||
lineNumber: 1,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "udp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destinationPort: 55405,
|
||||
},
|
||||
{
|
||||
lineNumber: 2,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "ACCEPT",
|
||||
protocol: "tcp",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
destinationPort: 55405,
|
||||
},
|
||||
{
|
||||
lineNumber: 3,
|
||||
packets: 0,
|
||||
bytes: 0,
|
||||
target: "DROP",
|
||||
protocol: "",
|
||||
inputInterface: "tun0",
|
||||
outputInterface: "*",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
table, err := parseChain(testCase.iptablesOutput)
|
||||
|
||||
assert.Equal(t, testCase.table, table)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,12 +5,6 @@ import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) {
|
||||
c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+
|
||||
"no default route matching its family", subnet))
|
||||
|
||||
3
internal/firewall/mocks_generate_test.go
Normal file
3
internal/firewall/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package firewall
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger
|
||||
109
internal/firewall/mocks_test.go
Normal file
109
internal/firewall/mocks_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger)
|
||||
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
command "github.com/qdm12/golibs/command"
|
||||
)
|
||||
|
||||
// MockRunner is a mock of Runner interface.
|
||||
type MockRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
||||
type MockRunnerMockRecorder struct {
|
||||
mock *MockRunner
|
||||
}
|
||||
|
||||
// NewMockRunner creates a new mock instance.
|
||||
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
||||
mock := &MockRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
||||
}
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockLogger) Info(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Info", arg0)
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||
}
|
||||
166
internal/firewall/parse.go
Normal file
166
internal/firewall/parse.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type iptablesInstruction struct {
|
||||
table string // defaults to "filter", and can be "nat" for example.
|
||||
append bool
|
||||
chain string // for example INPUT, PREROUTING. Cannot be empty.
|
||||
target string // for example ACCEPT. Can be empty.
|
||||
protocol string // "tcp" or "udp" or "" for all protocols.
|
||||
inputInterface string // for example "tun0" or "" for any interface.
|
||||
outputInterface string // for example "tun0" or "" for any interface.
|
||||
source netip.Prefix // if not valid, then it is unspecified.
|
||||
destination netip.Prefix // if not valid, then it is unspecified.
|
||||
destinationPort uint16 // if zero, there is no destination port
|
||||
toPorts []uint16 // if empty, there is no redirection
|
||||
ctstate []string // if empty, there is no ctstate
|
||||
}
|
||||
|
||||
func (i *iptablesInstruction) setDefaults() {
|
||||
if i.table == "" {
|
||||
i.table = "filter"
|
||||
}
|
||||
}
|
||||
|
||||
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
|
||||
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
|
||||
switch {
|
||||
case i.table != table:
|
||||
return false
|
||||
case i.chain != chain:
|
||||
return false
|
||||
case i.target != rule.target:
|
||||
return false
|
||||
case i.protocol != rule.protocol:
|
||||
return false
|
||||
case i.destinationPort != rule.destinationPort:
|
||||
return false
|
||||
case !slices.Equal(i.toPorts, rule.redirPorts):
|
||||
return false
|
||||
case !slices.Equal(i.ctstate, rule.ctstate):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
|
||||
return false
|
||||
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.source, rule.source):
|
||||
return false
|
||||
case !ipPrefixesEqual(i.destination, rule.destination):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// instruction can be "" which equivalent to the "*" chain rule interface.
|
||||
func networkInterfacesEqual(instruction, chainRule string) bool {
|
||||
return instruction == chainRule || (instruction == "" && chainRule == "*")
|
||||
}
|
||||
|
||||
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
|
||||
return instruction == chainRule ||
|
||||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
|
||||
}
|
||||
|
||||
var (
|
||||
ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
|
||||
)
|
||||
|
||||
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
|
||||
if s == "" {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
|
||||
}
|
||||
fields := strings.Fields(s)
|
||||
if len(fields)%2 != 0 {
|
||||
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
|
||||
ErrIptablesCommandMalformed, len(fields), s)
|
||||
}
|
||||
|
||||
for i := 0; i < len(fields); i += 2 {
|
||||
key := fields[i]
|
||||
value := fields[i+1]
|
||||
err = parseInstructionFlag(key, value, &instruction)
|
||||
if err != nil {
|
||||
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
|
||||
}
|
||||
}
|
||||
|
||||
instruction.setDefaults()
|
||||
return instruction, nil
|
||||
}
|
||||
|
||||
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
|
||||
switch key {
|
||||
case "-t", "--table":
|
||||
instruction.table = value
|
||||
case "-D", "--delete":
|
||||
instruction.append = false
|
||||
instruction.chain = value
|
||||
case "-A", "--append":
|
||||
instruction.append = true
|
||||
instruction.chain = value
|
||||
case "-j", "--jump":
|
||||
instruction.target = value
|
||||
case "-p", "--protocol":
|
||||
instruction.protocol = value
|
||||
case "-m", "--match": // ignore match
|
||||
case "-i", "--in-interface":
|
||||
instruction.inputInterface = value
|
||||
case "-o", "--out-interface":
|
||||
instruction.outputInterface = value
|
||||
case "-s", "--source":
|
||||
instruction.source, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case "-d", "--destination":
|
||||
instruction.destination, err = parseIPPrefix(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
case "--dport":
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port: %w", err)
|
||||
}
|
||||
instruction.destinationPort = uint16(destinationPort)
|
||||
case "--ctstate":
|
||||
instruction.ctstate = strings.Split(value, ",")
|
||||
case "--to-ports":
|
||||
portStrings := strings.Split(value, ",")
|
||||
instruction.toPorts = make([]uint16, len(portStrings))
|
||||
for i, portString := range portStrings {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(portString, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing port redirection: %w", err)
|
||||
}
|
||||
instruction.toPorts[i] = uint16(port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
|
||||
slashIndex := strings.Index(value, "/")
|
||||
if slashIndex >= 0 {
|
||||
return netip.ParsePrefix(value)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
|
||||
}
|
||||
return netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
138
internal/firewall/parse_test.go
Normal file
138
internal/firewall/parse_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseIptablesInstruction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
s string
|
||||
instruction iptablesInstruction
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no_instruction": {
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: empty instruction",
|
||||
},
|
||||
"uneven_fields": {
|
||||
s: "-A",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
|
||||
},
|
||||
"unknown_key": {
|
||||
s: "-x something",
|
||||
errWrapped: ErrIptablesCommandMalformed,
|
||||
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
|
||||
},
|
||||
"one_pair": {
|
||||
s: "-A INPUT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
},
|
||||
},
|
||||
"instruction_A": {
|
||||
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
|
||||
instruction: iptablesInstruction{
|
||||
table: "filter",
|
||||
chain: "INPUT",
|
||||
append: true,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
source: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
destination: netip.MustParsePrefix("5.6.7.8/32"),
|
||||
destinationPort: 10000,
|
||||
target: "ACCEPT",
|
||||
},
|
||||
},
|
||||
"nat_redirection": {
|
||||
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
|
||||
instruction: iptablesInstruction{
|
||||
table: "nat",
|
||||
chain: "PREROUTING",
|
||||
append: false,
|
||||
inputInterface: "tun0",
|
||||
protocol: "tcp",
|
||||
destinationPort: 43716,
|
||||
target: "REDIRECT",
|
||||
toPorts: []uint16{5678},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rule, err := parseIptablesInstruction(testCase.s)
|
||||
|
||||
assert.Equal(t, testCase.instruction, rule)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseIPPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
value string
|
||||
prefix netip.Prefix
|
||||
errMessage string
|
||||
}{
|
||||
"empty": {
|
||||
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
|
||||
},
|
||||
"invalid": {
|
||||
value: "invalid",
|
||||
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
|
||||
},
|
||||
"valid_ipv4_with_bits": {
|
||||
value: "10.0.0.0/16",
|
||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
|
||||
},
|
||||
"valid_ipv4_without_bits": {
|
||||
value: "10.0.0.4",
|
||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
|
||||
},
|
||||
"valid_ipv6_with_bits": {
|
||||
value: "2001:db8::/32",
|
||||
prefix: netip.PrefixFrom(
|
||||
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||
32),
|
||||
},
|
||||
"valid_ipv6_without_bits": {
|
||||
value: "2001:db8::",
|
||||
prefix: netip.PrefixFrom(
|
||||
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
|
||||
128),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prefix, err := parseIPPrefix(testCase.value)
|
||||
|
||||
assert.Equal(t, testCase.prefix, prefix)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/golibs/command (interfaces: Runner)
|
||||
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
command "github.com/qdm12/golibs/command"
|
||||
)
|
||||
|
||||
// MockRunner is a mock of Runner interface.
|
||||
type MockRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
||||
type MockRunnerMockRecorder struct {
|
||||
mock *MockRunner
|
||||
}
|
||||
|
||||
// NewMockRunner creates a new mock instance.
|
||||
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
||||
mock := &MockRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
||||
}
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner
|
||||
|
||||
func newAppendTestRuleMatcher(path string) *cmdMatcher {
|
||||
return newCmdMatcher(path,
|
||||
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
||||
|
||||
Reference in New Issue
Block a user