From 4a7d341c573b6af9bf6620c209af176d71ac9833 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Jul 2020 02:17:49 +0000 Subject: [PATCH] Fixing extra subnets firewall rules - Fix #194 - Fix #190 - Refers to #188 --- internal/firewall/enable.go | 8 ++++---- internal/firewall/iptables.go | 10 ++++------ internal/firewall/subnets.go | 22 ++++++++++++++-------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index ee65ad70..21b34350 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -102,17 +102,17 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptInputFromToSubnet(ctx, localSubnet, "*", remove); err != nil { + if err := c.acceptInputFromSubnetToSubnet(ctx, "*", localSubnet, localSubnet, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptOutputFromToSubnet(ctx, localSubnet, "*", remove); err != nil { + if err := c.acceptOutputFromSubnetToSubnet(ctx, "*", localSubnet, localSubnet, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } for _, subnet := range c.allowedSubnets { - if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + if err := c.acceptInputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } } diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index a1be52fc..85b56553 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -112,26 +112,24 @@ func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInte appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)) } -func (c *configurator) acceptInputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error { - subnetStr := subnet.String() +func (c *configurator) acceptInputFromSubnetToSubnet(ctx context.Context, intf string, sourceSubnet, destinationSubnet net.IPNet, remove bool) error { interfaceFlag := "-i " + intf if intf == "*" { // all interfaces interfaceFlag = "" } return c.runIptablesInstruction(ctx, fmt.Sprintf( - "%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr, + "%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, sourceSubnet.String(), destinationSubnet.String(), )) } // Thanks to @npawelek -func (c *configurator) acceptOutputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error { - subnetStr := subnet.String() +func (c *configurator) acceptOutputFromSubnetToSubnet(ctx context.Context, intf string, sourceSubnet, destinationSubnet net.IPNet, remove bool) error { interfaceFlag := "-o " + intf if intf == "*" { // all interfaces interfaceFlag = "" } return c.runIptablesInstruction(ctx, fmt.Sprintf( - "%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr, + "%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, sourceSubnet.String(), destinationSubnet.String(), )) } diff --git a/internal/firewall/subnets.go b/internal/firewall/subnets.go index fbb58523..9fa25222 100644 --- a/internal/firewall/subnets.go +++ b/internal/firewall/subnets.go @@ -32,9 +32,13 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe if err != nil { return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) } + localSubnet, err := c.routing.LocalSubnet() + if err != nil { + return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) + } - c.removeSubnets(ctx, subnetsToRemove, defaultInterface) - if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway); err != nil { + c.removeSubnets(ctx, subnetsToRemove, defaultInterface, localSubnet) + if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway, localSubnet); err != nil { return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) } @@ -89,15 +93,16 @@ func removeSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet return subnets } -func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string) { +func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, + localSubnet net.IPNet) { const remove = true for _, subnet := range subnets { failed := false - if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + if err := c.acceptInputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil { failed = true c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err) } - if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil { failed = true c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err) } @@ -112,13 +117,14 @@ func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, d } } -func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, defaultGateway net.IP) error { +func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, + defaultGateway net.IP, localSubnet net.IPNet) error { const remove = false for _, subnet := range subnets { - if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + if err := c.acceptInputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil { return fmt.Errorf("cannot add allowed subnet through firewall: %w", err) } - if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil { return fmt.Errorf("cannot add allowed subnet through firewall: %w", err) } if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {