Firewall refactoring

- Ability to enable and disable rules in various loops
- Simplified code overall
- Port forwarding moved into openvpn loop
- Route addition and removal improved
This commit is contained in:
Quentin McGaw
2020-07-11 21:03:55 +00:00
parent ccf11990f1
commit b1596bc7e4
20 changed files with 887 additions and 359 deletions

149
internal/firewall/enable.go Normal file
View File

@@ -0,0 +1,149 @@
package firewall
import (
"context"
"fmt"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if enabled == c.enabled {
if enabled {
c.logger.Info("already enabled")
} else {
c.logger.Info("already disabled")
}
return nil
}
if !enabled {
c.logger.Info("disabling...")
if err = c.disable(ctx); err != nil {
return err
}
c.enabled = false
c.logger.Info("disabled successfully")
return nil
}
c.logger.Info("enabling...")
if err := c.enable(ctx); err != nil {
return err
}
c.enabled = true
c.logger.Info("enabled successfully")
return nil
}
func (c *configurator) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err)
}
if err = c.setAllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err)
}
// TODO routes?
return nil
}
// To use in defered call when enabling the firewall
func (c *configurator) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.SetEnabled(ctx, false); err != nil {
c.logger.Error(err)
}
}
func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit
defaultInterface, defaultGateway, defaultSubnet, err := c.routing.DefaultRoute()
if err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
fmt.Println(1)
if err = c.setAllPolicies(ctx, "DROP"); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
const remove = false
defer func() {
if err != nil {
c.fallbackToDisabled(ctx)
}
}()
// Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
for _, conn := range c.vpnConnections {
if err = c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptInputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptOutputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
for _, subnet := range c.allowedSubnets {
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
// Re-ensure all routes exist
for _, subnet := range c.allowedSubnets {
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
for port := range c.allowedPorts {
// TODO restrict interface
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if c.portForwarded > 0 {
const tun = string(constants.TUN)
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
return nil
}

View File

@@ -3,42 +3,49 @@ package firewall
import (
"context"
"net"
"sync"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/routing"
)
// Configurator allows to change firewall rules and modify network routes
type Configurator interface {
Version(ctx context.Context) (string, error)
AcceptAll(ctx context.Context) error
Clear(ctx context.Context) error
BlockAll(ctx context.Context) error
CreateGeneralRules(ctx context.Context) error
CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error
CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error
AllowAnyIncomingOnPort(ctx context.Context, port uint16) error
RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error
Disable()
SetEnabled(ctx context.Context, enabled bool) (err error)
SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error)
SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error)
SetAllowedPort(ctx context.Context, port uint16) error
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
SetPortForward(ctx context.Context, port uint16) (err error)
}
type configurator struct {
commander command.Commander
logger logging.Logger
disabled bool
type configurator struct { //nolint:maligned
commander command.Commander
logger logging.Logger
routing routing.Routing
fileManager files.FileManager // for custom iptables rules
iptablesMutex sync.Mutex
// State
enabled bool
vpnConnections []models.OpenVPNConnection
allowedSubnets []net.IPNet
allowedPorts map[uint16]struct{}
portForwarded uint16
stateMutex sync.Mutex
}
// NewConfigurator creates a new Configurator instance
func NewConfigurator(logger logging.Logger) Configurator {
func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator {
return &configurator{
commander: command.NewCommander(),
logger: logger.WithPrefix("firewall configurator: "),
commander: command.NewCommander(),
logger: logger.WithPrefix("firewall: "),
routing: routing,
fileManager: fileManager,
allowedPorts: make(map[uint16]struct{}),
}
}
func (c *configurator) Disable() {
c.disabled = true
}

View File

@@ -6,10 +6,32 @@ import (
"net"
"strings"
"github.com/qdm12/golibs/files"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func appendOrDelete(remove bool) string {
if remove {
return "--delete"
}
return "--append"
}
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
switch {
case strings.HasPrefix(rule, "-A"):
return strings.Replace(rule, "-A", "-D", 1)
case strings.HasPrefix(rule, "--append"):
return strings.Replace(rule, "--append", "-D", 1)
case strings.HasPrefix(rule, "-D"):
return strings.Replace(rule, "-D", "-A", 1)
case strings.HasPrefix(rule, "--delete"):
return strings.Replace(rule, "--delete", "-A", 1)
}
return rule
}
// Version obtains the version of the installed iptables
func (c *configurator) Version(ctx context.Context) (string, error) {
output, err := c.commander.Run(ctx, "iptables", "--version")
@@ -33,6 +55,8 @@ func (c *configurator) runIptablesInstructions(ctx context.Context, instructions
}
func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error {
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
flags := strings.Fields(instruction)
if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil {
return fmt.Errorf("failed executing \"iptables %s\": %s: %w", instruction, output, err)
@@ -40,146 +64,119 @@ func (c *configurator) runIptablesInstruction(ctx context.Context, instruction s
return nil
}
func (c *configurator) Clear(ctx context.Context) error {
if c.disabled {
return nil
}
c.logger.Info("clearing all rules")
func (c *configurator) clearAllRules(ctx context.Context) error {
return c.runIptablesInstructions(ctx, []string{
"--flush",
"--delete-chain",
"--flush", // flush all chains
"--delete-chain", // delete all chains
})
}
func (c *configurator) AcceptAll(ctx context.Context) error {
if c.disabled {
return nil
func (c *configurator) setAllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("policy %q not recognized", policy)
}
c.logger.Info("accepting all traffic")
return c.runIptablesInstructions(ctx, []string{
"-P INPUT ACCEPT",
"-P OUTPUT ACCEPT",
"-P FORWARD ACCEPT",
fmt.Sprintf("--policy INPUT %s", policy),
fmt.Sprintf("--policy OUTPUT %s", policy),
fmt.Sprintf("--policy FORWARD %s", policy),
})
}
func (c *configurator) BlockAll(ctx context.Context) error {
if c.disabled {
return nil
}
c.logger.Info("blocking all traffic")
func (c *configurator) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *configurator) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
return c.runIptablesInstructions(ctx, []string{
"-P INPUT DROP",
"-F OUTPUT",
"-P OUTPUT DROP",
"-P FORWARD DROP",
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
})
}
func (c *configurator) CreateGeneralRules(ctx context.Context) error {
if c.disabled {
return nil
}
c.logger.Info("creating general rules")
return c.runIptablesInstructions(ctx, []string{
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"-A OUTPUT -o lo -j ACCEPT",
"-A INPUT -i lo -j ACCEPT",
})
func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.OpenVPNConnection, remove bool) error {
return c.runIptablesInstruction(ctx,
fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port))
}
func (c *configurator) CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error {
if c.disabled {
return nil
}
for _, connection := range connections {
c.logger.Info("allowing output traffic to VPN server %s through %s on port %s %d",
connection.IP, defaultInterface, connection.Protocol, connection.Port)
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)); err != nil {
return err
}
}
if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
return err
}
return nil
}
func (c *configurator) CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
if c.disabled {
return nil
}
func (c *configurator) acceptInputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
subnetStr := subnet.String()
c.logger.Info("accepting input and output traffic for %s", subnetStr)
if err := c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
}); err != nil {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
))
}
// Thanks to @npawelek
func (c *configurator) acceptOutputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
subnetStr := subnet.String()
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
))
}
// Used for port forwarding, with intf set to tun
func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runIptablesInstruction(ctx,
fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port),
)
}
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
exists, err := c.fileManager.FileExists(filepath)
if err != nil {
return err
}
for _, extraSubnet := range extraSubnets {
extraSubnetStr := extraSubnet.String()
c.logger.Info("accepting input traffic through %s from %s to %s", defaultInterface, extraSubnetStr, subnetStr)
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
return err
}
// Thanks to @npawelek
c.logger.Info("accepting output traffic through %s from %s to %s", defaultInterface, subnetStr, extraSubnetStr)
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
return err
}
}
return nil
}
// Used for port forwarding
func (c *configurator) AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error {
if c.disabled {
} else if !exists {
return nil
}
c.logger.Info("accepting input traffic through %s on port %d", device, port)
return c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
})
}
func (c *configurator) AllowAnyIncomingOnPort(ctx context.Context, port uint16) error {
if c.disabled {
return nil
}
c.logger.Info("accepting any input traffic on port %d", port)
return c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -p tcp --dport %d -j ACCEPT", port),
fmt.Sprintf("-A INPUT -p udp --dport %d -j ACCEPT", port),
})
}
func (c *configurator) RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error {
exists, err := fileManager.FileExists(filepath)
b, err := c.fileManager.ReadFile(filepath)
if err != nil {
return err
}
if exists {
b, err := fileManager.ReadFile(filepath)
if err != nil {
return err
lines := strings.Split(string(b), "\n")
successfulRules := []string{}
defer func() {
// transaction-like rollback
if err == nil || ctx.Err() != nil {
return
}
lines := strings.Split(string(b), "\n")
var rules []string
for _, line := range lines {
if !strings.HasPrefix(line, "iptables ") {
continue
}
rules = append(rules, strings.TrimPrefix(line, "iptables "))
c.logger.Info("running user post firewall rule: %s", line)
for _, rule := range successfulRules {
_ = c.runIptablesInstruction(ctx, flipRule(rule))
}
return c.runIptablesInstructions(ctx, rules)
}()
for _, line := range lines {
if !strings.HasPrefix(line, "iptables ") {
continue
}
rule := strings.TrimPrefix(line, "iptables ")
if remove {
rule = flipRule(rule)
}
if err = c.runIptablesInstruction(ctx, rule); err != nil {
return fmt.Errorf("cannot run custom rule: %w", err)
}
successfulRules = append(successfulRules, rule)
}
return nil
}

109
internal/firewall/ports.go Normal file
View File

@@ -0,0 +1,109 @@
package firewall
import (
"context"
"fmt"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if port == 0 {
return nil
}
if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed ports internal list")
c.allowedPorts[port] = struct{}{}
return nil
}
c.logger.Info("setting allowed port %d through firewall...", port)
if _, ok := c.allowedPorts[port]; ok {
return nil
}
const remove = false
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
}
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
}
c.allowedPorts[port] = struct{}{}
return nil
}
func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if port == 0 {
return nil
}
if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed ports internal list")
delete(c.allowedPorts, port)
return nil
}
c.logger.Info("removing allowed port %d through firewall...", port)
if _, ok := c.allowedPorts[port]; !ok {
return nil
}
const remove = true
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
}
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
}
delete(c.allowedPorts, port)
return nil
}
// Use 0 to remove
func (c *configurator) SetPortForward(ctx context.Context, port uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if port == c.portForwarded {
return nil
}
if !c.enabled {
c.logger.Info("firewall disabled, only updating port forwarded internally")
c.portForwarded = port
return nil
}
const tun = string(constants.TUN)
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, true); err != nil {
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, true); err != nil {
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
}
if port == 0 { // not changing port
c.portForwarded = 0
return nil
}
if err := c.acceptInputToPort(ctx, tun, constants.TCP, port, false); err != nil {
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, port, false); err != nil {
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
}
return nil
}

