Firewall refactoring

- Ability to enable and disable rules in various loops
- Simplified code overall
- Port forwarding moved into openvpn loop
- Route addition and removal improved
This commit is contained in:
Quentin McGaw
2020-07-11 21:03:55 +00:00
parent ccf11990f1
commit b1596bc7e4
20 changed files with 887 additions and 359 deletions

View File

@@ -17,13 +17,10 @@ import (
"github.com/qdm12/golibs/network"
"github.com/qdm12/private-internet-access-docker/internal/alpine"
"github.com/qdm12/private-internet-access-docker/internal/cli"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/dns"
"github.com/qdm12/private-internet-access-docker/internal/firewall"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/openvpn"
"github.com/qdm12/private-internet-access-docker/internal/params"
"github.com/qdm12/private-internet-access-docker/internal/provider"
"github.com/qdm12/private-internet-access-docker/internal/publicip"
"github.com/qdm12/private-internet-access-docker/internal/routing"
"github.com/qdm12/private-internet-access-docker/internal/server"
@@ -72,8 +69,8 @@ func _main(background context.Context, args []string) int {
alpineConf := alpine.NewConfigurator(fileManager)
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
dnsConf := dns.NewConfigurator(logger, client, fileManager)
firewallConf := firewall.NewConfigurator(logger)
routingConf := routing.NewRouting(logger, fileManager)
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
streamMerger := command.NewStreamMerger()
@@ -93,12 +90,6 @@ func _main(background context.Context, args []string) int {
// Should never change
uid, gid := allSettings.System.UID, allSettings.System.GID
providerConf := provider.New(allSettings.VPNSP, logger, client, fileManager, firewallConf)
if !allSettings.Firewall.Enabled {
firewallConf.Disable()
}
err = alpineConf.CreateUser("nonrootuser", uid)
fatalOnError(err)
err = fileManager.SetOwnership("/etc/unbound", uid, gid)
@@ -112,17 +103,6 @@ func _main(background context.Context, args []string) int {
fatalOnError(err)
}
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
fatalOnError(err)
// Temporarily reset chain policies allowing Kubernetes sidecar to
// successfully restart the container. Without this, the existing rules will
// pre-exist, preventing the nslookup of the PIA region address. These will
// simply be redundant at Docker runtime as they will already be set this way
// Thanks to @npawelek https://github.com/npawelek
err = firewallConf.AcceptAll(ctx)
fatalOnError(err)
connectedCh := make(chan struct{})
signalConnected := func() {
connectedCh <- struct{}{}
@@ -130,44 +110,23 @@ func _main(background context.Context, args []string) int {
defer close(connectedCh)
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
connections, err := providerConf.GetOpenVPNConnections(allSettings.OpenVPN.Provider.ServerSelection)
fatalOnError(err)
err = providerConf.BuildConf(
connections,
allSettings.OpenVPN.Verbosity,
uid,
gid,
allSettings.OpenVPN.Root,
allSettings.OpenVPN.Cipher,
allSettings.OpenVPN.Auth,
allSettings.OpenVPN.Provider.ExtraConfigOptions,
)
fatalOnError(err)
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
fatalOnError(err)
err = firewallConf.Clear(ctx)
fatalOnError(err)
err = firewallConf.BlockAll(ctx)
fatalOnError(err)
err = firewallConf.CreateGeneralRules(ctx)
fatalOnError(err)
err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections)
fatalOnError(err)
err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface)
fatalOnError(err)
err = firewallConf.RunUserPostRules(ctx, fileManager, "/iptables/post-rules.txt")
fatalOnError(err)
// TODO replace these with methods on loopers and pass loopers around
restartOpenvpn := make(chan struct{})
portForward := make(chan struct{})
restartUnbound := make(chan struct{})
restartPublicIP := make(chan struct{})
restartTinyproxy := make(chan struct{})
restartShadowsocks := make(chan struct{})
openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError, uid, gid)
if allSettings.Firewall.Enabled {
err := firewallConf.SetEnabled(ctx, true) // disabled by default
fatalOnError(err)
}
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid,
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError)
// wait for restartOpenvpn
go openvpnLooper.Run(ctx, restartOpenvpn, wg)
go openvpnLooper.Run(ctx, restartOpenvpn, portForward, wg)
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
// wait for restartUnbound
@@ -191,7 +150,6 @@ func _main(background context.Context, args []string) int {
}
go func() {
first := true
var restartTickerContext context.Context
var restartTickerCancel context.CancelFunc = func() {}
for {
@@ -200,14 +158,10 @@ func _main(background context.Context, args []string) int {
restartTickerCancel()
return
case <-connectedCh: // blocks until openvpn is connected
if first {
first = false
restartUnbound <- struct{}{}
}
restartTickerCancel()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
onConnected(allSettings, logger, routingConf, defaultInterface, providerConf, restartPublicIP)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP)
}
}
}()
@@ -224,11 +178,10 @@ func _main(background context.Context, args []string) int {
syscall.SIGTERM,
os.Interrupt,
)
exitStatus := 0
shutdownErrorsCount := 0
select {
case signal := <-signalsCh:
logger.Warn("Caught OS signal %s, shutting down", signal)
exitStatus = 1
cancel()
case <-ctx.Done():
logger.Warn("context canceled, shutting down")
@@ -236,20 +189,37 @@ func _main(background context.Context, args []string) int {
logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath)
if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil {
logger.Error(err)
exitStatus = 1
shutdownErrorsCount++
}
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
logger.Error(err)
exitStatus = 1
shutdownErrorsCount++
}
}
wg.Wait()
return exitStatus
waiting, waited := context.WithTimeout(context.Background(), time.Second)
go func() {
defer waited()
wg.Wait()
}()
<-waiting.Done()
if waiting.Err() == context.DeadlineExceeded {
if shutdownErrorsCount > 0 {
logger.Warn("Shutdown had %d errors", shutdownErrorsCount)
}
logger.Warn("Shutdown timed out")
return 1
}
if shutdownErrorsCount > 0 {
logger.Warn("Shutdown had %d errors")
return 1
}
logger.Info("Shutdown successful")
return 0
}
func makeFatalOnError(logger logging.Logger, cancel func(), wg *sync.WaitGroup) func(err error) {
func makeFatalOnError(logger logging.Logger, cancel context.CancelFunc, wg *sync.WaitGroup) func(err error) {
return func(err error) {
if err != nil {
logger.Error(err)
@@ -321,48 +291,25 @@ func trimEventualProgramPrefix(s string) string {
}
}
func onConnected(allSettings settings.Settings,
logger logging.Logger, routingConf routing.Routing, defaultInterface string,
providerConf provider.Provider, restartPublicIP chan<- struct{},
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
portForward, restartUnbound, restartPublicIP chan<- struct{},
) {
restartUnbound <- struct{}{}
restartPublicIP <- struct{}{}
uid, gid := allSettings.System.UID, allSettings.System.GID
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
time.AfterFunc(5*time.Second, func() {
setupPortForwarding(logger, providerConf, allSettings.OpenVPN.Provider.PortForwarding.Filepath, uid, gid)
portForward <- struct{}{}
})
}
vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface)
defaultInterface, _, _, err := routingConf.DefaultRoute()
if err != nil {
logger.Warn(err)
} else {
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
}
}
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
pfLogger := logger.WithPrefix("port forwarding: ")
var port uint16
var err error
for {
port, err = providerConf.GetPortForward()
vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface)
if err != nil {
pfLogger.Error(err)
pfLogger.Info("retrying in 5 seconds...")
time.Sleep(5 * time.Second)
logger.Warn(err)
} else {
pfLogger.Info("port forwarded is %d", port)
break
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
}
}
pfLogger.Info("writing forwarded port to %s", filepath)
if err := providerConf.WritePortForward(filepath, port, uid, gid); err != nil {
pfLogger.Error(err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := providerConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
pfLogger.Error(err)
}
}

