Files
gluetun/internal/firewall/enable.go
Quentin McGaw 7ba98af1cc Feature/Bugfix: IPv6 blocking (#428)
- Feature/Bugfix: Block all IPv6 traffic with `ip6tables` by default
- Feature: Adapt existing firewall code to handle IPv4 and IPv6, depending on user inputs and environment
- Maintenance: improve error wrapping in the firewall package
2021-04-19 09:24:46 -04:00

145 lines
3.9 KiB
Go

package firewall
import (
"context"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/constants"
)
var (
ErrEnable = errors.New("failed enabling firewall")
ErrDisable = errors.New("failed disabling firewall")
ErrUserPostRules = errors.New("cannot run user post firewall rules")
)
func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if enabled == c.enabled {
if enabled {
c.logger.Info("already enabled")
} else {
c.logger.Info("already disabled")
}
return nil
}
if !enabled {
c.logger.Info("disabling...")
if err = c.disable(ctx); err != nil {
return fmt.Errorf("%w: %s", ErrDisable, err)
}
c.enabled = false
c.logger.Info("disabled successfully")
return nil
}
c.logger.Info("enabling...")
if err := c.enable(ctx); err != nil {
return fmt.Errorf("%w: %s", ErrEnable, err)
}
c.enabled = true
c.logger.Info("enabled successfully")
return nil
}
func (c *configurator) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err)
}
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err)
}
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err)
}
return nil
}
// To use in defered call when enabling the firewall.
func (c *configurator) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.disable(ctx); err != nil {
c.logger.Error("failed reversing firewall changes: " + err.Error())
}
}
func (c *configurator) enable(ctx context.Context) (err error) {
touched := false
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
touched = true
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
const remove = false
defer func() {
if touched && err != nil {
c.fallbackToDisabled(ctx)
}
}()
// Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if c.vpnConnection.IP != nil {
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
for _, network := range c.localNetworks {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
for _, subnet := range c.outboundSubnets {
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
// Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun.
for _, network := range c.localNetworks {
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, *network.IPNet, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
for port, intf := range c.allowedInputPorts {
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil {
return fmt.Errorf("%w: %s", ErrUserPostRules, err)
}
return nil
}