fix(firewall): delete chain rules by line number (#2411)

- Fix #2334
- Parsing of iptables chains, contributing to progress for #1856
This commit is contained in:
Quentin McGaw
2024-08-17 20:12:22 +02:00
parent 09c47c740c
commit c33158c13c
14 changed files with 1229 additions and 62 deletions

View File

@@ -0,0 +1,98 @@
package firewall
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
)
// isDeleteMatchInstruction returns true if the iptables instruction
// is a delete instruction by rule matching. It returns false if the
// instruction is a delete instruction by line number, or not a delete
// instruction.
func isDeleteMatchInstruction(instruction string) bool {
fields := strings.Fields(instruction)
for i, field := range fields {
switch {
case field != "-D" && field != "--delete": //nolint:goconst
continue
case i == len(fields)-1: // malformed: missing chain name
return false
case i == len(fields)-2: // chain name is last field
return true
default:
// chain name is fields[i+1]
const base, bitLength = 10, 16
_, err := strconv.ParseUint(fields[i+2], base, bitLength)
return err != nil // not a line number
}
}
return false
}
func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
runner Runner, logger Logger) (err error) {
targetRule, err := parseIptablesInstruction(instruction)
if err != nil {
return fmt.Errorf("parsing iptables command: %w", err)
}
lineNumber, err := findLineNumber(ctx, iptablesBinary,
targetRule, runner, logger)
if err != nil {
return fmt.Errorf("finding iptables chain rule line number: %w", err)
} else if lineNumber == 0 {
logger.Debug("rule matching \"" + instruction + "\" not found")
return nil
}
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
instruction, lineNumber))
cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return err
}
return nil
}
// findLineNumber finds the line number of an iptables rule.
// It returns 0 if the rule is not found.
func findLineNumber(ctx context.Context, iptablesBinary string,
instruction iptablesInstruction, runner Runner, logger Logger) (
lineNumber uint16, err error) {
listFlags := []string{"-t", instruction.table, "-L", instruction.chain,
"--line-numbers", "-n", "-v"}
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return 0, err
}
chain, err := parseChain(output)
if err != nil {
return 0, fmt.Errorf("parsing chain list: %w", err)
}
for _, rule := range chain.rules {
if instruction.equalToRule(instruction.table, chain.name, rule) {
return rule.lineNumber, nil
}
}
return 0, nil
}

View File

@@ -0,0 +1,188 @@
package firewall
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func Test_isDeleteMatchInstruction(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
instruction string
isDeleteMatch bool
}{
"not_delete": {
instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT",
},
"malformed_missing_chain_name": {
instruction: "-t nat -D",
},
"delete_chain_name_last_field": {
instruction: "-t nat --delete PREROUTING",
isDeleteMatch: true,
},
"delete_match": {
instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT",
isDeleteMatch: true,
},
"delete_line_number_last_field": {
instruction: "-t nat -D PREROUTING 2",
},
"delete_line_number": {
instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT",
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
isDeleteMatch := isDeleteMatchInstruction(testCase.instruction)
assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch)
})
}
}
func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
"^--line-numbers$", "^-n$", "^-v$")
}
func Test_deleteIPTablesRule(t *testing.T) {
t.Parallel()
const iptablesBinary = "/sbin/iptables"
errTest := errors.New("test error")
testCases := map[string]struct {
instruction string
makeRunner func(ctrl *gomock.Controller) *MockRunner
makeLogger func(ctrl *gomock.Controller) *MockLogger
errWrapped error
errMessage string
}{
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: iptables command is malformed: " +
"fields count 1 is not even: \"invalid\"",
},
"list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
return logger
},
errWrapped: errTest,
errMessage: `finding iptables chain rule line number: command failed: ` +
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
},
"rule_not_found": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
num pkts bytes target prot opt in out source destination
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
return logger
},
},
"rule_found_delete_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
errWrapped: errTest,
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
},
"rule_found_delete_success": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
ctx := context.Background()
instruction := testCase.instruction
var runner *MockRunner
if testCase.makeRunner != nil {
runner = testCase.makeRunner(ctrl)
}
var logger *MockLogger
if testCase.makeLogger != nil {
logger = testCase.makeLogger(ctrl)
}
err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -0,0 +1,13 @@
package firewall
import "github.com/qdm12/golibs/command"
type Runner interface {
Run(cmd command.ExecCmd) (output string, err error)
}
type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}

