diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 768edbb2..c4f0c4ec 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -17,13 +17,10 @@ import ( "github.com/qdm12/golibs/network" "github.com/qdm12/private-internet-access-docker/internal/alpine" "github.com/qdm12/private-internet-access-docker/internal/cli" - "github.com/qdm12/private-internet-access-docker/internal/constants" "github.com/qdm12/private-internet-access-docker/internal/dns" "github.com/qdm12/private-internet-access-docker/internal/firewall" - "github.com/qdm12/private-internet-access-docker/internal/models" "github.com/qdm12/private-internet-access-docker/internal/openvpn" "github.com/qdm12/private-internet-access-docker/internal/params" - "github.com/qdm12/private-internet-access-docker/internal/provider" "github.com/qdm12/private-internet-access-docker/internal/publicip" "github.com/qdm12/private-internet-access-docker/internal/routing" "github.com/qdm12/private-internet-access-docker/internal/server" @@ -72,8 +69,8 @@ func _main(background context.Context, args []string) int { alpineConf := alpine.NewConfigurator(fileManager) ovpnConf := openvpn.NewConfigurator(logger, fileManager) dnsConf := dns.NewConfigurator(logger, client, fileManager) - firewallConf := firewall.NewConfigurator(logger) routingConf := routing.NewRouting(logger, fileManager) + firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager) tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger) shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger) streamMerger := command.NewStreamMerger() @@ -93,12 +90,6 @@ func _main(background context.Context, args []string) int { // Should never change uid, gid := allSettings.System.UID, allSettings.System.GID - providerConf := provider.New(allSettings.VPNSP, logger, client, fileManager, firewallConf) - - if !allSettings.Firewall.Enabled { - firewallConf.Disable() - } - err = alpineConf.CreateUser("nonrootuser", uid) fatalOnError(err) err = fileManager.SetOwnership("/etc/unbound", uid, gid) @@ -112,17 +103,6 @@ func _main(background context.Context, args []string) int { fatalOnError(err) } - defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute() - fatalOnError(err) - - // Temporarily reset chain policies allowing Kubernetes sidecar to - // successfully restart the container. Without this, the existing rules will - // pre-exist, preventing the nslookup of the PIA region address. These will - // simply be redundant at Docker runtime as they will already be set this way - // Thanks to @npawelek https://github.com/npawelek - err = firewallConf.AcceptAll(ctx) - fatalOnError(err) - connectedCh := make(chan struct{}) signalConnected := func() { connectedCh <- struct{}{} @@ -130,44 +110,23 @@ func _main(background context.Context, args []string) int { defer close(connectedCh) go collectStreamLines(ctx, streamMerger, logger, signalConnected) - connections, err := providerConf.GetOpenVPNConnections(allSettings.OpenVPN.Provider.ServerSelection) - fatalOnError(err) - err = providerConf.BuildConf( - connections, - allSettings.OpenVPN.Verbosity, - uid, - gid, - allSettings.OpenVPN.Root, - allSettings.OpenVPN.Cipher, - allSettings.OpenVPN.Auth, - allSettings.OpenVPN.Provider.ExtraConfigOptions, - ) - fatalOnError(err) - - err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface) - fatalOnError(err) - err = firewallConf.Clear(ctx) - fatalOnError(err) - err = firewallConf.BlockAll(ctx) - fatalOnError(err) - err = firewallConf.CreateGeneralRules(ctx) - fatalOnError(err) - err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections) - fatalOnError(err) - err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface) - fatalOnError(err) - err = firewallConf.RunUserPostRules(ctx, fileManager, "/iptables/post-rules.txt") - fatalOnError(err) - + // TODO replace these with methods on loopers and pass loopers around restartOpenvpn := make(chan struct{}) + portForward := make(chan struct{}) restartUnbound := make(chan struct{}) restartPublicIP := make(chan struct{}) restartTinyproxy := make(chan struct{}) restartShadowsocks := make(chan struct{}) - openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError, uid, gid) + if allSettings.Firewall.Enabled { + err := firewallConf.SetEnabled(ctx, true) // disabled by default + fatalOnError(err) + } + + openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, + ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError) // wait for restartOpenvpn - go openvpnLooper.Run(ctx, restartOpenvpn, wg) + go openvpnLooper.Run(ctx, restartOpenvpn, portForward, wg) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid) // wait for restartUnbound @@ -191,7 +150,6 @@ func _main(background context.Context, args []string) int { } go func() { - first := true var restartTickerContext context.Context var restartTickerCancel context.CancelFunc = func() {} for { @@ -200,14 +158,10 @@ func _main(background context.Context, args []string) int { restartTickerCancel() return case <-connectedCh: // blocks until openvpn is connected - if first { - first = false - restartUnbound <- struct{}{} - } restartTickerCancel() restartTickerContext, restartTickerCancel = context.WithCancel(ctx) go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound) - onConnected(allSettings, logger, routingConf, defaultInterface, providerConf, restartPublicIP) + onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP) } } }() @@ -224,11 +178,10 @@ func _main(background context.Context, args []string) int { syscall.SIGTERM, os.Interrupt, ) - exitStatus := 0 + shutdownErrorsCount := 0 select { case signal := <-signalsCh: logger.Warn("Caught OS signal %s, shutting down", signal) - exitStatus = 1 cancel() case <-ctx.Done(): logger.Warn("context canceled, shutting down") @@ -236,20 +189,37 @@ func _main(background context.Context, args []string) int { logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath) if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil { logger.Error(err) - exitStatus = 1 + shutdownErrorsCount++ } if allSettings.OpenVPN.Provider.PortForwarding.Enabled { logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath) if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { logger.Error(err) - exitStatus = 1 + shutdownErrorsCount++ } } - wg.Wait() - return exitStatus + waiting, waited := context.WithTimeout(context.Background(), time.Second) + go func() { + defer waited() + wg.Wait() + }() + <-waiting.Done() + if waiting.Err() == context.DeadlineExceeded { + if shutdownErrorsCount > 0 { + logger.Warn("Shutdown had %d errors", shutdownErrorsCount) + } + logger.Warn("Shutdown timed out") + return 1 + } + if shutdownErrorsCount > 0 { + logger.Warn("Shutdown had %d errors") + return 1 + } + logger.Info("Shutdown successful") + return 0 } -func makeFatalOnError(logger logging.Logger, cancel func(), wg *sync.WaitGroup) func(err error) { +func makeFatalOnError(logger logging.Logger, cancel context.CancelFunc, wg *sync.WaitGroup) func(err error) { return func(err error) { if err != nil { logger.Error(err) @@ -321,48 +291,25 @@ func trimEventualProgramPrefix(s string) string { } } -func onConnected(allSettings settings.Settings, - logger logging.Logger, routingConf routing.Routing, defaultInterface string, - providerConf provider.Provider, restartPublicIP chan<- struct{}, +func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing, + portForward, restartUnbound, restartPublicIP chan<- struct{}, ) { + restartUnbound <- struct{}{} restartPublicIP <- struct{}{} - uid, gid := allSettings.System.UID, allSettings.System.GID if allSettings.OpenVPN.Provider.PortForwarding.Enabled { time.AfterFunc(5*time.Second, func() { - setupPortForwarding(logger, providerConf, allSettings.OpenVPN.Provider.PortForwarding.Filepath, uid, gid) + portForward <- struct{}{} }) } - - vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface) + defaultInterface, _, _, err := routingConf.DefaultRoute() if err != nil { logger.Warn(err) } else { - logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) - } -} - -func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) { - pfLogger := logger.WithPrefix("port forwarding: ") - var port uint16 - var err error - for { - port, err = providerConf.GetPortForward() + vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface) if err != nil { - pfLogger.Error(err) - pfLogger.Info("retrying in 5 seconds...") - time.Sleep(5 * time.Second) + logger.Warn(err) } else { - pfLogger.Info("port forwarded is %d", port) - break + logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) } } - pfLogger.Info("writing forwarded port to %s", filepath) - if err := providerConf.WritePortForward(filepath, port, uid, gid); err != nil { - pfLogger.Error(err) - } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if err := providerConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil { - pfLogger.Error(err) - } } diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go new file mode 100644 index 00000000..8a201246 --- /dev/null +++ b/internal/firewall/enable.go @@ -0,0 +1,149 @@ +package firewall + +import ( + "context" + "fmt" + + "github.com/qdm12/private-internet-access-docker/internal/constants" +) + +func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + if enabled == c.enabled { + if enabled { + c.logger.Info("already enabled") + } else { + c.logger.Info("already disabled") + } + return nil + } + + if !enabled { + c.logger.Info("disabling...") + if err = c.disable(ctx); err != nil { + return err + } + c.enabled = false + c.logger.Info("disabled successfully") + return nil + } + + c.logger.Info("enabling...") + + if err := c.enable(ctx); err != nil { + return err + } + c.enabled = true + c.logger.Info("enabled successfully") + + return nil +} + +func (c *configurator) disable(ctx context.Context) (err error) { + if err = c.clearAllRules(ctx); err != nil { + return fmt.Errorf("cannot disable firewall: %w", err) + } + if err = c.setAllPolicies(ctx, "ACCEPT"); err != nil { + return fmt.Errorf("cannot disable firewall: %w", err) + } + // TODO routes? + return nil +} + +// To use in defered call when enabling the firewall +func (c *configurator) fallbackToDisabled(ctx context.Context) { + if ctx.Err() != nil { + return + } + if err := c.SetEnabled(ctx, false); err != nil { + c.logger.Error(err) + } +} + +func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit + defaultInterface, defaultGateway, defaultSubnet, err := c.routing.DefaultRoute() + if err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + + fmt.Println(1) + if err = c.setAllPolicies(ctx, "DROP"); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + + const remove = false + + defer func() { + if err != nil { + c.fallbackToDisabled(ctx) + } + }() + + // Loopback traffic + if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + + if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + for _, conn := range c.vpnConnections { + if err = c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + } + if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + if err := c.acceptInputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + if err := c.acceptOutputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + for _, subnet := range c.allowedSubnets { + if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + } + // Re-ensure all routes exist + for _, subnet := range c.allowedSubnets { + if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + } + + for port := range c.allowedPorts { + // TODO restrict interface + if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + } + + if c.portForwarded > 0 { + const tun = string(constants.TUN) + if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + } + + if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + + return nil +} diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 326d52f2..805ad8c2 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -3,42 +3,49 @@ package firewall import ( "context" "net" + "sync" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" "github.com/qdm12/private-internet-access-docker/internal/models" + "github.com/qdm12/private-internet-access-docker/internal/routing" ) // Configurator allows to change firewall rules and modify network routes type Configurator interface { Version(ctx context.Context) (string, error) - AcceptAll(ctx context.Context) error - Clear(ctx context.Context) error - BlockAll(ctx context.Context) error - CreateGeneralRules(ctx context.Context) error - CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error - CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error - AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error - AllowAnyIncomingOnPort(ctx context.Context, port uint16) error - RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error - Disable() + SetEnabled(ctx context.Context, enabled bool) (err error) + SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) + SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) + SetAllowedPort(ctx context.Context, port uint16) error + RemoveAllowedPort(ctx context.Context, port uint16) (err error) + SetPortForward(ctx context.Context, port uint16) (err error) } -type configurator struct { - commander command.Commander - logger logging.Logger - disabled bool +type configurator struct { //nolint:maligned + commander command.Commander + logger logging.Logger + routing routing.Routing + fileManager files.FileManager // for custom iptables rules + iptablesMutex sync.Mutex + + // State + enabled bool + vpnConnections []models.OpenVPNConnection + allowedSubnets []net.IPNet + allowedPorts map[uint16]struct{} + portForwarded uint16 + stateMutex sync.Mutex } // NewConfigurator creates a new Configurator instance -func NewConfigurator(logger logging.Logger) Configurator { +func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator { return &configurator{ - commander: command.NewCommander(), - logger: logger.WithPrefix("firewall configurator: "), + commander: command.NewCommander(), + logger: logger.WithPrefix("firewall: "), + routing: routing, + fileManager: fileManager, + allowedPorts: make(map[uint16]struct{}), } } - -func (c *configurator) Disable() { - c.disabled = true -} diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index f9bb80a3..e26d6b0b 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -6,10 +6,32 @@ import ( "net" "strings" - "github.com/qdm12/golibs/files" "github.com/qdm12/private-internet-access-docker/internal/models" ) +func appendOrDelete(remove bool) string { + if remove { + return "--delete" + } + return "--append" +} + +// flipRule changes an append rule in a delete rule or a delete rule into an +// append rule. +func flipRule(rule string) string { + switch { + case strings.HasPrefix(rule, "-A"): + return strings.Replace(rule, "-A", "-D", 1) + case strings.HasPrefix(rule, "--append"): + return strings.Replace(rule, "--append", "-D", 1) + case strings.HasPrefix(rule, "-D"): + return strings.Replace(rule, "-D", "-A", 1) + case strings.HasPrefix(rule, "--delete"): + return strings.Replace(rule, "--delete", "-A", 1) + } + return rule +} + // Version obtains the version of the installed iptables func (c *configurator) Version(ctx context.Context) (string, error) { output, err := c.commander.Run(ctx, "iptables", "--version") @@ -33,6 +55,8 @@ func (c *configurator) runIptablesInstructions(ctx context.Context, instructions } func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error { + c.iptablesMutex.Lock() // only one iptables command at once + defer c.iptablesMutex.Unlock() flags := strings.Fields(instruction) if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil { return fmt.Errorf("failed executing \"iptables %s\": %s: %w", instruction, output, err) @@ -40,146 +64,119 @@ func (c *configurator) runIptablesInstruction(ctx context.Context, instruction s return nil } -func (c *configurator) Clear(ctx context.Context) error { - if c.disabled { - return nil - } - c.logger.Info("clearing all rules") +func (c *configurator) clearAllRules(ctx context.Context) error { return c.runIptablesInstructions(ctx, []string{ - "--flush", - "--delete-chain", + "--flush", // flush all chains + "--delete-chain", // delete all chains }) } -func (c *configurator) AcceptAll(ctx context.Context) error { - if c.disabled { - return nil +func (c *configurator) setAllPolicies(ctx context.Context, policy string) error { + switch policy { + case "ACCEPT", "DROP": + default: + return fmt.Errorf("policy %q not recognized", policy) } - c.logger.Info("accepting all traffic") return c.runIptablesInstructions(ctx, []string{ - "-P INPUT ACCEPT", - "-P OUTPUT ACCEPT", - "-P FORWARD ACCEPT", + fmt.Sprintf("--policy INPUT %s", policy), + fmt.Sprintf("--policy OUTPUT %s", policy), + fmt.Sprintf("--policy FORWARD %s", policy), }) } -func (c *configurator) BlockAll(ctx context.Context) error { - if c.disabled { - return nil - } - c.logger.Info("blocking all traffic") +func (c *configurator) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error { + return c.runIptablesInstruction(ctx, fmt.Sprintf( + "%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf, + )) +} + +func (c *configurator) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error { + return c.runIptablesInstruction(ctx, fmt.Sprintf( + "%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf, + )) +} + +func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error { return c.runIptablesInstructions(ctx, []string{ - "-P INPUT DROP", - "-F OUTPUT", - "-P OUTPUT DROP", - "-P FORWARD DROP", + fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), + fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), }) } -func (c *configurator) CreateGeneralRules(ctx context.Context) error { - if c.disabled { - return nil - } - c.logger.Info("creating general rules") - return c.runIptablesInstructions(ctx, []string{ - "-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", - "-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", - "-A OUTPUT -o lo -j ACCEPT", - "-A INPUT -i lo -j ACCEPT", - }) +func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.OpenVPNConnection, remove bool) error { + return c.runIptablesInstruction(ctx, + fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT", + appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)) } -func (c *configurator) CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error { - if c.disabled { - return nil - } - for _, connection := range connections { - c.logger.Info("allowing output traffic to VPN server %s through %s on port %s %d", - connection.IP, defaultInterface, connection.Protocol, connection.Port) - if err := c.runIptablesInstruction(ctx, - fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT", - connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)); err != nil { - return err - } - } - if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil { - return err - } - return nil -} - -func (c *configurator) CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error { - if c.disabled { - return nil - } +func (c *configurator) acceptInputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error { subnetStr := subnet.String() - c.logger.Info("accepting input and output traffic for %s", subnetStr) - if err := c.runIptablesInstructions(ctx, []string{ - fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr), - fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr), - }); err != nil { + interfaceFlag := "-i " + intf + if intf == "*" { // all interfaces + interfaceFlag = "" + } + return c.runIptablesInstruction(ctx, fmt.Sprintf( + "%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr, + )) +} + +// Thanks to @npawelek +func (c *configurator) acceptOutputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error { + subnetStr := subnet.String() + interfaceFlag := "-o " + intf + if intf == "*" { // all interfaces + interfaceFlag = "" + } + return c.runIptablesInstruction(ctx, fmt.Sprintf( + "%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr, + )) +} + +// Used for port forwarding, with intf set to tun +func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error { + interfaceFlag := "-i " + intf + if intf == "*" { // all interfaces + interfaceFlag = "" + } + return c.runIptablesInstruction(ctx, + fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port), + ) +} + +func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error { + exists, err := c.fileManager.FileExists(filepath) + if err != nil { return err - } - for _, extraSubnet := range extraSubnets { - extraSubnetStr := extraSubnet.String() - c.logger.Info("accepting input traffic through %s from %s to %s", defaultInterface, extraSubnetStr, subnetStr) - if err := c.runIptablesInstruction(ctx, - fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil { - return err - } - // Thanks to @npawelek - c.logger.Info("accepting output traffic through %s from %s to %s", defaultInterface, subnetStr, extraSubnetStr) - if err := c.runIptablesInstruction(ctx, - fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil { - return err - } - } - return nil -} - -// Used for port forwarding -func (c *configurator) AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error { - if c.disabled { + } else if !exists { return nil } - c.logger.Info("accepting input traffic through %s on port %d", device, port) - return c.runIptablesInstructions(ctx, []string{ - fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port), - fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port), - }) -} - -func (c *configurator) AllowAnyIncomingOnPort(ctx context.Context, port uint16) error { - if c.disabled { - return nil - } - c.logger.Info("accepting any input traffic on port %d", port) - return c.runIptablesInstructions(ctx, []string{ - fmt.Sprintf("-A INPUT -p tcp --dport %d -j ACCEPT", port), - fmt.Sprintf("-A INPUT -p udp --dport %d -j ACCEPT", port), - }) -} - -func (c *configurator) RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error { - exists, err := fileManager.FileExists(filepath) + b, err := c.fileManager.ReadFile(filepath) if err != nil { return err } - if exists { - b, err := fileManager.ReadFile(filepath) - if err != nil { - return err + lines := strings.Split(string(b), "\n") + successfulRules := []string{} + defer func() { + // transaction-like rollback + if err == nil || ctx.Err() != nil { + return } - lines := strings.Split(string(b), "\n") - var rules []string - for _, line := range lines { - if !strings.HasPrefix(line, "iptables ") { - continue - } - rules = append(rules, strings.TrimPrefix(line, "iptables ")) - c.logger.Info("running user post firewall rule: %s", line) + for _, rule := range successfulRules { + _ = c.runIptablesInstruction(ctx, flipRule(rule)) } - return c.runIptablesInstructions(ctx, rules) + }() + for _, line := range lines { + if !strings.HasPrefix(line, "iptables ") { + continue + } + rule := strings.TrimPrefix(line, "iptables ") + if remove { + rule = flipRule(rule) + } + if err = c.runIptablesInstruction(ctx, rule); err != nil { + return fmt.Errorf("cannot run custom rule: %w", err) + } + successfulRules = append(successfulRules, rule) } return nil } diff --git a/internal/firewall/ports.go b/internal/firewall/ports.go new file mode 100644 index 00000000..cd92c8ce --- /dev/null +++ b/internal/firewall/ports.go @@ -0,0 +1,109 @@ +package firewall + +import ( + "context" + "fmt" + + "github.com/qdm12/private-internet-access-docker/internal/constants" +) + +func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err error) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + if port == 0 { + return nil + } + + if !c.enabled { + c.logger.Info("firewall disabled, only updating allowed ports internal list") + c.allowedPorts[port] = struct{}{} + return nil + } + + c.logger.Info("setting allowed port %d through firewall...", port) + + if _, ok := c.allowedPorts[port]; ok { + return nil + } + + const remove = false + if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil { + return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err) + } + if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil { + return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err) + } + c.allowedPorts[port] = struct{}{} + + return nil +} + +func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err error) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + if port == 0 { + return nil + } + + if !c.enabled { + c.logger.Info("firewall disabled, only updating allowed ports internal list") + delete(c.allowedPorts, port) + return nil + } + + c.logger.Info("removing allowed port %d through firewall...", port) + + if _, ok := c.allowedPorts[port]; !ok { + return nil + } + + const remove = true + if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil { + return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err) + } + if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil { + return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err) + } + delete(c.allowedPorts, port) + + return nil +} + +// Use 0 to remove +func (c *configurator) SetPortForward(ctx context.Context, port uint16) (err error) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + if port == c.portForwarded { + return nil + } + + if !c.enabled { + c.logger.Info("firewall disabled, only updating port forwarded internally") + c.portForwarded = port + return nil + } + + const tun = string(constants.TUN) + if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, true); err != nil { + return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err) + } + if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, true); err != nil { + return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err) + } + + if port == 0 { // not changing port + c.portForwarded = 0 + return nil + } + + if err := c.acceptInputToPort(ctx, tun, constants.TCP, port, false); err != nil { + return fmt.Errorf("cannot accept port forwarded through firewall: %w", err) + } + if err := c.acceptInputToPort(ctx, tun, constants.UDP, port, false); err != nil { + return fmt.Errorf("cannot accept port forwarded through firewall: %w", err) + } + return nil +} diff --git a/internal/firewall/subnets.go b/internal/firewall/subnets.go new file mode 100644 index 00000000..5160c5ac --- /dev/null +++ b/internal/firewall/subnets.go @@ -0,0 +1,127 @@ +package firewall + +import ( + "context" + "fmt" + "net" +) + +func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + if !c.enabled { + c.logger.Info("firewall disabled, only updating allowed subnets internal list") + c.allowedSubnets = make([]net.IPNet, len(subnets)) + copy(c.allowedSubnets, subnets) + return nil + } + + c.logger.Info("setting allowed subnets through firewall...") + + subnetsToAdd := findSubnetsToAdd(c.allowedSubnets, subnets) + subnetsToRemove := findSubnetsToRemove(c.allowedSubnets, subnets) + if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 { + return nil + } + + defaultInterface, defaultGateway, _, err := c.routing.DefaultRoute() + if err != nil { + return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) + } + + c.removeSubnets(ctx, subnetsToRemove, defaultInterface) + if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway); err != nil { + return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) + } + + return nil +} + +func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) { + for _, newSubnet := range newSubnets { + found := false + for _, oldSubnet := range oldSubnets { + if subnetsAreEqual(oldSubnet, newSubnet) { + found = true + break + } + } + if !found { + subnetsToAdd = append(subnetsToAdd, newSubnet) + } + } + return subnetsToAdd +} + +func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) { + for _, oldSubnet := range oldSubnets { + found := false + for _, newSubnet := range newSubnets { + if subnetsAreEqual(oldSubnet, newSubnet) { + found = true + break + } + } + if !found { + subnetsToRemove = append(subnetsToRemove, oldSubnet) + } + } + return subnetsToRemove +} + +func subnetsAreEqual(a, b net.IPNet) bool { + return a.IP.Equal(b.IP) && a.Mask.String() == b.Mask.String() +} + +func removeSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet { + L := len(subnets) + for i := range subnets { + if subnetsAreEqual(subnet, subnets[i]) { + subnets[i] = subnets[L-1] + subnets = subnets[:L-1] + break + } + } + return subnets +} + +func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string) { + const remove = true + for _, subnet := range subnets { + failed := false + if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + failed = true + c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err) + } + if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + failed = true + c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err) + } + if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil { + failed = true + c.logger.Error("cannot remove outdated allowed subnet route: %s", err) + } + if failed { + continue + } + c.allowedSubnets = removeSubnetFromSubnets(c.allowedSubnets, subnet) + } +} + +func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, defaultGateway net.IP) error { + const remove = false + for _, subnet := range subnets { + if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + return fmt.Errorf("cannot add allowed subnet through firewall: %w", err) + } + if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil { + return fmt.Errorf("cannot add allowed subnet through firewall: %w", err) + } + if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil { + return fmt.Errorf("cannot add route for allowed subnet: %w", err) + } + c.allowedSubnets = append(c.allowedSubnets, subnet) + } + return nil +} diff --git a/internal/firewall/vpn.go b/internal/firewall/vpn.go new file mode 100644 index 00000000..1428c153 --- /dev/null +++ b/internal/firewall/vpn.go @@ -0,0 +1,112 @@ +package firewall + +import ( + "context" + "fmt" + + "github.com/qdm12/private-internet-access-docker/internal/constants" + "github.com/qdm12/private-internet-access-docker/internal/models" +) + +func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + if !c.enabled { + c.logger.Info("firewall disabled, only updating VPN connections internal list") + c.vpnConnections = make([]models.OpenVPNConnection, len(connections)) + copy(c.vpnConnections, connections) + return nil + } + + c.logger.Info("setting VPN connections through firewall...") + + connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections) + connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections) + if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 { + return nil + } + + defaultInterface, _, _, err := c.routing.DefaultRoute() + if err != nil { + return fmt.Errorf("cannot set VPN connections through firewall: %w", err) + } + + // TODO remove elsewhere? + if err := c.acceptOutputThroughInterface(ctx, string(constants.TUN), false); err != nil { + return fmt.Errorf("cannot allow traffic through tunnel: %w", err) + } + + c.removeConnections(ctx, connectionsToRemove, defaultInterface) + if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil { + return fmt.Errorf("cannot set VPN connections through firewall: %w", err) + } + + return nil +} + +func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection { + L := len(connections) + for i := range connections { + if connection.Equal(connections[i]) { + connections[i] = connections[L-1] + connections = connections[:L-1] + break + } + } + return connections +} + +func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) { + for _, newConnection := range newConnections { + found := false + for _, oldConnection := range oldConnections { + if oldConnection.Equal(newConnection) { + found = true + break + } + } + if !found { + connectionsToAdd = append(connectionsToAdd, newConnection) + } + } + return connectionsToAdd +} + +func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) { + for _, oldConnection := range oldConnections { + found := false + for _, newConnection := range newConnections { + if oldConnection.Equal(newConnection) { + found = true + break + } + } + if !found { + connectionsToRemove = append(connectionsToRemove, oldConnection) + } + } + return connectionsToRemove +} + +func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) { + for _, conn := range connections { + const remove = true + if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil { + c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err) + continue + } + c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn) + } +} + +func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error { + const remove = false + for _, conn := range connections { + if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil { + return err + } + c.vpnConnections = append(c.vpnConnections, conn) + } + return nil +} diff --git a/internal/models/openvpn.go b/internal/models/openvpn.go index 837f04fa..6f20cd91 100644 --- a/internal/models/openvpn.go +++ b/internal/models/openvpn.go @@ -7,3 +7,7 @@ type OpenVPNConnection struct { Port uint16 Protocol NetworkProtocol } + +func (o *OpenVPNConnection) Equal(other OpenVPNConnection) bool { + return o.IP.Equal(other.IP) && o.Port == other.Port && o.Protocol == other.Protocol +} diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 753a5134..5e2db08e 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -2,43 +2,64 @@ package openvpn import ( "context" + "fmt" "sync" "time" "github.com/qdm12/golibs/command" + "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" + "github.com/qdm12/golibs/network" "github.com/qdm12/private-internet-access-docker/internal/constants" + "github.com/qdm12/private-internet-access-docker/internal/firewall" + "github.com/qdm12/private-internet-access-docker/internal/models" + "github.com/qdm12/private-internet-access-docker/internal/provider" "github.com/qdm12/private-internet-access-docker/internal/settings" ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) + Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) } type looper struct { - conf Configurator - settings settings.OpenVPN + // Variable parameters + provider models.VPNProvider + settings settings.OpenVPN + // Fixed parameters + uid int + gid int + // Configurators + conf Configurator + fw firewall.Configurator + // Other objects logger logging.Logger + client network.Client + fileManager files.FileManager streamMerger command.StreamMerger fatalOnError func(err error) - uid int - gid int } -func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger, - streamMerger command.StreamMerger, fatalOnError func(err error), uid, gid int) Looper { +func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, + uid, gid int, + conf Configurator, fw firewall.Configurator, + logger logging.Logger, client network.Client, fileManager files.FileManager, + streamMerger command.StreamMerger, fatalOnError func(err error)) Looper { return &looper{ - conf: conf, + provider: provider, settings: settings, - logger: logger.WithPrefix("openvpn: "), - streamMerger: streamMerger, - fatalOnError: fatalOnError, uid: uid, gid: gid, + conf: conf, + fw: fw, + logger: logger.WithPrefix("openvpn: "), + client: client, + fileManager: fileManager, + streamMerger: streamMerger, + fatalOnError: fatalOnError, } } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { +func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() select { @@ -46,17 +67,51 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait case <-ctx.Done(): return } - for { - openvpnCtx, openvpnCancel := context.WithCancel(ctx) - err := l.conf.WriteAuthFile( - l.settings.User, - l.settings.Password, + defer l.logger.Warn("loop exited") + + for ctx.Err() == nil { + providerConf := provider.New(l.provider, l.client, l.fileManager) + connections, err := providerConf.GetOpenVPNConnections(l.settings.Provider.ServerSelection) + l.fatalOnError(err) + err = providerConf.BuildConf( + connections, + l.settings.Verbosity, l.uid, l.gid, + l.settings.Root, + l.settings.Cipher, + l.settings.Auth, + l.settings.Provider.ExtraConfigOptions, ) l.fatalOnError(err) - stream, waitFn, err := l.conf.Start(openvpnCtx) + + err = l.conf.WriteAuthFile(l.settings.User, l.settings.Password, l.uid, l.gid) l.fatalOnError(err) + + if err := l.fw.SetVPNConnections(ctx, connections); err != nil { + l.fatalOnError(err) + } + + openvpnCtx, openvpnCancel := context.WithCancel(context.Background()) + + stream, waitFn, err := l.conf.Start(openvpnCtx) + if err != nil { + openvpnCancel() + l.logAndWait(ctx, err) + continue + } + + go func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-portForward: + l.portForward(ctx, providerConf) + } + } + }(openvpnCtx) + go l.streamMerger.Merge(openvpnCtx, stream, command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn())) waitError := make(chan error) @@ -74,13 +129,53 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait case <-restart: // triggered restart l.logger.Info("restarting") openvpnCancel() + <-waitError close(waitError) case err := <-waitError: // unexpected error - l.logger.Warn(err) - l.logger.Info("restarting") openvpnCancel() close(waitError) - time.Sleep(time.Second) + l.logAndWait(ctx, err) } } } + +func (l *looper) logAndWait(ctx context.Context, err error) { + l.logger.Error(err) + l.logger.Info("retrying in 30 seconds") + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() // just for the linter + <-ctx.Done() +} + +func (l *looper) portForward(ctx context.Context, providerConf provider.Provider) { + if !l.settings.Provider.PortForwarding.Enabled { + return + } + var port uint16 + err := fmt.Errorf("") + for err != nil { + if ctx.Err() != nil { + return + } + port, err = providerConf.GetPortForward() + if err != nil { + l.logAndWait(ctx, err) + continue + } + l.logger.Info("port forwarded is %d", port) + } + + filepath := l.settings.Provider.PortForwarding.Filepath + l.logger.Info("writing forwarded port to %s", filepath) + err = l.fileManager.WriteLinesToFile( + string(filepath), []string{fmt.Sprintf("%d", port)}, + files.Ownership(l.uid, l.gid), files.Permissions(0400), + ) + if err != nil { + l.logger.Error(err) + } + + if err := l.fw.SetPortForward(ctx, port); err != nil { + l.logger.Error(err) + } +} diff --git a/internal/provider/cyberghost.go b/internal/provider/cyberghost.go index b9c1e2fc..fbe853ac 100644 --- a/internal/provider/cyberghost.go +++ b/internal/provider/cyberghost.go @@ -1,7 +1,6 @@ package provider import ( - "context" "fmt" "net" "strings" @@ -125,11 +124,3 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity func (c *cyberghost) GetPortForward() (port uint16, err error) { panic("port forwarding is not supported for cyberghost") } - -func (c *cyberghost) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) { - panic("port forwarding is not supported for cyberghost") -} - -func (c *cyberghost) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) { - panic("port forwarding is not supported for cyberghost") -} diff --git a/internal/provider/mullvad.go b/internal/provider/mullvad.go index 44d02cd5..91e63623 100644 --- a/internal/provider/mullvad.go +++ b/internal/provider/mullvad.go @@ -1,24 +1,20 @@ package provider import ( - "context" "fmt" "github.com/qdm12/golibs/files" - "github.com/qdm12/golibs/logging" "github.com/qdm12/private-internet-access-docker/internal/constants" "github.com/qdm12/private-internet-access-docker/internal/models" ) type mullvad struct { fileManager files.FileManager - logger logging.Logger } -func newMullvad(fileManager files.FileManager, logger logging.Logger) *mullvad { +func newMullvad(fileManager files.FileManager) *mullvad { return &mullvad{ fileManager: fileManager, - logger: logger.WithPrefix("Mullvad configurator: "), } } @@ -106,11 +102,3 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u func (m *mullvad) GetPortForward() (port uint16, err error) { panic("port forwarding is not supported for mullvad") } - -func (m *mullvad) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) { - panic("port forwarding is not supported for mullvad") -} - -func (m *mullvad) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) { - panic("port forwarding is not supported for mullvad") -} diff --git a/internal/provider/pia.go b/internal/provider/pia.go index dd9cec05..4fa55455 100644 --- a/internal/provider/pia.go +++ b/internal/provider/pia.go @@ -1,7 +1,6 @@ package provider import ( - "context" "encoding/hex" "encoding/json" "fmt" @@ -14,24 +13,21 @@ import ( "github.com/qdm12/golibs/network" "github.com/qdm12/golibs/verification" "github.com/qdm12/private-internet-access-docker/internal/constants" - "github.com/qdm12/private-internet-access-docker/internal/firewall" "github.com/qdm12/private-internet-access-docker/internal/models" ) type pia struct { client network.Client fileManager files.FileManager - firewall firewall.Configurator random random.Random verifyPort func(port string) error lookupIP func(host string) ([]net.IP, error) } -func newPrivateInternetAccess(client network.Client, fileManager files.FileManager, firewall firewall.Configurator) *pia { +func newPrivateInternetAccess(client network.Client, fileManager files.FileManager) *pia { return &pia{ client: client, fileManager: fileManager, - firewall: firewall, random: random.NewRandom(), verifyPort: verification.NewVerifier().VerifyPort, lookupIP: net.LookupIP} @@ -168,7 +164,7 @@ func (p *pia) GetPortForward() (port uint16, err error) { } clientID := hex.EncodeToString(b) url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID) - content, status, err := p.client.GetContent(url) + content, status, err := p.client.GetContent(url) // TODO add ctx switch { case err != nil: return 0, err @@ -185,15 +181,3 @@ func (p *pia) GetPortForward() (port uint16, err error) { } return body.Port, nil } - -func (p *pia) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) { - return p.fileManager.WriteLinesToFile( - string(filepath), - []string{fmt.Sprintf("%d", port)}, - files.Ownership(uid, gid), - files.Permissions(0400)) -} - -func (p *pia) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) { - return p.firewall.AllowInputTrafficOnPort(ctx, device, port) -} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 61380f76..05783a87 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -1,13 +1,9 @@ package provider import ( - "context" - "github.com/qdm12/golibs/files" - "github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/network" "github.com/qdm12/private-internet-access-docker/internal/constants" - "github.com/qdm12/private-internet-access-docker/internal/firewall" "github.com/qdm12/private-internet-access-docker/internal/models" ) @@ -16,16 +12,14 @@ type Provider interface { GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (err error) GetPortForward() (port uint16, err error) - WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) - AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) } -func New(provider models.VPNProvider, logger logging.Logger, client network.Client, fileManager files.FileManager, firewall firewall.Configurator) Provider { +func New(provider models.VPNProvider, client network.Client, fileManager files.FileManager) Provider { switch provider { case constants.PrivateInternetAccess: - return newPrivateInternetAccess(client, fileManager, firewall) + return newPrivateInternetAccess(client, fileManager) case constants.Mullvad: - return newMullvad(fileManager, logger) + return newMullvad(fileManager) case constants.Windscribe: return newWindscribe(fileManager) case constants.Surfshark: diff --git a/internal/provider/surfshark.go b/internal/provider/surfshark.go index 6ced928e..f3071e10 100644 --- a/internal/provider/surfshark.go +++ b/internal/provider/surfshark.go @@ -1,7 +1,6 @@ package provider import ( - "context" "fmt" "net" "strings" @@ -127,11 +126,3 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity, func (s *surfshark) GetPortForward() (port uint16, err error) { panic("port forwarding is not supported for surfshark") } - -func (s *surfshark) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) { - panic("port forwarding is not supported for surfshark") -} - -func (s *surfshark) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) { - panic("port forwarding is not supported for surfshark") -} diff --git a/internal/provider/windscribe.go b/internal/provider/windscribe.go index c5f99cfc..6911cfd9 100644 --- a/internal/provider/windscribe.go +++ b/internal/provider/windscribe.go @@ -1,7 +1,6 @@ package provider import ( - "context" "fmt" "net" "strings" @@ -124,11 +123,3 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity func (w *windscribe) GetPortForward() (port uint16, err error) { panic("port forwarding is not supported for windscribe") } - -func (w *windscribe) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) { - panic("port forwarding is not supported for windscribe") -} - -func (w *windscribe) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) { - panic("port forwarding is not supported for windscribe") -} diff --git a/internal/routing/mutate.go b/internal/routing/mutate.go index 44161a89..fa670b05 100644 --- a/internal/routing/mutate.go +++ b/internal/routing/mutate.go @@ -7,29 +7,34 @@ import ( "fmt" ) -func (r *routing) AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error { - for _, subnet := range subnets { - exists, err := r.routeExists(subnet) - if err != nil { - return err - } else if exists { // thanks to @npawelek https://github.com/npawelek - if err := r.removeRoute(ctx, subnet); err != nil { - return err - } - } - r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface) - output, err := r.commander.Run(ctx, "ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface) - if err != nil { - return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err) - } +func (r *routing) AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error { + subnetStr := subnet.String() + r.logger.Info("adding %s as route via %s %s", subnetStr, defaultGateway, defaultInterface) + exists, err := r.routeExists(subnet) + if err != nil { + return err + } else if exists { + return nil + } + output, err := r.commander.Run(ctx, "ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface) + if err != nil { + return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway, "dev", defaultInterface, output, err) } return nil } -func (r *routing) removeRoute(ctx context.Context, subnet net.IPNet) (err error) { - output, err := r.commander.Run(ctx, "ip", "route", "del", subnet.String()) +func (r *routing) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) { + subnetStr := subnet.String() + r.logger.Info("deleting route for %s", subnetStr) + exists, err := r.routeExists(subnet) if err != nil { - return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err) + return err + } else if !exists { // thanks to @npawelek https://github.com/npawelek + return nil + } + output, err := r.commander.Run(ctx, "ip", "route", "del", subnetStr) + if err != nil { + return fmt.Errorf("cannot delete route for %s: %s: %w", subnetStr, output, err) } return nil } diff --git a/internal/routing/mutate_test.go b/internal/routing/mutate_test.go index 50ecb048..fa7b2d2a 100644 --- a/internal/routing/mutate_test.go +++ b/internal/routing/mutate_test.go @@ -8,12 +8,16 @@ import ( "github.com/golang/mock/gomock" "github.com/qdm12/golibs/command/mock_command" + "github.com/qdm12/golibs/files/mock_files" + "github.com/qdm12/golibs/logging/mock_logging" + "github.com/qdm12/private-internet-access-docker/internal/constants" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func Test_removeRoute(t *testing.T) { +func Test_DeleteRouteVia(t *testing.T) { t.Parallel() + ctx := context.Background() tests := map[string]struct { subnet net.IPNet runOutput string @@ -22,26 +26,26 @@ func Test_removeRoute(t *testing.T) { }{ "no output no error": { subnet: net.IPNet{ - IP: net.IP{192, 168, 1, 0}, + IP: net.IP{192, 168, 2, 0}, Mask: net.IPMask{255, 255, 255, 0}, }, }, "error only": { subnet: net.IPNet{ - IP: net.IP{192, 168, 1, 0}, + IP: net.IP{192, 168, 2, 0}, Mask: net.IPMask{255, 255, 255, 0}, }, runErr: fmt.Errorf("error"), - err: fmt.Errorf("cannot delete route for 192.168.1.0/24: : error"), + err: fmt.Errorf("cannot delete route for 192.168.2.0/24: : error"), }, "error and output": { subnet: net.IPNet{ - IP: net.IP{192, 168, 1, 0}, + IP: net.IP{192, 168, 2, 0}, Mask: net.IPMask{255, 255, 255, 0}, }, runErr: fmt.Errorf("error"), runOutput: "output", - err: fmt.Errorf("cannot delete route for 192.168.1.0/24: output: error"), + err: fmt.Errorf("cannot delete route for 192.168.2.0/24: output: error"), }, } for name, tc := range tests { @@ -50,12 +54,26 @@ func Test_removeRoute(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - commander := mock_command.NewMockCommander(mockCtrl) - commander.EXPECT().Run(context.Background(), "ip", "route", "del", tc.subnet.String()). + subnetStr := tc.subnet.String() + + logger := mock_logging.NewMockLogger(mockCtrl) + logger.EXPECT().Info("deleting route for %s") + commander := mock_command.NewMockCommander(mockCtrl) + commander.EXPECT().Run(ctx, "ip", "route", "del", subnetStr). Return(tc.runOutput, tc.runErr).Times(1) - r := &routing{commander: commander} - err := r.removeRoute(context.Background(), tc.subnet) + fileManager := mock_files.NewMockFileManager(mockCtrl) + routesData := []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 +`) + fileManager.EXPECT().ReadFile(string(constants.NetRoute)).Return(routesData, nil) + r := &routing{ + logger: logger, + commander: commander, + fileManager: fileManager, + } + + err := r.DeleteRouteVia(ctx, tc.subnet) if tc.err != nil { require.Error(t, err) assert.Equal(t, tc.err.Error(), err.Error()) diff --git a/internal/routing/routing.go b/internal/routing/routing.go index a363bbf0..9dba928d 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -10,7 +10,8 @@ import ( ) type Routing interface { - AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error + AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error + DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) } diff --git a/internal/shadowsocks/loop.go b/internal/shadowsocks/loop.go index 6c398162..b476dd5b 100644 --- a/internal/shadowsocks/loop.go +++ b/internal/shadowsocks/loop.go @@ -59,6 +59,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait } defer l.logger.Warn("loop exited") + var previousPort uint16 for ctx.Err() == nil { nameserver := l.dnsSettings.PlaintextAddress.String() if l.dnsSettings.Enabled { @@ -75,11 +76,19 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait l.logAndWait(ctx, err) continue } - err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port) - // TODO remove firewall rule on exit below - if err != nil { - l.logger.Error(err) + + if previousPort > 0 { + if err := l.firewallConf.RemoveAllowedPort(ctx, previousPort); err != nil { + l.logger.Error(err) + continue + } } + if err := l.firewallConf.SetAllowedPort(ctx, l.settings.Port); err != nil { + l.logger.Error(err) + continue + } + previousPort = l.settings.Port + shadowsocksCtx, shadowsocksCancel := context.WithCancel(context.Background()) stdout, stderr, waitFn, err := l.conf.Start(shadowsocksCtx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log) if err != nil { diff --git a/internal/tinyproxy/loop.go b/internal/tinyproxy/loop.go index fbd9580f..48430121 100644 --- a/internal/tinyproxy/loop.go +++ b/internal/tinyproxy/loop.go @@ -57,6 +57,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait } defer l.logger.Warn("loop exited") + var previousPort uint16 for ctx.Err() == nil { err := l.conf.MakeConf( l.settings.LogLevel, @@ -69,11 +70,19 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait l.logAndWait(ctx, err) continue } - err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port) - // TODO remove firewall rule on exit below - if err != nil { - l.logger.Error(err) + + if previousPort > 0 { + if err := l.firewallConf.RemoveAllowedPort(ctx, previousPort); err != nil { + l.logger.Error(err) + continue + } } + if err := l.firewallConf.SetAllowedPort(ctx, l.settings.Port); err != nil { + l.logger.Error(err) + continue + } + previousPort = l.settings.Port + tinyproxyCtx, tinyproxyCancel := context.WithCancel(context.Background()) stream, waitFn, err := l.conf.Start(tinyproxyCtx) if err != nil {