149
internal/firewall/enable.go Normal file
View File

@@ -0,0 +1,149 @@
package firewall
import (
"context"
"fmt"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if enabled == c.enabled {
if enabled {
c.logger.Info("already enabled")
} else {
c.logger.Info("already disabled")
}
return nil
}
if !enabled {
c.logger.Info("disabling...")
if err = c.disable(ctx); err != nil {
return err
}
c.enabled = false
c.logger.Info("disabled successfully")
return nil
}
c.logger.Info("enabling...")
if err := c.enable(ctx); err != nil {
return err
}
c.enabled = true
c.logger.Info("enabled successfully")
return nil
}
func (c *configurator) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err)
}
if err = c.setAllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err)
}
// TODO routes?
return nil
}
// To use in defered call when enabling the firewall
func (c *configurator) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.SetEnabled(ctx, false); err != nil {
c.logger.Error(err)
}
}
func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit
defaultInterface, defaultGateway, defaultSubnet, err := c.routing.DefaultRoute()
if err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
fmt.Println(1)
if err = c.setAllPolicies(ctx, "DROP"); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
const remove = false
defer func() {
if err != nil {
c.fallbackToDisabled(ctx)
}
}()
// Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
for _, conn := range c.vpnConnections {
if err = c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptInputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptOutputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
for _, subnet := range c.allowedSubnets {
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
// Re-ensure all routes exist
for _, subnet := range c.allowedSubnets {
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
for port := range c.allowedPorts {
// TODO restrict interface
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if c.portForwarded > 0 {
const tun = string(constants.TUN)
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
return nil
}

View File

@@ -3,42 +3,49 @@ package firewall
import (
"context"
"net"
"sync"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/routing"
)
// Configurator allows to change firewall rules and modify network routes
type Configurator interface {
Version(ctx context.Context) (string, error)
AcceptAll(ctx context.Context) error
Clear(ctx context.Context) error
BlockAll(ctx context.Context) error
CreateGeneralRules(ctx context.Context) error
CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error
CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error
AllowAnyIncomingOnPort(ctx context.Context, port uint16) error
RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error
Disable()
SetEnabled(ctx context.Context, enabled bool) (err error)
SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error)
SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error)
SetAllowedPort(ctx context.Context, port uint16) error
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
SetPortForward(ctx context.Context, port uint16) (err error)
}
type configurator struct {
commander command.Commander
logger logging.Logger
disabled bool
type configurator struct { //nolint:maligned
commander command.Commander
logger logging.Logger
routing routing.Routing
fileManager files.FileManager // for custom iptables rules
iptablesMutex sync.Mutex
// State
enabled bool
vpnConnections []models.OpenVPNConnection
allowedSubnets []net.IPNet
allowedPorts map[uint16]struct{}
portForwarded uint16
stateMutex sync.Mutex
}
// NewConfigurator creates a new Configurator instance
func NewConfigurator(logger logging.Logger) Configurator {
func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator {
return &configurator{
commander: command.NewCommander(),
logger: logger.WithPrefix("firewall configurator: "),
commander: command.NewCommander(),
logger: logger.WithPrefix("firewall: "),
routing: routing,
fileManager: fileManager,
allowedPorts: make(map[uint16]struct{}),
}
}
func (c *configurator) Disable() {
c.disabled = true
}

View File

@@ -6,10 +6,32 @@ import (
"net"
"strings"
"github.com/qdm12/golibs/files"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func appendOrDelete(remove bool) string {
if remove {
return "--delete"
}
return "--append"
}
// flipRule changes an append rule in a delete rule or a delete rule into an
// append rule.
func flipRule(rule string) string {
switch {
case strings.HasPrefix(rule, "-A"):
return strings.Replace(rule, "-A", "-D", 1)
case strings.HasPrefix(rule, "--append"):
return strings.Replace(rule, "--append", "-D", 1)
case strings.HasPrefix(rule, "-D"):
return strings.Replace(rule, "-D", "-A", 1)
case strings.HasPrefix(rule, "--delete"):
return strings.Replace(rule, "--delete", "-A", 1)
}
return rule
}
// Version obtains the version of the installed iptables
func (c *configurator) Version(ctx context.Context) (string, error) {
output, err := c.commander.Run(ctx, "iptables", "--version")
@@ -33,6 +55,8 @@ func (c *configurator) runIptablesInstructions(ctx context.Context, instructions
}
func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error {
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()
flags := strings.Fields(instruction)
if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil {
return fmt.Errorf("failed executing \"iptables %s\": %s: %w", instruction, output, err)
@@ -40,146 +64,119 @@ func (c *configurator) runIptablesInstruction(ctx context.Context, instruction s
return nil
}
func (c *configurator) Clear(ctx context.Context) error {
if c.disabled {
return nil
}
c.logger.Info("clearing all rules")
func (c *configurator) clearAllRules(ctx context.Context) error {
return c.runIptablesInstructions(ctx, []string{
"--flush",
"--delete-chain",
"--flush", // flush all chains
"--delete-chain", // delete all chains
})
}
func (c *configurator) AcceptAll(ctx context.Context) error {
if c.disabled {
return nil
func (c *configurator) setAllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
default:
return fmt.Errorf("policy %q not recognized", policy)
}
c.logger.Info("accepting all traffic")
return c.runIptablesInstructions(ctx, []string{
"-P INPUT ACCEPT",
"-P OUTPUT ACCEPT",
"-P FORWARD ACCEPT",
fmt.Sprintf("--policy INPUT %s", policy),
fmt.Sprintf("--policy OUTPUT %s", policy),
fmt.Sprintf("--policy FORWARD %s", policy),
})
}
func (c *configurator) BlockAll(ctx context.Context) error {
if c.disabled {
return nil
}
c.logger.Info("blocking all traffic")
func (c *configurator) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *configurator) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
return c.runIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
))
}
func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
return c.runIptablesInstructions(ctx, []string{
"-P INPUT DROP",
"-F OUTPUT",
"-P OUTPUT DROP",
"-P FORWARD DROP",
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
})
}
func (c *configurator) CreateGeneralRules(ctx context.Context) error {
if c.disabled {
return nil
}
c.logger.Info("creating general rules")
return c.runIptablesInstructions(ctx, []string{
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"-A OUTPUT -o lo -j ACCEPT",
"-A INPUT -i lo -j ACCEPT",
})
func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.OpenVPNConnection, remove bool) error {
return c.runIptablesInstruction(ctx,
fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port))
}
func (c *configurator) CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error {
if c.disabled {
return nil
}
for _, connection := range connections {
c.logger.Info("allowing output traffic to VPN server %s through %s on port %s %d",
connection.IP, defaultInterface, connection.Protocol, connection.Port)
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)); err != nil {
return err
}
}
if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
return err
}
return nil
}
func (c *configurator) CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
if c.disabled {
return nil
}
func (c *configurator) acceptInputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
subnetStr := subnet.String()
c.logger.Info("accepting input and output traffic for %s", subnetStr)
if err := c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
}); err != nil {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runIptablesInstruction(ctx, fmt.Sprintf(
"%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
))
}
// Thanks to @npawelek
func (c *configurator) acceptOutputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
subnetStr := subnet.String()
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runIptablesInstruction(ctx, fmt.Sprintf(
"%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
))
}
// Used for port forwarding, with intf set to tun
func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
return c.runIptablesInstruction(ctx,
fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port),
)
}
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
exists, err := c.fileManager.FileExists(filepath)
if err != nil {
return err
}
for _, extraSubnet := range extraSubnets {
extraSubnetStr := extraSubnet.String()
c.logger.Info("accepting input traffic through %s from %s to %s", defaultInterface, extraSubnetStr, subnetStr)
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
return err
}
// Thanks to @npawelek
c.logger.Info("accepting output traffic through %s from %s to %s", defaultInterface, subnetStr, extraSubnetStr)
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
return err
}
}
return nil
}
// Used for port forwarding
func (c *configurator) AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error {
if c.disabled {
} else if !exists {
return nil
}
c.logger.Info("accepting input traffic through %s on port %d", device, port)
return c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
})
}
func (c *configurator) AllowAnyIncomingOnPort(ctx context.Context, port uint16) error {
if c.disabled {
return nil
}
c.logger.Info("accepting any input traffic on port %d", port)
return c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -p tcp --dport %d -j ACCEPT", port),
fmt.Sprintf("-A INPUT -p udp --dport %d -j ACCEPT", port),
})
}
func (c *configurator) RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error {
exists, err := fileManager.FileExists(filepath)
b, err := c.fileManager.ReadFile(filepath)
if err != nil {
return err
}
if exists {
b, err := fileManager.ReadFile(filepath)
if err != nil {
return err
lines := strings.Split(string(b), "\n")
successfulRules := []string{}
defer func() {
// transaction-like rollback
if err == nil || ctx.Err() != nil {
return
}
lines := strings.Split(string(b), "\n")
var rules []string
for _, line := range lines {
if !strings.HasPrefix(line, "iptables ") {
continue
}
rules = append(rules, strings.TrimPrefix(line, "iptables "))
c.logger.Info("running user post firewall rule: %s", line)
for _, rule := range successfulRules {
_ = c.runIptablesInstruction(ctx, flipRule(rule))
}
return c.runIptablesInstructions(ctx, rules)
}()
for _, line := range lines {
if !strings.HasPrefix(line, "iptables ") {
continue
}
rule := strings.TrimPrefix(line, "iptables ")
if remove {
rule = flipRule(rule)
}
if err = c.runIptablesInstruction(ctx, rule); err != nil {
return fmt.Errorf("cannot run custom rule: %w", err)
}
successfulRules = append(successfulRules, rule)
}
return nil
}

109
internal/firewall/ports.go Normal file
View File

@@ -0,0 +1,109 @@
package firewall
import (
"context"
"fmt"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if port == 0 {
return nil
}
if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed ports internal list")
c.allowedPorts[port] = struct{}{}
return nil
}
c.logger.Info("setting allowed port %d through firewall...", port)
if _, ok := c.allowedPorts[port]; ok {
return nil
}
const remove = false
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
}
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
}
c.allowedPorts[port] = struct{}{}
return nil
}
func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if port == 0 {
return nil
}
if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed ports internal list")
delete(c.allowedPorts, port)
return nil
}
c.logger.Info("removing allowed port %d through firewall...", port)
if _, ok := c.allowedPorts[port]; !ok {
return nil
}
const remove = true
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
}
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
}
delete(c.allowedPorts, port)
return nil
}
// Use 0 to remove
func (c *configurator) SetPortForward(ctx context.Context, port uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if port == c.portForwarded {
return nil
}
if !c.enabled {
c.logger.Info("firewall disabled, only updating port forwarded internally")
c.portForwarded = port
return nil
}
const tun = string(constants.TUN)
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, true); err != nil {
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, true); err != nil {
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
}
if port == 0 { // not changing port
c.portForwarded = 0
return nil
}
if err := c.acceptInputToPort(ctx, tun, constants.TCP, port, false); err != nil {
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, port, false); err != nil {
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
}
return nil
}