View File

@@ -0,0 +1,127 @@
package firewall
import (
"context"
"fmt"
"net"
)
func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed subnets internal list")
c.allowedSubnets = make([]net.IPNet, len(subnets))
copy(c.allowedSubnets, subnets)
return nil
}
c.logger.Info("setting allowed subnets through firewall...")
subnetsToAdd := findSubnetsToAdd(c.allowedSubnets, subnets)
subnetsToRemove := findSubnetsToRemove(c.allowedSubnets, subnets)
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
return nil
}
defaultInterface, defaultGateway, _, err := c.routing.DefaultRoute()
if err != nil {
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
}
c.removeSubnets(ctx, subnetsToRemove, defaultInterface)
if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway); err != nil {
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
}
return nil
}
func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) {
for _, newSubnet := range newSubnets {
found := false
for _, oldSubnet := range oldSubnets {
if subnetsAreEqual(oldSubnet, newSubnet) {
found = true
break
}
}
if !found {
subnetsToAdd = append(subnetsToAdd, newSubnet)
}
}
return subnetsToAdd
}
func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) {
for _, oldSubnet := range oldSubnets {
found := false
for _, newSubnet := range newSubnets {
if subnetsAreEqual(oldSubnet, newSubnet) {
found = true
break
}
}
if !found {
subnetsToRemove = append(subnetsToRemove, oldSubnet)
}
}
return subnetsToRemove
}
func subnetsAreEqual(a, b net.IPNet) bool {
return a.IP.Equal(b.IP) && a.Mask.String() == b.Mask.String()
}
func removeSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet {
L := len(subnets)
for i := range subnets {
if subnetsAreEqual(subnet, subnets[i]) {
subnets[i] = subnets[L-1]
subnets = subnets[:L-1]
break
}
}
return subnets
}
func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string) {
const remove = true
for _, subnet := range subnets {
failed := false
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
failed = true
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
}
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
failed = true
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
}
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
failed = true
c.logger.Error("cannot remove outdated allowed subnet route: %s", err)
}
if failed {
continue
}
c.allowedSubnets = removeSubnetFromSubnets(c.allowedSubnets, subnet)
}
}
func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, defaultGateway net.IP) error {
const remove = false
for _, subnet := range subnets {
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
}
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
}
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
return fmt.Errorf("cannot add route for allowed subnet: %w", err)
}
c.allowedSubnets = append(c.allowedSubnets, subnet)
}
return nil
}

