372 lines
10 KiB
Go
372 lines
10 KiB
Go
package firewall
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/netip"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type chain struct {
|
|
name string
|
|
policy string
|
|
packets uint64
|
|
bytes uint64
|
|
rules []chainRule
|
|
}
|
|
|
|
type chainRule struct {
|
|
lineNumber uint16 // starts from 1 and cannot be zero.
|
|
packets uint64
|
|
bytes uint64
|
|
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
|
|
protocol string // "icmp", "tcp", "udp" or "" for all protocols.
|
|
inputInterface string // input interface, for example "tun0" or "*""
|
|
outputInterface string // output interface, for example "eth0" or "*""
|
|
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
|
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
|
destinationPort uint16 // Not specified if set to zero.
|
|
redirPorts []uint16 // Not specified if empty.
|
|
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
|
}
|
|
|
|
var ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
|
|
|
func parseChain(iptablesOutput string) (c chain, err error) {
|
|
// Text example:
|
|
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
|
// pkts bytes target prot opt in out source destination
|
|
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
|
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
|
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
|
|
iptablesOutput = strings.TrimSpace(iptablesOutput)
|
|
linesWithComments := strings.Split(iptablesOutput, "\n")
|
|
|
|
// Filter out lines starting with a '#' character
|
|
lines := make([]string, 0, len(linesWithComments))
|
|
for _, line := range linesWithComments {
|
|
if strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
lines = append(lines, line)
|
|
}
|
|
|
|
const minLines = 2 // chain general information line + legend line
|
|
if len(lines) < minLines {
|
|
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
|
ErrChainListMalformed, iptablesOutput)
|
|
}
|
|
|
|
c, err = parseChainGeneralDataLine(lines[0])
|
|
if err != nil {
|
|
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
|
|
}
|
|
|
|
// Sanity check for the legend line
|
|
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
|
|
legendLine := strings.TrimSpace(lines[1])
|
|
legendFields := strings.Fields(legendLine)
|
|
if !slices.Equal(expectedLegendFields, legendFields) {
|
|
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
|
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
|
}
|
|
|
|
lines = lines[2:] // remove chain general information line and legend line
|
|
if len(lines) == 0 {
|
|
return c, nil
|
|
}
|
|
|
|
c.rules = make([]chainRule, len(lines))
|
|
for i, line := range lines {
|
|
c.rules[i], err = parseChainRuleLine(line)
|
|
if err != nil {
|
|
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
|
|
}
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
// parseChainGeneralDataLine parses the first line of iptables chain list output.
|
|
// For example, it can parse the following line:
|
|
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
|
// It returns a chain struct with the parsed data.
|
|
func parseChainGeneralDataLine(line string) (base chain, err error) {
|
|
line = strings.TrimSpace(line)
|
|
runesToRemove := []rune{'(', ')', ','}
|
|
for _, r := range runesToRemove {
|
|
line = strings.ReplaceAll(line, string(r), "")
|
|
}
|
|
|
|
fields := strings.Fields(line)
|
|
const expectedNumberOfFields = 8
|
|
if len(fields) != expectedNumberOfFields {
|
|
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
|
ErrChainListMalformed, expectedNumberOfFields, line)
|
|
}
|
|
|
|
// Sanity checks
|
|
indexToExpectedValue := map[int]string{
|
|
0: "Chain",
|
|
2: "policy",
|
|
5: "packets",
|
|
7: "bytes",
|
|
}
|
|
for index, expectedValue := range indexToExpectedValue {
|
|
if fields[index] == expectedValue {
|
|
continue
|
|
}
|
|
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
|
ErrChainListMalformed, expectedValue, index, line)
|
|
}
|
|
|
|
base.name = fields[1] // chain name could be custom
|
|
base.policy = fields[3]
|
|
err = checkTarget(base.policy)
|
|
if err != nil {
|
|
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
|
|
}
|
|
|
|
packets, err := parseMetricSize(fields[4])
|
|
if err != nil {
|
|
return chain{}, fmt.Errorf("parsing packets: %w", err)
|
|
}
|
|
base.packets = packets
|
|
|
|
bytes, err := parseMetricSize(fields[6])
|
|
if err != nil {
|
|
return chain{}, fmt.Errorf("parsing bytes: %w", err)
|
|
}
|
|
base.bytes = bytes
|
|
|
|
return base, nil
|
|
}
|
|
|
|
var ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
|
|
|
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
|
line = strings.TrimSpace(line)
|
|
if line == "" {
|
|
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
|
}
|
|
|
|
fields := strings.Fields(line)
|
|
|
|
const minFields = 10
|
|
if len(fields) < minFields {
|
|
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
|
}
|
|
|
|
for fieldIndex, field := range fields[:minFields] {
|
|
err = parseChainRuleField(fieldIndex, field, &rule)
|
|
if err != nil {
|
|
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
|
|
}
|
|
}
|
|
|
|
if len(fields) > minFields {
|
|
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
|
|
if err != nil {
|
|
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
|
|
}
|
|
}
|
|
|
|
return rule, nil
|
|
}
|
|
|
|
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
|
if field == "" {
|
|
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
|
}
|
|
|
|
const (
|
|
numIndex = iota
|
|
packetsIndex
|
|
bytesIndex
|
|
targetIndex
|
|
protocolIndex
|
|
optIndex
|
|
inputInterfaceIndex
|
|
outputInterfaceIndex
|
|
sourceIndex
|
|
destinationIndex
|
|
)
|
|
|
|
switch fieldIndex {
|
|
case numIndex:
|
|
rule.lineNumber, err = parseLineNumber(field)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing line number: %w", err)
|
|
}
|
|
case packetsIndex:
|
|
rule.packets, err = parseMetricSize(field)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing packets: %w", err)
|
|
}
|
|
case bytesIndex:
|
|
rule.bytes, err = parseMetricSize(field)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing bytes: %w", err)
|
|
}
|
|
case targetIndex:
|
|
err = checkTarget(field)
|
|
if err != nil {
|
|
return fmt.Errorf("checking target: %w", err)
|
|
}
|
|
rule.target = field
|
|
case protocolIndex:
|
|
rule.protocol, err = parseProtocol(field)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing protocol: %w", err)
|
|
}
|
|
case optIndex: // ignored
|
|
case inputInterfaceIndex:
|
|
rule.inputInterface = field
|
|
case outputInterfaceIndex:
|
|
rule.outputInterface = field
|
|
case sourceIndex:
|
|
rule.source, err = parseIPPrefix(field)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
|
}
|
|
case destinationIndex:
|
|
rule.destination, err = parseIPPrefix(field)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
|
for i := 0; i < len(optionalFields); i++ {
|
|
key := optionalFields[i]
|
|
switch key {
|
|
case "tcp", "udp":
|
|
i++
|
|
value := optionalFields[i]
|
|
value = strings.TrimPrefix(value, "dpt:")
|
|
const base, bitLength = 10, 16
|
|
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
|
}
|
|
rule.destinationPort = uint16(destinationPort)
|
|
case "redir":
|
|
i++
|
|
switch optionalFields[i] {
|
|
case "ports":
|
|
i++
|
|
ports, err := parsePortsCSV(optionalFields[i])
|
|
if err != nil {
|
|
return fmt.Errorf("parsing redirection ports: %w", err)
|
|
}
|
|
rule.redirPorts = ports
|
|
default:
|
|
return fmt.Errorf("%w: unexpected optional field: %s",
|
|
ErrChainRuleMalformed, optionalFields[i])
|
|
}
|
|
case "ctstate":
|
|
i++
|
|
rule.ctstate = strings.Split(optionalFields[i], ",")
|
|
default:
|
|
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parsePortsCSV(s string) (ports []uint16, err error) {
|
|
if s == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
fields := strings.Split(s, ",")
|
|
ports = make([]uint16, len(fields))
|
|
for i, field := range fields {
|
|
const base, bitLength = 10, 16
|
|
port, err := strconv.ParseUint(field, base, bitLength)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
|
}
|
|
ports[i] = uint16(port)
|
|
}
|
|
return ports, nil
|
|
}
|
|
|
|
var ErrLineNumberIsZero = errors.New("line number is zero")
|
|
|
|
func parseLineNumber(s string) (n uint16, err error) {
|
|
const base, bitLength = 10, 16
|
|
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
|
if err != nil {
|
|
return 0, err
|
|
} else if lineNumber == 0 {
|
|
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
|
}
|
|
return uint16(lineNumber), nil
|
|
}
|
|
|
|
var ErrTargetUnknown = errors.New("unknown target")
|
|
|
|
func checkTarget(target string) (err error) {
|
|
switch target {
|
|
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
|
return nil
|
|
}
|
|
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
|
}
|
|
|
|
var ErrProtocolUnknown = errors.New("unknown protocol")
|
|
|
|
func parseProtocol(s string) (protocol string, err error) {
|
|
switch s {
|
|
case "0":
|
|
case "1":
|
|
protocol = "icmp"
|
|
case "6":
|
|
protocol = "tcp"
|
|
case "17":
|
|
protocol = "udp"
|
|
default:
|
|
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
|
}
|
|
return protocol, nil
|
|
}
|
|
|
|
var ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
|
|
|
// parseMetricSize parses a metric size string like 140K or 226M and
|
|
// returns the raw integer matching it.
|
|
func parseMetricSize(size string) (n uint64, err error) {
|
|
if size == "" {
|
|
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
|
}
|
|
|
|
//nolint:mnd
|
|
multiplerLetterToValue := map[byte]uint64{
|
|
'K': 1000,
|
|
'M': 1000000,
|
|
'G': 1000000000,
|
|
'T': 1000000000000,
|
|
}
|
|
|
|
lastCharacter := size[len(size)-1]
|
|
multiplier, ok := multiplerLetterToValue[lastCharacter]
|
|
if ok { // multiplier present
|
|
size = size[:len(size)-1]
|
|
} else {
|
|
multiplier = 1
|
|
}
|
|
|
|
const base, bitLength = 10, 64
|
|
n, err = strconv.ParseUint(size, base, bitLength)
|
|
if err != nil {
|
|
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
|
}
|
|
n *= multiplier
|
|
return n, nil
|
|
}
|