View File

@@ -0,0 +1,127 @@
package firewall
import (
"context"
"fmt"
"net"
)
func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed subnets internal list")
c.allowedSubnets = make([]net.IPNet, len(subnets))
copy(c.allowedSubnets, subnets)
return nil
}
c.logger.Info("setting allowed subnets through firewall...")
subnetsToAdd := findSubnetsToAdd(c.allowedSubnets, subnets)
subnetsToRemove := findSubnetsToRemove(c.allowedSubnets, subnets)
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
return nil
}
defaultInterface, defaultGateway, _, err := c.routing.DefaultRoute()
if err != nil {
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
}
c.removeSubnets(ctx, subnetsToRemove, defaultInterface)
if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway); err != nil {
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
}
return nil
}
func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) {
for _, newSubnet := range newSubnets {
found := false
for _, oldSubnet := range oldSubnets {
if subnetsAreEqual(oldSubnet, newSubnet) {
found = true
break
}
}
if !found {
subnetsToAdd = append(subnetsToAdd, newSubnet)
}
}
return subnetsToAdd
}
func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) {
for _, oldSubnet := range oldSubnets {
found := false
for _, newSubnet := range newSubnets {
if subnetsAreEqual(oldSubnet, newSubnet) {
found = true
break
}
}
if !found {
subnetsToRemove = append(subnetsToRemove, oldSubnet)
}
}
return subnetsToRemove
}
func subnetsAreEqual(a, b net.IPNet) bool {
return a.IP.Equal(b.IP) && a.Mask.String() == b.Mask.String()
}
func removeSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet {
L := len(subnets)
for i := range subnets {
if subnetsAreEqual(subnet, subnets[i]) {
subnets[i] = subnets[L-1]
subnets = subnets[:L-1]
break
}
}
return subnets
}
func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string) {
const remove = true
for _, subnet := range subnets {
failed := false
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
failed = true
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
}
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
failed = true
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
}
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
failed = true
c.logger.Error("cannot remove outdated allowed subnet route: %s", err)
}
if failed {
continue
}
c.allowedSubnets = removeSubnetFromSubnets(c.allowedSubnets, subnet)
}
}
func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, defaultGateway net.IP) error {
const remove = false
for _, subnet := range subnets {
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
}
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
}
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
return fmt.Errorf("cannot add route for allowed subnet: %w", err)
}
c.allowedSubnets = append(c.allowedSubnets, subnet)
}
return nil
}