View File

@@ -40,10 +40,14 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()
c.logger.Debug(c.ip6Tables + " " + instruction)
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
c.runner, c.logger)
}
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ip6Tables, instruction, output, err)
@@ -55,7 +59,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
case "ACCEPT", "DROP": //nolint:goconst
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
}

View File

@@ -70,10 +70,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
c.logger.Debug(c.ipTables + " " + instruction)
if isDeleteMatchInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
}
flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ipTables, instruction, output, err)
@@ -143,7 +147,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool) error {
protocol := connection.Protocol
if protocol == "tcp-client" {
protocol = "tcp"
protocol = "tcp" //nolint:goconst
}
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, protocol,

381
internal/firewall/list.go Normal file
View File

@@ -0,0 +1,381 @@
package firewall
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type chain struct {
name string
policy string
packets uint64
bytes uint64
rules []chainRule
}
type chainRule struct {
lineNumber uint16 // starts from 1 and cannot be zero.
packets uint64
bytes uint64
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
protocol string // "tcp", "udp" or "" for all protocols.
inputInterface string // input interface, for example "tun0" or "*""
outputInterface string // output interface, for example "eth0" or "*""
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
destinationPort uint16 // Not specified if set to zero.
redirPorts []uint16 // Not specified if empty.
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
}
var (
ErrChainListMalformed = errors.New("iptables chain list output is malformed")
)
func parseChain(iptablesOutput string) (c chain, err error) {
// Text example:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
// pkts bytes target prot opt in out source destination
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
iptablesOutput = strings.TrimSpace(iptablesOutput)
linesWithComments := strings.Split(iptablesOutput, "\n")
// Filter out lines starting with a '#' character
lines := make([]string, 0, len(linesWithComments))
for _, line := range linesWithComments {
if strings.HasPrefix(line, "#") {
continue
}
lines = append(lines, line)
}
const minLines = 2 // chain general information line + legend line
if len(lines) < minLines {
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
ErrChainListMalformed, iptablesOutput)
}
c, err = parseChainGeneralDataLine(lines[0])
if err != nil {
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
}
// Sanity check for the legend line
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
legendLine := strings.TrimSpace(lines[1])
legendFields := strings.Fields(legendLine)
if !slices.Equal(expectedLegendFields, legendFields) {
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
}
lines = lines[2:] // remove chain general information line and legend line
if len(lines) == 0 {
return c, nil
}
c.rules = make([]chainRule, len(lines))
for i, line := range lines {
c.rules[i], err = parseChainRuleLine(line)
if err != nil {
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
}
}
return c, nil
}
// parseChainGeneralDataLine parses the first line of iptables chain list output.
// For example, it can parse the following line:
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
// It returns a chain struct with the parsed data.
func parseChainGeneralDataLine(line string) (base chain, err error) {
line = strings.TrimSpace(line)
runesToRemove := []rune{'(', ')', ','}
for _, r := range runesToRemove {
line = strings.ReplaceAll(line, string(r), "")
}
fields := strings.Fields(line)
const expectedNumberOfFields = 8
if len(fields) != expectedNumberOfFields {
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
ErrChainListMalformed, expectedNumberOfFields, line)
}
// Sanity checks
indexToExpectedValue := map[int]string{
0: "Chain",
2: "policy",
5: "packets",
7: "bytes",
}
for index, expectedValue := range indexToExpectedValue {
if fields[index] == expectedValue {
continue
}
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
ErrChainListMalformed, expectedValue, index, line)
}
base.name = fields[1] // chain name could be custom
base.policy = fields[3]
err = checkTarget(base.policy)
if err != nil {
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
}
packets, err := parseMetricSize(fields[4])
if err != nil {
return chain{}, fmt.Errorf("parsing packets: %w", err)
}
base.packets = packets
bytes, err := parseMetricSize(fields[6])
if err != nil {
return chain{}, fmt.Errorf("parsing bytes: %w", err)
}
base.bytes = bytes
return base, nil
}
var (
ErrChainRuleMalformed = errors.New("chain rule is malformed")
)
func parseChainRuleLine(line string) (rule chainRule, err error) {
line = strings.TrimSpace(line)
if line == "" {
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
}
fields := strings.Fields(line)
const minFields = 10
if len(fields) < minFields {
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
}
for fieldIndex, field := range fields[:minFields] {
err = parseChainRuleField(fieldIndex, field, &rule)
if err != nil {
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
}
}
if len(fields) > minFields {
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
if err != nil {
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
}
}
return rule, nil
}
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
if field == "" {
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
}
const (
numIndex = iota
packetsIndex
bytesIndex
targetIndex
protocolIndex
optIndex
inputInterfaceIndex
outputInterfaceIndex
sourceIndex
destinationIndex
)
switch fieldIndex {
case numIndex:
rule.lineNumber, err = parseLineNumber(field)
if err != nil {
return fmt.Errorf("parsing line number: %w", err)
}
case packetsIndex:
rule.packets, err = parseMetricSize(field)
if err != nil {
return fmt.Errorf("parsing packets: %w", err)
}
case bytesIndex:
rule.bytes, err = parseMetricSize(field)
if err != nil {
return fmt.Errorf("parsing bytes: %w", err)
}
case targetIndex:
err = checkTarget(field)
if err != nil {
return fmt.Errorf("checking target: %w", err)
}
rule.target = field
case protocolIndex:
rule.protocol, err = parseProtocol(field)
if err != nil {
return fmt.Errorf("parsing protocol: %w", err)
}
case optIndex: // ignored
case inputInterfaceIndex:
rule.inputInterface = field
case outputInterfaceIndex:
rule.outputInterface = field
case sourceIndex:
rule.source, err = parseIPPrefix(field)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case destinationIndex:
rule.destination, err = parseIPPrefix(field)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
}
return nil
}
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
for i := 0; i < len(optionalFields); i++ {
key := optionalFields[i]
switch key {
case "tcp", "udp":
i++
value := optionalFields[i]
value = strings.TrimPrefix(value, "dpt:")
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing destination port %q: %w", value, err)
}
rule.destinationPort = uint16(destinationPort)
case "redir":
i++
switch optionalFields[i] {
case "ports":
i++
ports, err := parsePortsCSV(optionalFields[i])
if err != nil {
return fmt.Errorf("parsing redirection ports: %w", err)
}
rule.redirPorts = ports
default:
return fmt.Errorf("%w: unexpected optional field: %s",
ErrChainRuleMalformed, optionalFields[i])
}
case "ctstate":
i++
rule.ctstate = strings.Split(optionalFields[i], ",")
default:
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
}
}
return nil
}
func parsePortsCSV(s string) (ports []uint16, err error) {
if s == "" {
return nil, nil
}
fields := strings.Split(s, ",")
ports = make([]uint16, len(fields))
for i, field := range fields {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(field, base, bitLength)
if err != nil {
return nil, fmt.Errorf("parsing port %q: %w", field, err)
}
ports[i] = uint16(port)
}
return ports, nil
}
var (
ErrLineNumberIsZero = errors.New("line number is zero")
)
func parseLineNumber(s string) (n uint16, err error) {
const base, bitLength = 10, 16
lineNumber, err := strconv.ParseUint(s, base, bitLength)
if err != nil {
return 0, err
} else if lineNumber == 0 {
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
}
return uint16(lineNumber), nil
}
var (
ErrTargetUnknown = errors.New("unknown target")
)
func checkTarget(target string) (err error) {
switch target {
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
return nil
}
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
}
var (
ErrProtocolUnknown = errors.New("unknown protocol")
)
func parseProtocol(s string) (protocol string, err error) {
switch s {
case "0":
case "6":
protocol = "tcp"
case "17":
protocol = "udp"
default:
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
}
return protocol, nil
}
var (
ErrMetricSizeMalformed = errors.New("metric size is malformed")
)
// parseMetricSize parses a metric size string like 140K or 226M and
// returns the raw integer matching it.
func parseMetricSize(size string) (n uint64, err error) {
if size == "" {
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
}
//nolint:gomnd
multiplerLetterToValue := map[byte]uint64{
'K': 1000,
'M': 1000000,
'G': 1000000000,
'T': 1000000000000,
}
lastCharacter := size[len(size)-1]
multiplier, ok := multiplerLetterToValue[lastCharacter]
if ok { // multiplier present
size = size[:len(size)-1]
} else {
multiplier = 1
}
const base, bitLength = 10, 64
n, err = strconv.ParseUint(size, base, bitLength)
if err != nil {
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
}
n *= multiplier
return n, nil
}