112
internal/firewall/vpn.go Normal file
View File

@@ -0,0 +1,112 @@
package firewall
import (
"context"
"fmt"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if !c.enabled {
c.logger.Info("firewall disabled, only updating VPN connections internal list")
c.vpnConnections = make([]models.OpenVPNConnection, len(connections))
copy(c.vpnConnections, connections)
return nil
}
c.logger.Info("setting VPN connections through firewall...")
connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections)
connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections)
if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 {
return nil
}
defaultInterface, _, _, err := c.routing.DefaultRoute()
if err != nil {
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
}
// TODO remove elsewhere?
if err := c.acceptOutputThroughInterface(ctx, string(constants.TUN), false); err != nil {
return fmt.Errorf("cannot allow traffic through tunnel: %w", err)
}
c.removeConnections(ctx, connectionsToRemove, defaultInterface)
if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil {
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
}
return nil
}
func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection {
L := len(connections)
for i := range connections {
if connection.Equal(connections[i]) {
connections[i] = connections[L-1]
connections = connections[:L-1]
break
}
}
return connections
}
func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) {
for _, newConnection := range newConnections {
found := false
for _, oldConnection := range oldConnections {
if oldConnection.Equal(newConnection) {
found = true
break
}
}
if !found {
connectionsToAdd = append(connectionsToAdd, newConnection)
}
}
return connectionsToAdd
}
func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) {
for _, oldConnection := range oldConnections {
found := false
for _, newConnection := range newConnections {
if oldConnection.Equal(newConnection) {
found = true
break
}
}
if !found {
connectionsToRemove = append(connectionsToRemove, oldConnection)
}
}
return connectionsToRemove
}
func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) {
for _, conn := range connections {
const remove = true
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err)
continue
}
c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn)
}
}
func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error {
const remove = false
for _, conn := range connections {
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
return err
}
c.vpnConnections = append(c.vpnConnections, conn)
}
return nil
}