112
internal/firewall/vpn.go Normal file
View File

@@ -0,0 +1,112 @@
package firewall
import (
"context"
"fmt"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if !c.enabled {
c.logger.Info("firewall disabled, only updating VPN connections internal list")
c.vpnConnections = make([]models.OpenVPNConnection, len(connections))
copy(c.vpnConnections, connections)
return nil
}
c.logger.Info("setting VPN connections through firewall...")
connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections)
connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections)
if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 {
return nil
}
defaultInterface, _, _, err := c.routing.DefaultRoute()
if err != nil {
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
}
// TODO remove elsewhere?
if err := c.acceptOutputThroughInterface(ctx, string(constants.TUN), false); err != nil {
return fmt.Errorf("cannot allow traffic through tunnel: %w", err)
}
c.removeConnections(ctx, connectionsToRemove, defaultInterface)
if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil {
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
}
return nil
}
func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection {
L := len(connections)
for i := range connections {
if connection.Equal(connections[i]) {
connections[i] = connections[L-1]
connections = connections[:L-1]
break
}
}
return connections
}
func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) {
for _, newConnection := range newConnections {
found := false
for _, oldConnection := range oldConnections {
if oldConnection.Equal(newConnection) {
found = true
break
}
}
if !found {
connectionsToAdd = append(connectionsToAdd, newConnection)
}
}
return connectionsToAdd
}
func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) {
for _, oldConnection := range oldConnections {
found := false
for _, newConnection := range newConnections {
if oldConnection.Equal(newConnection) {
found = true
break
}
}
if !found {
connectionsToRemove = append(connectionsToRemove, oldConnection)
}
}
return connectionsToRemove
}
func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) {
for _, conn := range connections {
const remove = true
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err)
continue
}
c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn)
}
}
func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error {
const remove = false
for _, conn := range connections {
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
return err
}
c.vpnConnections = append(c.vpnConnections, conn)
}
return nil
}