View File

@@ -0,0 +1,121 @@
package firewall
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseChain(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
iptablesOutput string
table chain
errWrapped error
errMessage string
}{
"no_output": {
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: ",
},
"single_line_only": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: not enough lines to process in: " +
"Chain INPUT (policy ACCEPT 140K packets, 226M bytes)",
},
"malformed_general_data_line": {
iptablesOutput: `Chain INPUT
num pkts bytes target prot opt in out source destination`,
errWrapped: ErrChainListMalformed,
errMessage: "parsing chain general data line: iptables chain list output is malformed: " +
"expected 8 fields in \"Chain INPUT\"",
},
"malformed_legend": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source`,
errWrapped: ErrChainListMalformed,
errMessage: "iptables chain list output is malformed: legend " +
"\"num pkts bytes target prot opt in out source\" " +
"is not the expected \"num pkts bytes target prot opt in out source destination\"",
},
"no_rule": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source destination`,
table: chain{
name: "INPUT",
policy: "ACCEPT",
packets: 140000,
bytes: 226000000,
},
},
"some_rules": {
iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
num pkts bytes target prot opt in out source destination
1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
3 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0
`,
table: chain{
name: "INPUT",
policy: "ACCEPT",
packets: 140000,
bytes: 226000000,
rules: []chainRule{
{
lineNumber: 1,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "udp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
destinationPort: 55405,
},
{
lineNumber: 2,
packets: 0,
bytes: 0,
target: "ACCEPT",
protocol: "tcp",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("0.0.0.0/0"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
destinationPort: 55405,
},
{
lineNumber: 3,
packets: 0,
bytes: 0,
target: "DROP",
protocol: "",
inputInterface: "tun0",
outputInterface: "*",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("0.0.0.0/0"),
},
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
table, err := parseChain(testCase.iptablesOutput)
assert.Equal(t, testCase.table, table)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

View File

@@ -5,12 +5,6 @@ import (
"net/netip"
)
type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}
func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) {
c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+
"no default route matching its family", subnet))

View File

@@ -0,0 +1,3 @@
package firewall
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger

View File

@@ -0,0 +1,109 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger)
// Package firewall is a generated GoMock package.
package firewall
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
command "github.com/qdm12/golibs/command"
)
// MockRunner is a mock of Runner interface.
type MockRunner struct {
ctrl *gomock.Controller
recorder *MockRunnerMockRecorder
}
// MockRunnerMockRecorder is the mock recorder for MockRunner.
type MockRunnerMockRecorder struct {
mock *MockRunner
}
// NewMockRunner creates a new mock instance.
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
mock := &MockRunner{ctrl: ctrl}
mock.recorder = &MockRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
return m.recorder
}
// Run mocks base method.
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Run indicates an expected call of Run.
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
}
// MockLogger is a mock of Logger interface.
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger.
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance.
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method.
func (m *MockLogger) Debug(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Debug", arg0)
}
// Debug indicates an expected call of Debug.
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
}
// Error mocks base method.
func (m *MockLogger) Error(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Error", arg0)
}
// Error indicates an expected call of Error.
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
}
// Info mocks base method.
func (m *MockLogger) Info(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Info", arg0)
}
// Info indicates an expected call of Info.
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
}

166
internal/firewall/parse.go Normal file
View File

@@ -0,0 +1,166 @@
package firewall
import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
)
type iptablesInstruction struct {
table string // defaults to "filter", and can be "nat" for example.
append bool
chain string // for example INPUT, PREROUTING. Cannot be empty.
target string // for example ACCEPT. Can be empty.
protocol string // "tcp" or "udp" or "" for all protocols.
inputInterface string // for example "tun0" or "" for any interface.
outputInterface string // for example "tun0" or "" for any interface.
source netip.Prefix // if not valid, then it is unspecified.
destination netip.Prefix // if not valid, then it is unspecified.
destinationPort uint16 // if zero, there is no destination port
toPorts []uint16 // if empty, there is no redirection
ctstate []string // if empty, there is no ctstate
}
func (i *iptablesInstruction) setDefaults() {
if i.table == "" {
i.table = "filter"
}
}
// equalToRule ignores the append boolean flag of the instruction to compare against the rule.
func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) {
switch {
case i.table != table:
return false
case i.chain != chain:
return false
case i.target != rule.target:
return false
case i.protocol != rule.protocol:
return false
case i.destinationPort != rule.destinationPort:
return false
case !slices.Equal(i.toPorts, rule.redirPorts):
return false
case !slices.Equal(i.ctstate, rule.ctstate):
return false
case !networkInterfacesEqual(i.inputInterface, rule.inputInterface):
return false
case !networkInterfacesEqual(i.outputInterface, rule.outputInterface):
return false
case !ipPrefixesEqual(i.source, rule.source):
return false
case !ipPrefixesEqual(i.destination, rule.destination):
return false
default:
return true
}
}
// instruction can be "" which equivalent to the "*" chain rule interface.
func networkInterfacesEqual(instruction, chainRule string) bool {
return instruction == chainRule || (instruction == "" && chainRule == "*")
}
func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool {
return instruction == chainRule ||
(!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified())
}
var (
ErrIptablesCommandMalformed = errors.New("iptables command is malformed")
)
func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) {
if s == "" {
return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed)
}
fields := strings.Fields(s)
if len(fields)%2 != 0 {
return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q",
ErrIptablesCommandMalformed, len(fields), s)
}
for i := 0; i < len(fields); i += 2 {
key := fields[i]
value := fields[i+1]
err = parseInstructionFlag(key, value, &instruction)
if err != nil {
return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err)
}
}
instruction.setDefaults()
return instruction, nil
}
func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) {
switch key {
case "-t", "--table":
instruction.table = value
case "-D", "--delete":
instruction.append = false
instruction.chain = value
case "-A", "--append":
instruction.append = true
instruction.chain = value
case "-j", "--jump":
instruction.target = value
case "-p", "--protocol":
instruction.protocol = value
case "-m", "--match": // ignore match
case "-i", "--in-interface":
instruction.inputInterface = value
case "-o", "--out-interface":
instruction.outputInterface = value
case "-s", "--source":
instruction.source, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing source IP CIDR: %w", err)
}
case "-d", "--destination":
instruction.destination, err = parseIPPrefix(value)
if err != nil {
return fmt.Errorf("parsing destination IP CIDR: %w", err)
}
case "--dport":
const base, bitLength = 10, 16
destinationPort, err := strconv.ParseUint(value, base, bitLength)
if err != nil {
return fmt.Errorf("parsing destination port: %w", err)
}
instruction.destinationPort = uint16(destinationPort)
case "--ctstate":
instruction.ctstate = strings.Split(value, ",")
case "--to-ports":
portStrings := strings.Split(value, ",")
instruction.toPorts = make([]uint16, len(portStrings))
for i, portString := range portStrings {
const base, bitLength = 10, 16
port, err := strconv.ParseUint(portString, base, bitLength)
if err != nil {
return fmt.Errorf("parsing port redirection: %w", err)
}
instruction.toPorts[i] = uint16(port)
}
default:
return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key)
}
return nil
}
func parseIPPrefix(value string) (prefix netip.Prefix, err error) {
slashIndex := strings.Index(value, "/")
if slashIndex >= 0 {
return netip.ParsePrefix(value)
}
ip, err := netip.ParseAddr(value)
if err != nil {
return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err)
}
return netip.PrefixFrom(ip, ip.BitLen()), nil
}

View File

@@ -0,0 +1,138 @@
package firewall
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseIptablesInstruction(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
s string
instruction iptablesInstruction
errWrapped error
errMessage string
}{
"no_instruction": {
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: empty instruction",
},
"uneven_fields": {
s: "-A",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"",
},
"unknown_key": {
s: "-x something",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"",
},
"one_pair": {
s: "-A INPUT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
},
},
"instruction_A": {
s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT",
instruction: iptablesInstruction{
table: "filter",
chain: "INPUT",
append: true,
inputInterface: "tun0",
protocol: "tcp",
source: netip.MustParsePrefix("1.2.3.4/32"),
destination: netip.MustParsePrefix("5.6.7.8/32"),
destinationPort: 10000,
target: "ACCEPT",
},
},
"nat_redirection": {
s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
instruction: iptablesInstruction{
table: "nat",
chain: "PREROUTING",
append: false,
inputInterface: "tun0",
protocol: "tcp",
destinationPort: 43716,
target: "REDIRECT",
toPorts: []uint16{5678},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
rule, err := parseIptablesInstruction(testCase.s)
assert.Equal(t, testCase.instruction, rule)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}
func Test_parseIPPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
value string
prefix netip.Prefix
errMessage string
}{
"empty": {
errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`,
},
"invalid": {
value: "invalid",
errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`,
},
"valid_ipv4_with_bits": {
value: "10.0.0.0/16",
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16),
},
"valid_ipv4_without_bits": {
value: "10.0.0.4",
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32),
},
"valid_ipv6_with_bits": {
value: "2001:db8::/32",
prefix: netip.PrefixFrom(
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
32),
},
"valid_ipv6_without_bits": {
value: "2001:db8::",
prefix: netip.PrefixFrom(
netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}),
128),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix, err := parseIPPrefix(testCase.value)
assert.Equal(t, testCase.prefix, prefix)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,50 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/golibs/command (interfaces: Runner)
// Package firewall is a generated GoMock package.
package firewall
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
command "github.com/qdm12/golibs/command"
)
// MockRunner is a mock of Runner interface.
type MockRunner struct {
ctrl *gomock.Controller
recorder *MockRunnerMockRecorder
}
// MockRunnerMockRecorder is the mock recorder for MockRunner.
type MockRunnerMockRecorder struct {
mock *MockRunner
}
// NewMockRunner creates a new mock instance.
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
mock := &MockRunner{ctrl: ctrl}
mock.recorder = &MockRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
return m.recorder
}
// Run mocks base method.
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Run indicates an expected call of Run.
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
}

View File

@@ -11,8 +11,6 @@ import (
"github.com/stretchr/testify/require"
)
//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner
func newAppendTestRuleMatcher(path string) *cmdMatcher {
return newCmdMatcher(path,
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",