From 7ba98af1cc4a969ff44ca4d9166b57654e534677 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 19 Apr 2021 09:24:46 -0400 Subject: [PATCH] Feature/Bugfix: IPv6 blocking (#428) - Feature/Bugfix: Block all IPv6 traffic with `ip6tables` by default - Feature: Adapt existing firewall code to handle IPv4 and IPv6, depending on user inputs and environment - Maintenance: improve error wrapping in the firewall package --- internal/firewall/enable.go | 32 ++++++--- internal/firewall/firewall.go | 1 + internal/firewall/ip6tables.go | 47 ++++++++++++++ internal/firewall/iptables.go | 107 ++++++++++++++++++++++--------- internal/firewall/iptablesmix.go | 21 ++++++ 5 files changed, 171 insertions(+), 37 deletions(-) create mode 100644 internal/firewall/ip6tables.go create mode 100644 internal/firewall/iptablesmix.go diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 7edb4780..c51ba063 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -2,11 +2,18 @@ package firewall import ( "context" + "errors" "fmt" "github.com/qdm12/gluetun/internal/constants" ) +var ( + ErrEnable = errors.New("failed enabling firewall") + ErrDisable = errors.New("failed disabling firewall") + ErrUserPostRules = errors.New("cannot run user post firewall rules") +) + func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() @@ -23,7 +30,7 @@ func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) if !enabled { c.logger.Info("disabling...") if err = c.disable(ctx); err != nil { - return err + return fmt.Errorf("%w: %s", ErrDisable, err) } c.enabled = false c.logger.Info("disabled successfully") @@ -33,7 +40,7 @@ func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) c.logger.Info("enabling...") if err := c.enable(ctx); err != nil { - return err + return fmt.Errorf("%w: %s", ErrEnable, err) } c.enabled = true c.logger.Info("enabled successfully") @@ -45,7 +52,10 @@ func (c *configurator) disable(ctx context.Context) (err error) { if err = c.clearAllRules(ctx); err != nil { return fmt.Errorf("cannot disable firewall: %w", err) } - if err = c.setAllPolicies(ctx, "ACCEPT"); err != nil { + if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil { + return fmt.Errorf("cannot disable firewall: %w", err) + } + if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil { return fmt.Errorf("cannot disable firewall: %w", err) } return nil @@ -56,20 +66,26 @@ func (c *configurator) fallbackToDisabled(ctx context.Context) { if ctx.Err() != nil { return } - if err := c.SetEnabled(ctx, false); err != nil { - c.logger.Error(err) + if err := c.disable(ctx); err != nil { + c.logger.Error("failed reversing firewall changes: " + err.Error()) } } func (c *configurator) enable(ctx context.Context) (err error) { - if err = c.setAllPolicies(ctx, "DROP"); err != nil { + touched := false + if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + touched = true + + if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } const remove = false defer func() { - if err != nil { + if touched && err != nil { c.fallbackToDisabled(ctx) } }() @@ -121,7 +137,7 @@ func (c *configurator) enable(ctx context.Context) (err error) { } if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) + return fmt.Errorf("%w: %s", ErrUserPostRules, err) } return nil diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index b098db60..cb9ac873 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -34,6 +34,7 @@ type configurator struct { //nolint:maligned routing routing.Routing openFile os.OpenFileFunc // for custom iptables rules iptablesMutex sync.Mutex + ip6tablesMutex sync.Mutex debug bool defaultInterface string defaultGateway net.IP diff --git a/internal/firewall/ip6tables.go b/internal/firewall/ip6tables.go new file mode 100644 index 00000000..9666badd --- /dev/null +++ b/internal/firewall/ip6tables.go @@ -0,0 +1,47 @@ +package firewall + +import ( + "context" + "errors" + "fmt" + "strings" +) + +var ( + ErrIP6Tables = errors.New("failed ip6tables command") +) + +func (c *configurator) runIP6tablesInstructions(ctx context.Context, instructions []string) error { + for _, instruction := range instructions { + if err := c.runIP6tablesInstruction(ctx, instruction); err != nil { + return err + } + } + return nil +} + +func (c *configurator) runIP6tablesInstruction(ctx context.Context, instruction string) error { + c.ip6tablesMutex.Lock() // only one ip6tables command at once + defer c.ip6tablesMutex.Unlock() + if c.debug { + fmt.Println("ip6tables " + instruction) + } + flags := strings.Fields(instruction) + if output, err := c.commander.Run(ctx, "ip6tables", flags...); err != nil { + return fmt.Errorf("%w \"ip6tables %s\": %s: %s", ErrIP6Tables, instruction, output, err) + } + return nil +} + +func (c *configurator) setIPv6AllPolicies(ctx context.Context, policy string) error { + switch policy { + case "ACCEPT", "DROP": + default: + return fmt.Errorf("policy %q not recognized", policy) + } + return c.runIP6tablesInstructions(ctx, []string{ + "--policy INPUT " + policy, + "--policy OUTPUT " + policy, + "--policy FORWARD " + policy, + }) +} diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index 1d224b45..4b9d6574 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -2,6 +2,7 @@ package firewall import ( "context" + "errors" "fmt" "io/ioutil" "net" @@ -11,6 +12,14 @@ import ( "github.com/qdm12/gluetun/internal/models" ) +var ( + ErrIPTablesVersionTooShort = errors.New("iptables version string is too short") + ErrIPTables = errors.New("failed iptables command") + ErrPolicyUnknown = errors.New("unknown policy") + ErrClearRules = errors.New("cannot clear all rules") + ErrSetIPtablesPolicies = errors.New("cannot set iptables policies") +) + func appendOrDelete(remove bool) string { if remove { return "--delete" @@ -43,7 +52,7 @@ func (c *configurator) Version(ctx context.Context) (string, error) { words := strings.Fields(output) const minWords = 2 if len(words) < minWords { - return "", fmt.Errorf("iptables --version: output is too short: %q", output) + return "", fmt.Errorf("%w: %s", ErrIPTablesVersionTooShort, output) } return words[1], nil } @@ -65,55 +74,68 @@ func (c *configurator) runIptablesInstruction(ctx context.Context, instruction s } flags := strings.Fields(instruction) if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil { - return fmt.Errorf("failed executing \"iptables %s\": %s: %w", instruction, output, err) + return fmt.Errorf("%w \"iptables %s\": %s: %s", ErrIPTables, instruction, output, err) } return nil } func (c *configurator) clearAllRules(ctx context.Context) error { - return c.runIptablesInstructions(ctx, []string{ + if err := c.runMixedIptablesInstructions(ctx, []string{ "--flush", // flush all chains "--delete-chain", // delete all chains - }) + }); err != nil { + return fmt.Errorf("%w: %s", ErrClearRules, err.Error()) + } + return nil } -func (c *configurator) setAllPolicies(ctx context.Context, policy string) error { +func (c *configurator) setIPv4AllPolicies(ctx context.Context, policy string) error { switch policy { case "ACCEPT", "DROP": default: - return fmt.Errorf("policy %q not recognized", policy) + return fmt.Errorf("%w: %s: %s", ErrSetIPtablesPolicies, ErrPolicyUnknown, policy) } - return c.runIptablesInstructions(ctx, []string{ - fmt.Sprintf("--policy INPUT %s", policy), - fmt.Sprintf("--policy OUTPUT %s", policy), - fmt.Sprintf("--policy FORWARD %s", policy), - }) + if err := c.runIptablesInstructions(ctx, []string{ + "--policy INPUT " + policy, + "--policy OUTPUT " + policy, + "--policy FORWARD " + policy, + }); err != nil { + return fmt.Errorf("%w: %s", ErrSetIPtablesPolicies, err) + } + return nil } func (c *configurator) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error { - return c.runIptablesInstruction(ctx, fmt.Sprintf( + return c.runMixedIptablesInstruction(ctx, fmt.Sprintf( "%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf, )) } func (c *configurator) acceptInputToSubnet(ctx context.Context, intf string, destination net.IPNet, remove bool) error { + isIP4Subnet := destination.IP.To4() != nil + interfaceFlag := "-i " + intf if intf == "*" { // all interfaces interfaceFlag = "" } - return c.runIptablesInstruction(ctx, fmt.Sprintf( - "%s INPUT %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, destination.String(), - )) + + instruction := fmt.Sprintf("%s INPUT %s -d %s -j ACCEPT", + appendOrDelete(remove), interfaceFlag, destination.String()) + + if isIP4Subnet { + return c.runIptablesInstruction(ctx, instruction) + } + return c.runIP6tablesInstruction(ctx, instruction) } func (c *configurator) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error { - return c.runIptablesInstruction(ctx, fmt.Sprintf( + return c.runMixedIptablesInstruction(ctx, fmt.Sprintf( "%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf, )) } func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error { - return c.runIptablesInstructions(ctx, []string{ + return c.runMixedIptablesInstructions(ctx, []string{ fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), }) @@ -121,22 +143,33 @@ func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remo func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.OpenVPNConnection, remove bool) error { - return c.runIptablesInstruction(ctx, - fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT", - appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)) + instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT", + appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, + connection.Protocol, connection.Port) + isIPv4 := connection.IP.To4() != nil + if isIPv4 { + return c.runIptablesInstruction(ctx, instruction) + } + return c.runIP6tablesInstruction(ctx, instruction) } // Thanks to @npawelek. func (c *configurator) acceptOutputFromIPToSubnet(ctx context.Context, intf string, sourceIP net.IP, destinationSubnet net.IPNet, remove bool) error { + doIPv4 := sourceIP.To4() != nil && destinationSubnet.IP.To4() != nil + interfaceFlag := "-o " + intf if intf == "*" { // all interfaces interfaceFlag = "" } - return c.runIptablesInstruction(ctx, fmt.Sprintf( - "%s OUTPUT %s -s %s -d %s -j ACCEPT", - appendOrDelete(remove), interfaceFlag, sourceIP.String(), destinationSubnet.String(), - )) + + instruction := fmt.Sprintf("%s OUTPUT %s -s %s -d %s -j ACCEPT", + appendOrDelete(remove), interfaceFlag, sourceIP.String(), destinationSubnet.String()) + + if doIPv4 { + return c.runIptablesInstruction(ctx, instruction) + } + return c.runIP6tablesInstruction(ctx, instruction) } // Used for port forwarding, with intf set to tun. @@ -145,7 +178,7 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port if intf == "*" { // all interfaces interfaceFlag = "" } - return c.runIptablesInstructions(ctx, []string{ + return c.runMixedIptablesInstructions(ctx, []string{ fmt.Sprintf("%s INPUT %s -p tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port), fmt.Sprintf("%s INPUT %s -p udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port), }) @@ -178,16 +211,32 @@ func (c *configurator) runUserPostRules(ctx context.Context, filepath string, re } }() for _, line := range lines { - if !strings.HasPrefix(line, "iptables ") { + var ipv4 bool + var rule string + switch { + case strings.HasPrefix(line, "iptables "): + ipv4 = true + rule = strings.TrimPrefix(line, "iptables ") + case strings.HasPrefix(line, "ip6tables "): + ipv4 = false + rule = strings.TrimPrefix(line, "ip6tables ") + default: continue } - rule := strings.TrimPrefix(line, "iptables ") + if remove { rule = flipRule(rule) } - if err = c.runIptablesInstruction(ctx, rule); err != nil { - return fmt.Errorf("cannot run custom rule: %w", err) + + if ipv4 { + err = c.runIptablesInstruction(ctx, rule) + } else { + err = c.runIP6tablesInstruction(ctx, rule) } + if err != nil { + return err + } + successfulRules = append(successfulRules, rule) } return nil diff --git a/internal/firewall/iptablesmix.go b/internal/firewall/iptablesmix.go new file mode 100644 index 00000000..e3456330 --- /dev/null +++ b/internal/firewall/iptablesmix.go @@ -0,0 +1,21 @@ +package firewall + +import ( + "context" +) + +func (c *configurator) runMixedIptablesInstructions(ctx context.Context, instructions []string) error { + for _, instruction := range instructions { + if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil { + return err + } + } + return nil +} + +func (c *configurator) runMixedIptablesInstruction(ctx context.Context, instruction string) error { + if err := c.runIptablesInstruction(ctx, instruction); err != nil { + return err + } + return c.runIP6tablesInstruction(ctx, instruction) +}