View File

@@ -7,3 +7,7 @@ type OpenVPNConnection struct {
Port uint16
Protocol NetworkProtocol
}
func (o *OpenVPNConnection) Equal(other OpenVPNConnection) bool {
return o.IP.Equal(other.IP) && o.Port == other.Port && o.Protocol == other.Protocol
}

View File

@@ -2,43 +2,64 @@ package openvpn
import (
"context"
"fmt"
"sync"
"time"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/firewall"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/provider"
"github.com/qdm12/private-internet-access-docker/internal/settings"
)
type Looper interface {
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup)
}
type looper struct {
conf Configurator
settings settings.OpenVPN
// Variable parameters
provider models.VPNProvider
settings settings.OpenVPN
// Fixed parameters
uid int
gid int
// Configurators
conf Configurator
fw firewall.Configurator
// Other objects
logger logging.Logger
client network.Client
fileManager files.FileManager
streamMerger command.StreamMerger
fatalOnError func(err error)
uid int
gid int
}
func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger,
streamMerger command.StreamMerger, fatalOnError func(err error), uid, gid int) Looper {
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
uid, gid int,
conf Configurator, fw firewall.Configurator,
logger logging.Logger, client network.Client, fileManager files.FileManager,
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
return &looper{
conf: conf,
provider: provider,
settings: settings,
logger: logger.WithPrefix("openvpn: "),
streamMerger: streamMerger,
fatalOnError: fatalOnError,
uid: uid,
gid: gid,
conf: conf,
fw: fw,
logger: logger.WithPrefix("openvpn: "),
client: client,
fileManager: fileManager,
streamMerger: streamMerger,
fatalOnError: fatalOnError,
}
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done()
select {
@@ -46,17 +67,51 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-ctx.Done():
return
}
for {
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
err := l.conf.WriteAuthFile(
l.settings.User,
l.settings.Password,
defer l.logger.Warn("loop exited")
for ctx.Err() == nil {
providerConf := provider.New(l.provider, l.client, l.fileManager)
connections, err := providerConf.GetOpenVPNConnections(l.settings.Provider.ServerSelection)
l.fatalOnError(err)
err = providerConf.BuildConf(
connections,
l.settings.Verbosity,
l.uid,
l.gid,
l.settings.Root,
l.settings.Cipher,
l.settings.Auth,
l.settings.Provider.ExtraConfigOptions,
)
l.fatalOnError(err)
stream, waitFn, err := l.conf.Start(openvpnCtx)
err = l.conf.WriteAuthFile(l.settings.User, l.settings.Password, l.uid, l.gid)
l.fatalOnError(err)
if err := l.fw.SetVPNConnections(ctx, connections); err != nil {
l.fatalOnError(err)
}
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
stream, waitFn, err := l.conf.Start(openvpnCtx)
if err != nil {
openvpnCancel()
l.logAndWait(ctx, err)
continue
}
go func(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-portForward:
l.portForward(ctx, providerConf)
}
}
}(openvpnCtx)
go l.streamMerger.Merge(openvpnCtx, stream,
command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
waitError := make(chan error)
@@ -74,13 +129,53 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-restart: // triggered restart
l.logger.Info("restarting")
openvpnCancel()
<-waitError
close(waitError)
case err := <-waitError: // unexpected error
l.logger.Warn(err)
l.logger.Info("restarting")
openvpnCancel()
close(waitError)
time.Sleep(time.Second)
l.logAndWait(ctx, err)
}
}
}
func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err)
l.logger.Info("retrying in 30 seconds")
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() // just for the linter
<-ctx.Done()
}
func (l *looper) portForward(ctx context.Context, providerConf provider.Provider) {
if !l.settings.Provider.PortForwarding.Enabled {
return
}
var port uint16
err := fmt.Errorf("")
for err != nil {
if ctx.Err() != nil {
return
}
port, err = providerConf.GetPortForward()
if err != nil {
l.logAndWait(ctx, err)
continue
}
l.logger.Info("port forwarded is %d", port)
}
filepath := l.settings.Provider.PortForwarding.Filepath
l.logger.Info("writing forwarded port to %s", filepath)
err = l.fileManager.WriteLinesToFile(
string(filepath), []string{fmt.Sprintf("%d", port)},
files.Ownership(l.uid, l.gid), files.Permissions(0400),
)
if err != nil {
l.logger.Error(err)
}
if err := l.fw.SetPortForward(ctx, port); err != nil {
l.logger.Error(err)
}
}

