Firewall refactoring
- Ability to enable and disable rules in various loops - Simplified code overall - Port forwarding moved into openvpn loop - Route addition and removal improved
This commit is contained in:
@@ -6,10 +6,32 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
func appendOrDelete(remove bool) string {
|
||||
if remove {
|
||||
return "--delete"
|
||||
}
|
||||
return "--append"
|
||||
}
|
||||
|
||||
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||
// append rule.
|
||||
func flipRule(rule string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(rule, "-A"):
|
||||
return strings.Replace(rule, "-A", "-D", 1)
|
||||
case strings.HasPrefix(rule, "--append"):
|
||||
return strings.Replace(rule, "--append", "-D", 1)
|
||||
case strings.HasPrefix(rule, "-D"):
|
||||
return strings.Replace(rule, "-D", "-A", 1)
|
||||
case strings.HasPrefix(rule, "--delete"):
|
||||
return strings.Replace(rule, "--delete", "-A", 1)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed iptables
|
||||
func (c *configurator) Version(ctx context.Context) (string, error) {
|
||||
output, err := c.commander.Run(ctx, "iptables", "--version")
|
||||
@@ -33,6 +55,8 @@ func (c *configurator) runIptablesInstructions(ctx context.Context, instructions
|
||||
}
|
||||
|
||||
func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
flags := strings.Fields(instruction)
|
||||
if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil {
|
||||
return fmt.Errorf("failed executing \"iptables %s\": %s: %w", instruction, output, err)
|
||||
@@ -40,146 +64,119 @@ func (c *configurator) runIptablesInstruction(ctx context.Context, instruction s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) Clear(ctx context.Context) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("clearing all rules")
|
||||
func (c *configurator) clearAllRules(ctx context.Context) error {
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--flush",
|
||||
"--delete-chain",
|
||||
"--flush", // flush all chains
|
||||
"--delete-chain", // delete all chains
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) AcceptAll(ctx context.Context) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
func (c *configurator) setAllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("policy %q not recognized", policy)
|
||||
}
|
||||
c.logger.Info("accepting all traffic")
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"-P INPUT ACCEPT",
|
||||
"-P OUTPUT ACCEPT",
|
||||
"-P FORWARD ACCEPT",
|
||||
fmt.Sprintf("--policy INPUT %s", policy),
|
||||
fmt.Sprintf("--policy OUTPUT %s", policy),
|
||||
fmt.Sprintf("--policy FORWARD %s", policy),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) BlockAll(ctx context.Context) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("blocking all traffic")
|
||||
func (c *configurator) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *configurator) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"-P INPUT DROP",
|
||||
"-F OUTPUT",
|
||||
"-P OUTPUT DROP",
|
||||
"-P FORWARD DROP",
|
||||
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) CreateGeneralRules(ctx context.Context) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("creating general rules")
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"-A OUTPUT -o lo -j ACCEPT",
|
||||
"-A INPUT -i lo -j ACCEPT",
|
||||
})
|
||||
func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.OpenVPNConnection, remove bool) error {
|
||||
return c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port))
|
||||
}
|
||||
|
||||
func (c *configurator) CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
for _, connection := range connections {
|
||||
c.logger.Info("allowing output traffic to VPN server %s through %s on port %s %d",
|
||||
connection.IP, defaultInterface, connection.Protocol, connection.Port)
|
||||
if err := c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
func (c *configurator) acceptInputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
|
||||
subnetStr := subnet.String()
|
||||
c.logger.Info("accepting input and output traffic for %s", subnetStr)
|
||||
if err := c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
||||
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
||||
}); err != nil {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
|
||||
))
|
||||
}
|
||||
|
||||
// Thanks to @npawelek
|
||||
func (c *configurator) acceptOutputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
|
||||
subnetStr := subnet.String()
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
|
||||
))
|
||||
}
|
||||
|
||||
// Used for port forwarding, with intf set to tun
|
||||
func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
return c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
exists, err := c.fileManager.FileExists(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, extraSubnet := range extraSubnets {
|
||||
extraSubnetStr := extraSubnet.String()
|
||||
c.logger.Info("accepting input traffic through %s from %s to %s", defaultInterface, extraSubnetStr, subnetStr)
|
||||
if err := c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Thanks to @npawelek
|
||||
c.logger.Info("accepting output traffic through %s from %s to %s", defaultInterface, subnetStr, extraSubnetStr)
|
||||
if err := c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Used for port forwarding
|
||||
func (c *configurator) AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error {
|
||||
if c.disabled {
|
||||
} else if !exists {
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("accepting input traffic through %s on port %d", device, port)
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
|
||||
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) AllowAnyIncomingOnPort(ctx context.Context, port uint16) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("accepting any input traffic on port %d", port)
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-A INPUT -p tcp --dport %d -j ACCEPT", port),
|
||||
fmt.Sprintf("-A INPUT -p udp --dport %d -j ACCEPT", port),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error {
|
||||
exists, err := fileManager.FileExists(filepath)
|
||||
b, err := c.fileManager.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
b, err := fileManager.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
lines := strings.Split(string(b), "\n")
|
||||
successfulRules := []string{}
|
||||
defer func() {
|
||||
// transaction-like rollback
|
||||
if err == nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
lines := strings.Split(string(b), "\n")
|
||||
var rules []string
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "iptables ") {
|
||||
continue
|
||||
}
|
||||
rules = append(rules, strings.TrimPrefix(line, "iptables "))
|
||||
c.logger.Info("running user post firewall rule: %s", line)
|
||||
for _, rule := range successfulRules {
|
||||
_ = c.runIptablesInstruction(ctx, flipRule(rule))
|
||||
}
|
||||
return c.runIptablesInstructions(ctx, rules)
|
||||
}()
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "iptables ") {
|
||||
continue
|
||||
}
|
||||
rule := strings.TrimPrefix(line, "iptables ")
|
||||
if remove {
|
||||
rule = flipRule(rule)
|
||||
}
|
||||
if err = c.runIptablesInstruction(ctx, rule); err != nil {
|
||||
return fmt.Errorf("cannot run custom rule: %w", err)
|
||||
}
|
||||
successfulRules = append(successfulRules, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user