View File

@@ -1,7 +1,6 @@
package provider
import (
"context"
"fmt"
"net"
"strings"
@@ -125,11 +124,3 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity
func (c *cyberghost) GetPortForward() (port uint16, err error) {
panic("port forwarding is not supported for cyberghost")
}
func (c *cyberghost) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
panic("port forwarding is not supported for cyberghost")
}
func (c *cyberghost) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
panic("port forwarding is not supported for cyberghost")
}

View File

@@ -1,24 +1,20 @@
package provider
import (
"context"
"fmt"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
type mullvad struct {
fileManager files.FileManager
logger logging.Logger
}
func newMullvad(fileManager files.FileManager, logger logging.Logger) *mullvad {
func newMullvad(fileManager files.FileManager) *mullvad {
return &mullvad{
fileManager: fileManager,
logger: logger.WithPrefix("Mullvad configurator: "),
}
}
@@ -106,11 +102,3 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u
func (m *mullvad) GetPortForward() (port uint16, err error) {
panic("port forwarding is not supported for mullvad")
}
func (m *mullvad) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
panic("port forwarding is not supported for mullvad")
}
func (m *mullvad) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
panic("port forwarding is not supported for mullvad")
}

View File

@@ -1,7 +1,6 @@
package provider
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
@@ -14,24 +13,21 @@ import (
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/verification"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/firewall"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
type pia struct {
client network.Client
fileManager files.FileManager
firewall firewall.Configurator
random random.Random
verifyPort func(port string) error
lookupIP func(host string) ([]net.IP, error)
}
func newPrivateInternetAccess(client network.Client, fileManager files.FileManager, firewall firewall.Configurator) *pia {
func newPrivateInternetAccess(client network.Client, fileManager files.FileManager) *pia {
return &pia{
client: client,
fileManager: fileManager,
firewall: firewall,
random: random.NewRandom(),
verifyPort: verification.NewVerifier().VerifyPort,
lookupIP: net.LookupIP}
@@ -168,7 +164,7 @@ func (p *pia) GetPortForward() (port uint16, err error) {
}
clientID := hex.EncodeToString(b)
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
content, status, err := p.client.GetContent(url)
content, status, err := p.client.GetContent(url) // TODO add ctx
switch {
case err != nil:
return 0, err
@@ -185,15 +181,3 @@ func (p *pia) GetPortForward() (port uint16, err error) {
}
return body.Port, nil
}
func (p *pia) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
return p.fileManager.WriteLinesToFile(
string(filepath),
[]string{fmt.Sprintf("%d", port)},
files.Ownership(uid, gid),
files.Permissions(0400))
}
func (p *pia) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
return p.firewall.AllowInputTrafficOnPort(ctx, device, port)
}

View File

@@ -1,13 +1,9 @@
package provider
import (
"context"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/firewall"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
@@ -16,16 +12,14 @@ type Provider interface {
GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error)
BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (err error)
GetPortForward() (port uint16, err error)
WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error)
AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error)
}
func New(provider models.VPNProvider, logger logging.Logger, client network.Client, fileManager files.FileManager, firewall firewall.Configurator) Provider {
func New(provider models.VPNProvider, client network.Client, fileManager files.FileManager) Provider {
switch provider {
case constants.PrivateInternetAccess:
return newPrivateInternetAccess(client, fileManager, firewall)
return newPrivateInternetAccess(client, fileManager)
case constants.Mullvad:
return newMullvad(fileManager, logger)
return newMullvad(fileManager)
case constants.Windscribe:
return newWindscribe(fileManager)
case constants.Surfshark:

View File

@@ -1,7 +1,6 @@
package provider
import (
"context"
"fmt"
"net"
"strings"
@@ -127,11 +126,3 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity,
func (s *surfshark) GetPortForward() (port uint16, err error) {
panic("port forwarding is not supported for surfshark")
}
func (s *surfshark) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
panic("port forwarding is not supported for surfshark")
}
func (s *surfshark) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
panic("port forwarding is not supported for surfshark")
}

View File

@@ -1,7 +1,6 @@
package provider
import (
"context"
"fmt"
"net"
"strings"
@@ -124,11 +123,3 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity
func (w *windscribe) GetPortForward() (port uint16, err error) {
panic("port forwarding is not supported for windscribe")
}
func (w *windscribe) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
panic("port forwarding is not supported for windscribe")
}
func (w *windscribe) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
panic("port forwarding is not supported for windscribe")
}

View File

@@ -7,29 +7,34 @@ import (
"fmt"
)
func (r *routing) AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
for _, subnet := range subnets {
exists, err := r.routeExists(subnet)
if err != nil {
return err
} else if exists { // thanks to @npawelek https://github.com/npawelek
if err := r.removeRoute(ctx, subnet); err != nil {
return err
}
}
r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface)
output, err := r.commander.Run(ctx, "ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface)
if err != nil {
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err)
}
func (r *routing) AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error {
subnetStr := subnet.String()
r.logger.Info("adding %s as route via %s %s", subnetStr, defaultGateway, defaultInterface)
exists, err := r.routeExists(subnet)
if err != nil {
return err
} else if exists {
return nil
}
output, err := r.commander.Run(ctx, "ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
if err != nil {
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway, "dev", defaultInterface, output, err)
}
return nil
}
func (r *routing) removeRoute(ctx context.Context, subnet net.IPNet) (err error) {
output, err := r.commander.Run(ctx, "ip", "route", "del", subnet.String())
func (r *routing) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) {
subnetStr := subnet.String()
r.logger.Info("deleting route for %s", subnetStr)
exists, err := r.routeExists(subnet)
if err != nil {
return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err)
return err
} else if !exists { // thanks to @npawelek https://github.com/npawelek
return nil
}
output, err := r.commander.Run(ctx, "ip", "route", "del", subnetStr)
if err != nil {
return fmt.Errorf("cannot delete route for %s: %s: %w", subnetStr, output, err)
}
return nil
}

View File

@@ -8,12 +8,16 @@ import (
"github.com/golang/mock/gomock"
"github.com/qdm12/golibs/command/mock_command"
"github.com/qdm12/golibs/files/mock_files"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_removeRoute(t *testing.T) {
func Test_DeleteRouteVia(t *testing.T) {
t.Parallel()
ctx := context.Background()
tests := map[string]struct {
subnet net.IPNet
runOutput string
@@ -22,26 +26,26 @@ func Test_removeRoute(t *testing.T) {
}{
"no output no error": {
subnet: net.IPNet{
IP: net.IP{192, 168, 1, 0},
IP: net.IP{192, 168, 2, 0},
Mask: net.IPMask{255, 255, 255, 0},
},
},
"error only": {
subnet: net.IPNet{
IP: net.IP{192, 168, 1, 0},
IP: net.IP{192, 168, 2, 0},
Mask: net.IPMask{255, 255, 255, 0},
},
runErr: fmt.Errorf("error"),
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: : error"),
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: : error"),
},
"error and output": {
subnet: net.IPNet{
IP: net.IP{192, 168, 1, 0},
IP: net.IP{192, 168, 2, 0},
Mask: net.IPMask{255, 255, 255, 0},
},
runErr: fmt.Errorf("error"),
runOutput: "output",
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: output: error"),
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: output: error"),
},
}
for name, tc := range tests {
@@ -50,12 +54,26 @@ func Test_removeRoute(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
commander := mock_command.NewMockCommander(mockCtrl)
commander.EXPECT().Run(context.Background(), "ip", "route", "del", tc.subnet.String()).
subnetStr := tc.subnet.String()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("deleting route for %s")
commander := mock_command.NewMockCommander(mockCtrl)
commander.EXPECT().Run(ctx, "ip", "route", "del", subnetStr).
Return(tc.runOutput, tc.runErr).Times(1)
r := &routing{commander: commander}
err := r.removeRoute(context.Background(), tc.subnet)
fileManager := mock_files.NewMockFileManager(mockCtrl)
routesData := []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
`)
fileManager.EXPECT().ReadFile(string(constants.NetRoute)).Return(routesData, nil)
r := &routing{
logger: logger,
commander: commander,
fileManager: fileManager,
}
err := r.DeleteRouteVia(ctx, tc.subnet)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())

View File

@@ -10,7 +10,8 @@ import (
)
type Routing interface {
AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
}

View File

@@ -59,6 +59,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
}
defer l.logger.Warn("loop exited")
var previousPort uint16
for ctx.Err() == nil {
nameserver := l.dnsSettings.PlaintextAddress.String()
if l.dnsSettings.Enabled {
@@ -75,11 +76,19 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
l.logAndWait(ctx, err)
continue
}
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
// TODO remove firewall rule on exit below
if err != nil {
l.logger.Error(err)
if previousPort > 0 {
if err := l.firewallConf.RemoveAllowedPort(ctx, previousPort); err != nil {
l.logger.Error(err)
continue
}
}
if err := l.firewallConf.SetAllowedPort(ctx, l.settings.Port); err != nil {
l.logger.Error(err)
continue
}
previousPort = l.settings.Port
shadowsocksCtx, shadowsocksCancel := context.WithCancel(context.Background())
stdout, stderr, waitFn, err := l.conf.Start(shadowsocksCtx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log)
if err != nil {

View File

@@ -57,6 +57,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
}
defer l.logger.Warn("loop exited")
var previousPort uint16
for ctx.Err() == nil {
err := l.conf.MakeConf(
l.settings.LogLevel,
@@ -69,11 +70,19 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
l.logAndWait(ctx, err)
continue
}
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
// TODO remove firewall rule on exit below
if err != nil {
l.logger.Error(err)
if previousPort > 0 {
if err := l.firewallConf.RemoveAllowedPort(ctx, previousPort); err != nil {
l.logger.Error(err)
continue
}
}
if err := l.firewallConf.SetAllowedPort(ctx, l.settings.Port); err != nil {
l.logger.Error(err)
continue
}
previousPort = l.settings.Port
tinyproxyCtx, tinyproxyCancel := context.WithCancel(context.Background())
stream, waitFn, err := l.conf.Start(tinyproxyCtx)
if err != nil {