Firewall refactoring
- Ability to enable and disable rules in various loops - Simplified code overall - Port forwarding moved into openvpn loop - Route addition and removal improved
This commit is contained in:
@@ -17,13 +17,10 @@ import (
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/alpine"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/cli"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/dns"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/openvpn"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/provider"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/publicip"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/routing"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/server"
|
||||
@@ -72,8 +69,8 @@ func _main(background context.Context, args []string) int {
|
||||
alpineConf := alpine.NewConfigurator(fileManager)
|
||||
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
|
||||
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
||||
firewallConf := firewall.NewConfigurator(logger)
|
||||
routingConf := routing.NewRouting(logger, fileManager)
|
||||
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
|
||||
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
|
||||
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
|
||||
streamMerger := command.NewStreamMerger()
|
||||
@@ -93,12 +90,6 @@ func _main(background context.Context, args []string) int {
|
||||
// Should never change
|
||||
uid, gid := allSettings.System.UID, allSettings.System.GID
|
||||
|
||||
providerConf := provider.New(allSettings.VPNSP, logger, client, fileManager, firewallConf)
|
||||
|
||||
if !allSettings.Firewall.Enabled {
|
||||
firewallConf.Disable()
|
||||
}
|
||||
|
||||
err = alpineConf.CreateUser("nonrootuser", uid)
|
||||
fatalOnError(err)
|
||||
err = fileManager.SetOwnership("/etc/unbound", uid, gid)
|
||||
@@ -112,17 +103,6 @@ func _main(background context.Context, args []string) int {
|
||||
fatalOnError(err)
|
||||
}
|
||||
|
||||
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
|
||||
fatalOnError(err)
|
||||
|
||||
// Temporarily reset chain policies allowing Kubernetes sidecar to
|
||||
// successfully restart the container. Without this, the existing rules will
|
||||
// pre-exist, preventing the nslookup of the PIA region address. These will
|
||||
// simply be redundant at Docker runtime as they will already be set this way
|
||||
// Thanks to @npawelek https://github.com/npawelek
|
||||
err = firewallConf.AcceptAll(ctx)
|
||||
fatalOnError(err)
|
||||
|
||||
connectedCh := make(chan struct{})
|
||||
signalConnected := func() {
|
||||
connectedCh <- struct{}{}
|
||||
@@ -130,44 +110,23 @@ func _main(background context.Context, args []string) int {
|
||||
defer close(connectedCh)
|
||||
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
|
||||
|
||||
connections, err := providerConf.GetOpenVPNConnections(allSettings.OpenVPN.Provider.ServerSelection)
|
||||
fatalOnError(err)
|
||||
err = providerConf.BuildConf(
|
||||
connections,
|
||||
allSettings.OpenVPN.Verbosity,
|
||||
uid,
|
||||
gid,
|
||||
allSettings.OpenVPN.Root,
|
||||
allSettings.OpenVPN.Cipher,
|
||||
allSettings.OpenVPN.Auth,
|
||||
allSettings.OpenVPN.Provider.ExtraConfigOptions,
|
||||
)
|
||||
fatalOnError(err)
|
||||
|
||||
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.Clear(ctx)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.BlockAll(ctx)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.CreateGeneralRules(ctx)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.RunUserPostRules(ctx, fileManager, "/iptables/post-rules.txt")
|
||||
fatalOnError(err)
|
||||
|
||||
// TODO replace these with methods on loopers and pass loopers around
|
||||
restartOpenvpn := make(chan struct{})
|
||||
portForward := make(chan struct{})
|
||||
restartUnbound := make(chan struct{})
|
||||
restartPublicIP := make(chan struct{})
|
||||
restartTinyproxy := make(chan struct{})
|
||||
restartShadowsocks := make(chan struct{})
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError, uid, gid)
|
||||
if allSettings.Firewall.Enabled {
|
||||
err := firewallConf.SetEnabled(ctx, true) // disabled by default
|
||||
fatalOnError(err)
|
||||
}
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid,
|
||||
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError)
|
||||
// wait for restartOpenvpn
|
||||
go openvpnLooper.Run(ctx, restartOpenvpn, wg)
|
||||
go openvpnLooper.Run(ctx, restartOpenvpn, portForward, wg)
|
||||
|
||||
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
|
||||
// wait for restartUnbound
|
||||
@@ -191,7 +150,6 @@ func _main(background context.Context, args []string) int {
|
||||
}
|
||||
|
||||
go func() {
|
||||
first := true
|
||||
var restartTickerContext context.Context
|
||||
var restartTickerCancel context.CancelFunc = func() {}
|
||||
for {
|
||||
@@ -200,14 +158,10 @@ func _main(background context.Context, args []string) int {
|
||||
restartTickerCancel()
|
||||
return
|
||||
case <-connectedCh: // blocks until openvpn is connected
|
||||
if first {
|
||||
first = false
|
||||
restartUnbound <- struct{}{}
|
||||
}
|
||||
restartTickerCancel()
|
||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
|
||||
onConnected(allSettings, logger, routingConf, defaultInterface, providerConf, restartPublicIP)
|
||||
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -224,11 +178,10 @@ func _main(background context.Context, args []string) int {
|
||||
syscall.SIGTERM,
|
||||
os.Interrupt,
|
||||
)
|
||||
exitStatus := 0
|
||||
shutdownErrorsCount := 0
|
||||
select {
|
||||
case signal := <-signalsCh:
|
||||
logger.Warn("Caught OS signal %s, shutting down", signal)
|
||||
exitStatus = 1
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
logger.Warn("context canceled, shutting down")
|
||||
@@ -236,20 +189,37 @@ func _main(background context.Context, args []string) int {
|
||||
logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath)
|
||||
if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil {
|
||||
logger.Error(err)
|
||||
exitStatus = 1
|
||||
shutdownErrorsCount++
|
||||
}
|
||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
||||
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
||||
logger.Error(err)
|
||||
exitStatus = 1
|
||||
shutdownErrorsCount++
|
||||
}
|
||||
}
|
||||
waiting, waited := context.WithTimeout(context.Background(), time.Second)
|
||||
go func() {
|
||||
defer waited()
|
||||
wg.Wait()
|
||||
return exitStatus
|
||||
}()
|
||||
<-waiting.Done()
|
||||
if waiting.Err() == context.DeadlineExceeded {
|
||||
if shutdownErrorsCount > 0 {
|
||||
logger.Warn("Shutdown had %d errors", shutdownErrorsCount)
|
||||
}
|
||||
logger.Warn("Shutdown timed out")
|
||||
return 1
|
||||
}
|
||||
if shutdownErrorsCount > 0 {
|
||||
logger.Warn("Shutdown had %d errors")
|
||||
return 1
|
||||
}
|
||||
logger.Info("Shutdown successful")
|
||||
return 0
|
||||
}
|
||||
|
||||
func makeFatalOnError(logger logging.Logger, cancel func(), wg *sync.WaitGroup) func(err error) {
|
||||
func makeFatalOnError(logger logging.Logger, cancel context.CancelFunc, wg *sync.WaitGroup) func(err error) {
|
||||
return func(err error) {
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
@@ -321,48 +291,25 @@ func trimEventualProgramPrefix(s string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func onConnected(allSettings settings.Settings,
|
||||
logger logging.Logger, routingConf routing.Routing, defaultInterface string,
|
||||
providerConf provider.Provider, restartPublicIP chan<- struct{},
|
||||
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
|
||||
portForward, restartUnbound, restartPublicIP chan<- struct{},
|
||||
) {
|
||||
restartUnbound <- struct{}{}
|
||||
restartPublicIP <- struct{}{}
|
||||
uid, gid := allSettings.System.UID, allSettings.System.GID
|
||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||
time.AfterFunc(5*time.Second, func() {
|
||||
setupPortForwarding(logger, providerConf, allSettings.OpenVPN.Provider.PortForwarding.Filepath, uid, gid)
|
||||
portForward <- struct{}{}
|
||||
})
|
||||
}
|
||||
|
||||
defaultInterface, _, _, err := routingConf.DefaultRoute()
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
||||
}
|
||||
}
|
||||
|
||||
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
|
||||
pfLogger := logger.WithPrefix("port forwarding: ")
|
||||
var port uint16
|
||||
var err error
|
||||
for {
|
||||
port, err = providerConf.GetPortForward()
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
pfLogger.Info("retrying in 5 seconds...")
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
pfLogger.Info("port forwarded is %d", port)
|
||||
break
|
||||
}
|
||||
}
|
||||
pfLogger.Info("writing forwarded port to %s", filepath)
|
||||
if err := providerConf.WritePortForward(filepath, port, uid, gid); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := providerConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
149
internal/firewall/enable.go
Normal file
149
internal/firewall/enable.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
if enabled == c.enabled {
|
||||
if enabled {
|
||||
c.logger.Info("already enabled")
|
||||
} else {
|
||||
c.logger.Info("already disabled")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
c.logger.Info("disabling...")
|
||||
if err = c.disable(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
c.enabled = false
|
||||
c.logger.Info("disabled successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("enabling...")
|
||||
|
||||
if err := c.enable(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
c.enabled = true
|
||||
c.logger.Info("enabled successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) disable(ctx context.Context) (err error) {
|
||||
if err = c.clearAllRules(ctx); err != nil {
|
||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||
}
|
||||
if err = c.setAllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||
}
|
||||
// TODO routes?
|
||||
return nil
|
||||
}
|
||||
|
||||
// To use in defered call when enabling the firewall
|
||||
func (c *configurator) fallbackToDisabled(ctx context.Context) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err := c.SetEnabled(ctx, false); err != nil {
|
||||
c.logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit
|
||||
defaultInterface, defaultGateway, defaultSubnet, err := c.routing.DefaultRoute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println(1)
|
||||
if err = c.setAllPolicies(ctx, "DROP"); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
|
||||
const remove = false
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.fallbackToDisabled(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
// Loopback traffic
|
||||
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
|
||||
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
for _, conn := range c.vpnConnections {
|
||||
if err = c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
}
|
||||
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptOutputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
for _, subnet := range c.allowedSubnets {
|
||||
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
}
|
||||
// Re-ensure all routes exist
|
||||
for _, subnet := range c.allowedSubnets {
|
||||
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for port := range c.allowedPorts {
|
||||
// TODO restrict interface
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.portForwarded > 0 {
|
||||
const tun = string(constants.TUN)
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -3,42 +3,49 @@ package firewall
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/routing"
|
||||
)
|
||||
|
||||
// Configurator allows to change firewall rules and modify network routes
|
||||
type Configurator interface {
|
||||
Version(ctx context.Context) (string, error)
|
||||
AcceptAll(ctx context.Context) error
|
||||
Clear(ctx context.Context) error
|
||||
BlockAll(ctx context.Context) error
|
||||
CreateGeneralRules(ctx context.Context) error
|
||||
CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error
|
||||
CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
|
||||
AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error
|
||||
AllowAnyIncomingOnPort(ctx context.Context, port uint16) error
|
||||
RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error
|
||||
Disable()
|
||||
SetEnabled(ctx context.Context, enabled bool) (err error)
|
||||
SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error)
|
||||
SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error)
|
||||
SetAllowedPort(ctx context.Context, port uint16) error
|
||||
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
SetPortForward(ctx context.Context, port uint16) (err error)
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
type configurator struct { //nolint:maligned
|
||||
commander command.Commander
|
||||
logger logging.Logger
|
||||
disabled bool
|
||||
routing routing.Routing
|
||||
fileManager files.FileManager // for custom iptables rules
|
||||
iptablesMutex sync.Mutex
|
||||
|
||||
// State
|
||||
enabled bool
|
||||
vpnConnections []models.OpenVPNConnection
|
||||
allowedSubnets []net.IPNet
|
||||
allowedPorts map[uint16]struct{}
|
||||
portForwarded uint16
|
||||
stateMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewConfigurator creates a new Configurator instance
|
||||
func NewConfigurator(logger logging.Logger) Configurator {
|
||||
func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator {
|
||||
return &configurator{
|
||||
commander: command.NewCommander(),
|
||||
logger: logger.WithPrefix("firewall configurator: "),
|
||||
logger: logger.WithPrefix("firewall: "),
|
||||
routing: routing,
|
||||
fileManager: fileManager,
|
||||
allowedPorts: make(map[uint16]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *configurator) Disable() {
|
||||
c.disabled = true
|
||||
}
|
||||
|
||||
@@ -6,10 +6,32 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
func appendOrDelete(remove bool) string {
|
||||
if remove {
|
||||
return "--delete"
|
||||
}
|
||||
return "--append"
|
||||
}
|
||||
|
||||
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||
// append rule.
|
||||
func flipRule(rule string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(rule, "-A"):
|
||||
return strings.Replace(rule, "-A", "-D", 1)
|
||||
case strings.HasPrefix(rule, "--append"):
|
||||
return strings.Replace(rule, "--append", "-D", 1)
|
||||
case strings.HasPrefix(rule, "-D"):
|
||||
return strings.Replace(rule, "-D", "-A", 1)
|
||||
case strings.HasPrefix(rule, "--delete"):
|
||||
return strings.Replace(rule, "--delete", "-A", 1)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed iptables
|
||||
func (c *configurator) Version(ctx context.Context) (string, error) {
|
||||
output, err := c.commander.Run(ctx, "iptables", "--version")
|
||||
@@ -33,6 +55,8 @@ func (c *configurator) runIptablesInstructions(ctx context.Context, instructions
|
||||
}
|
||||
|
||||
func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error {
|
||||
c.iptablesMutex.Lock() // only one iptables command at once
|
||||
defer c.iptablesMutex.Unlock()
|
||||
flags := strings.Fields(instruction)
|
||||
if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil {
|
||||
return fmt.Errorf("failed executing \"iptables %s\": %s: %w", instruction, output, err)
|
||||
@@ -40,146 +64,119 @@ func (c *configurator) runIptablesInstruction(ctx context.Context, instruction s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) Clear(ctx context.Context) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("clearing all rules")
|
||||
func (c *configurator) clearAllRules(ctx context.Context) error {
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"--flush",
|
||||
"--delete-chain",
|
||||
"--flush", // flush all chains
|
||||
"--delete-chain", // delete all chains
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) AcceptAll(ctx context.Context) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
func (c *configurator) setAllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("policy %q not recognized", policy)
|
||||
}
|
||||
c.logger.Info("accepting all traffic")
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"-P INPUT ACCEPT",
|
||||
"-P OUTPUT ACCEPT",
|
||||
"-P FORWARD ACCEPT",
|
||||
fmt.Sprintf("--policy INPUT %s", policy),
|
||||
fmt.Sprintf("--policy OUTPUT %s", policy),
|
||||
fmt.Sprintf("--policy FORWARD %s", policy),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) BlockAll(ctx context.Context) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("blocking all traffic")
|
||||
func (c *configurator) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *configurator) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"-P INPUT DROP",
|
||||
"-F OUTPUT",
|
||||
"-P OUTPUT DROP",
|
||||
"-P FORWARD DROP",
|
||||
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) CreateGeneralRules(ctx context.Context) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("creating general rules")
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"-A OUTPUT -o lo -j ACCEPT",
|
||||
"-A INPUT -i lo -j ACCEPT",
|
||||
})
|
||||
func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.OpenVPNConnection, remove bool) error {
|
||||
return c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port))
|
||||
}
|
||||
|
||||
func (c *configurator) CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
for _, connection := range connections {
|
||||
c.logger.Info("allowing output traffic to VPN server %s through %s on port %s %d",
|
||||
connection.IP, defaultInterface, connection.Protocol, connection.Port)
|
||||
if err := c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
}
|
||||
func (c *configurator) acceptInputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
|
||||
subnetStr := subnet.String()
|
||||
c.logger.Info("accepting input and output traffic for %s", subnetStr)
|
||||
if err := c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
||||
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
||||
}); err != nil {
|
||||
return err
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
for _, extraSubnet := range extraSubnets {
|
||||
extraSubnetStr := extraSubnet.String()
|
||||
c.logger.Info("accepting input traffic through %s from %s to %s", defaultInterface, extraSubnetStr, subnetStr)
|
||||
if err := c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Thanks to @npawelek
|
||||
c.logger.Info("accepting output traffic through %s from %s to %s", defaultInterface, subnetStr, extraSubnetStr)
|
||||
if err := c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
|
||||
))
|
||||
}
|
||||
|
||||
// Used for port forwarding
|
||||
func (c *configurator) AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
// Thanks to @npawelek
|
||||
func (c *configurator) acceptOutputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
|
||||
subnetStr := subnet.String()
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
c.logger.Info("accepting input traffic through %s on port %d", device, port)
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
|
||||
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
|
||||
})
|
||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||
"%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *configurator) AllowAnyIncomingOnPort(ctx context.Context, port uint16) error {
|
||||
if c.disabled {
|
||||
return nil
|
||||
// Used for port forwarding, with intf set to tun
|
||||
func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
c.logger.Info("accepting any input traffic on port %d", port)
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("-A INPUT -p tcp --dport %d -j ACCEPT", port),
|
||||
fmt.Sprintf("-A INPUT -p udp --dport %d -j ACCEPT", port),
|
||||
})
|
||||
return c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *configurator) RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error {
|
||||
exists, err := fileManager.FileExists(filepath)
|
||||
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
exists, err := c.fileManager.FileExists(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !exists {
|
||||
return nil
|
||||
}
|
||||
if exists {
|
||||
b, err := fileManager.ReadFile(filepath)
|
||||
b, err := c.fileManager.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(b), "\n")
|
||||
var rules []string
|
||||
successfulRules := []string{}
|
||||
defer func() {
|
||||
// transaction-like rollback
|
||||
if err == nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range successfulRules {
|
||||
_ = c.runIptablesInstruction(ctx, flipRule(rule))
|
||||
}
|
||||
}()
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "iptables ") {
|
||||
continue
|
||||
}
|
||||
rules = append(rules, strings.TrimPrefix(line, "iptables "))
|
||||
c.logger.Info("running user post firewall rule: %s", line)
|
||||
rule := strings.TrimPrefix(line, "iptables ")
|
||||
if remove {
|
||||
rule = flipRule(rule)
|
||||
}
|
||||
return c.runIptablesInstructions(ctx, rules)
|
||||
if err = c.runIptablesInstruction(ctx, rule); err != nil {
|
||||
return fmt.Errorf("cannot run custom rule: %w", err)
|
||||
}
|
||||
successfulRules = append(successfulRules, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
109
internal/firewall/ports.go
Normal file
109
internal/firewall/ports.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
if port == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating allowed ports internal list")
|
||||
c.allowedPorts[port] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("setting allowed port %d through firewall...", port)
|
||||
|
||||
if _, ok := c.allowedPorts[port]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
const remove = false
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
|
||||
}
|
||||
c.allowedPorts[port] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
if port == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating allowed ports internal list")
|
||||
delete(c.allowedPorts, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("removing allowed port %d through firewall...", port)
|
||||
|
||||
if _, ok := c.allowedPorts[port]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
const remove = true
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
|
||||
}
|
||||
delete(c.allowedPorts, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use 0 to remove
|
||||
func (c *configurator) SetPortForward(ctx context.Context, port uint16) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
if port == c.portForwarded {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating port forwarded internally")
|
||||
c.portForwarded = port
|
||||
return nil
|
||||
}
|
||||
|
||||
const tun = string(constants.TUN)
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, true); err != nil {
|
||||
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, true); err != nil {
|
||||
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
|
||||
}
|
||||
|
||||
if port == 0 { // not changing port
|
||||
c.portForwarded = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.TCP, port, false); err != nil {
|
||||
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.UDP, port, false); err != nil {
|
||||
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
127
internal/firewall/subnets.go
Normal file
127
internal/firewall/subnets.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating allowed subnets internal list")
|
||||
c.allowedSubnets = make([]net.IPNet, len(subnets))
|
||||
copy(c.allowedSubnets, subnets)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("setting allowed subnets through firewall...")
|
||||
|
||||
subnetsToAdd := findSubnetsToAdd(c.allowedSubnets, subnets)
|
||||
subnetsToRemove := findSubnetsToRemove(c.allowedSubnets, subnets)
|
||||
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
defaultInterface, defaultGateway, _, err := c.routing.DefaultRoute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||
}
|
||||
|
||||
c.removeSubnets(ctx, subnetsToRemove, defaultInterface)
|
||||
if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway); err != nil {
|
||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) {
|
||||
for _, newSubnet := range newSubnets {
|
||||
found := false
|
||||
for _, oldSubnet := range oldSubnets {
|
||||
if subnetsAreEqual(oldSubnet, newSubnet) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
subnetsToAdd = append(subnetsToAdd, newSubnet)
|
||||
}
|
||||
}
|
||||
return subnetsToAdd
|
||||
}
|
||||
|
||||
func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) {
|
||||
for _, oldSubnet := range oldSubnets {
|
||||
found := false
|
||||
for _, newSubnet := range newSubnets {
|
||||
if subnetsAreEqual(oldSubnet, newSubnet) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
subnetsToRemove = append(subnetsToRemove, oldSubnet)
|
||||
}
|
||||
}
|
||||
return subnetsToRemove
|
||||
}
|
||||
|
||||
func subnetsAreEqual(a, b net.IPNet) bool {
|
||||
return a.IP.Equal(b.IP) && a.Mask.String() == b.Mask.String()
|
||||
}
|
||||
|
||||
func removeSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet {
|
||||
L := len(subnets)
|
||||
for i := range subnets {
|
||||
if subnetsAreEqual(subnet, subnets[i]) {
|
||||
subnets[i] = subnets[L-1]
|
||||
subnets = subnets[:L-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
return subnets
|
||||
}
|
||||
|
||||
func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string) {
|
||||
const remove = true
|
||||
for _, subnet := range subnets {
|
||||
failed := false
|
||||
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||
failed = true
|
||||
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
||||
}
|
||||
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||
failed = true
|
||||
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
||||
}
|
||||
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
|
||||
failed = true
|
||||
c.logger.Error("cannot remove outdated allowed subnet route: %s", err)
|
||||
}
|
||||
if failed {
|
||||
continue
|
||||
}
|
||||
c.allowedSubnets = removeSubnetFromSubnets(c.allowedSubnets, subnet)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, defaultGateway net.IP) error {
|
||||
const remove = false
|
||||
for _, subnet := range subnets {
|
||||
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||
}
|
||||
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
||||
return fmt.Errorf("cannot add route for allowed subnet: %w", err)
|
||||
}
|
||||
c.allowedSubnets = append(c.allowedSubnets, subnet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
112
internal/firewall/vpn.go
Normal file
112
internal/firewall/vpn.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating VPN connections internal list")
|
||||
c.vpnConnections = make([]models.OpenVPNConnection, len(connections))
|
||||
copy(c.vpnConnections, connections)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("setting VPN connections through firewall...")
|
||||
|
||||
connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections)
|
||||
connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections)
|
||||
if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
defaultInterface, _, _, err := c.routing.DefaultRoute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
||||
}
|
||||
|
||||
// TODO remove elsewhere?
|
||||
if err := c.acceptOutputThroughInterface(ctx, string(constants.TUN), false); err != nil {
|
||||
return fmt.Errorf("cannot allow traffic through tunnel: %w", err)
|
||||
}
|
||||
|
||||
c.removeConnections(ctx, connectionsToRemove, defaultInterface)
|
||||
if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil {
|
||||
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection {
|
||||
L := len(connections)
|
||||
for i := range connections {
|
||||
if connection.Equal(connections[i]) {
|
||||
connections[i] = connections[L-1]
|
||||
connections = connections[:L-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
return connections
|
||||
}
|
||||
|
||||
func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) {
|
||||
for _, newConnection := range newConnections {
|
||||
found := false
|
||||
for _, oldConnection := range oldConnections {
|
||||
if oldConnection.Equal(newConnection) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
connectionsToAdd = append(connectionsToAdd, newConnection)
|
||||
}
|
||||
}
|
||||
return connectionsToAdd
|
||||
}
|
||||
|
||||
func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) {
|
||||
for _, oldConnection := range oldConnections {
|
||||
found := false
|
||||
for _, newConnection := range newConnections {
|
||||
if oldConnection.Equal(newConnection) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
connectionsToRemove = append(connectionsToRemove, oldConnection)
|
||||
}
|
||||
}
|
||||
return connectionsToRemove
|
||||
}
|
||||
|
||||
func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) {
|
||||
for _, conn := range connections {
|
||||
const remove = true
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err)
|
||||
continue
|
||||
}
|
||||
c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error {
|
||||
const remove = false
|
||||
for _, conn := range connections {
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
||||
return err
|
||||
}
|
||||
c.vpnConnections = append(c.vpnConnections, conn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -7,3 +7,7 @@ type OpenVPNConnection struct {
|
||||
Port uint16
|
||||
Protocol NetworkProtocol
|
||||
}
|
||||
|
||||
func (o *OpenVPNConnection) Equal(other OpenVPNConnection) bool {
|
||||
return o.IP.Equal(other.IP) && o.Port == other.Port && o.Protocol == other.Protocol
|
||||
}
|
||||
|
||||
@@ -2,43 +2,64 @@ package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/provider"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
|
||||
Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup)
|
||||
}
|
||||
|
||||
type looper struct {
|
||||
conf Configurator
|
||||
// Variable parameters
|
||||
provider models.VPNProvider
|
||||
settings settings.OpenVPN
|
||||
logger logging.Logger
|
||||
streamMerger command.StreamMerger
|
||||
fatalOnError func(err error)
|
||||
// Fixed parameters
|
||||
uid int
|
||||
gid int
|
||||
// Configurators
|
||||
conf Configurator
|
||||
fw firewall.Configurator
|
||||
// Other objects
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
fileManager files.FileManager
|
||||
streamMerger command.StreamMerger
|
||||
fatalOnError func(err error)
|
||||
}
|
||||
|
||||
func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger,
|
||||
streamMerger command.StreamMerger, fatalOnError func(err error), uid, gid int) Looper {
|
||||
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||
uid, gid int,
|
||||
conf Configurator, fw firewall.Configurator,
|
||||
logger logging.Logger, client network.Client, fileManager files.FileManager,
|
||||
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
|
||||
return &looper{
|
||||
conf: conf,
|
||||
provider: provider,
|
||||
settings: settings,
|
||||
logger: logger.WithPrefix("openvpn: "),
|
||||
streamMerger: streamMerger,
|
||||
fatalOnError: fatalOnError,
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
conf: conf,
|
||||
fw: fw,
|
||||
logger: logger.WithPrefix("openvpn: "),
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
streamMerger: streamMerger,
|
||||
fatalOnError: fatalOnError,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
|
||||
func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
select {
|
||||
@@ -46,17 +67,51 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
for {
|
||||
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
||||
err := l.conf.WriteAuthFile(
|
||||
l.settings.User,
|
||||
l.settings.Password,
|
||||
defer l.logger.Warn("loop exited")
|
||||
|
||||
for ctx.Err() == nil {
|
||||
providerConf := provider.New(l.provider, l.client, l.fileManager)
|
||||
connections, err := providerConf.GetOpenVPNConnections(l.settings.Provider.ServerSelection)
|
||||
l.fatalOnError(err)
|
||||
err = providerConf.BuildConf(
|
||||
connections,
|
||||
l.settings.Verbosity,
|
||||
l.uid,
|
||||
l.gid,
|
||||
l.settings.Root,
|
||||
l.settings.Cipher,
|
||||
l.settings.Auth,
|
||||
l.settings.Provider.ExtraConfigOptions,
|
||||
)
|
||||
l.fatalOnError(err)
|
||||
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||
|
||||
err = l.conf.WriteAuthFile(l.settings.User, l.settings.Password, l.uid, l.gid)
|
||||
l.fatalOnError(err)
|
||||
|
||||
if err := l.fw.SetVPNConnections(ctx, connections); err != nil {
|
||||
l.fatalOnError(err)
|
||||
}
|
||||
|
||||
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
|
||||
|
||||
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||
if err != nil {
|
||||
openvpnCancel()
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
|
||||
go func(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-portForward:
|
||||
l.portForward(ctx, providerConf)
|
||||
}
|
||||
}
|
||||
}(openvpnCtx)
|
||||
|
||||
go l.streamMerger.Merge(openvpnCtx, stream,
|
||||
command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
||||
waitError := make(chan error)
|
||||
@@ -74,13 +129,53 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
||||
case <-restart: // triggered restart
|
||||
l.logger.Info("restarting")
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
close(waitError)
|
||||
case err := <-waitError: // unexpected error
|
||||
l.logger.Warn(err)
|
||||
l.logger.Info("restarting")
|
||||
openvpnCancel()
|
||||
close(waitError)
|
||||
time.Sleep(time.Second)
|
||||
l.logAndWait(ctx, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
l.logger.Error(err)
|
||||
l.logger.Info("retrying in 30 seconds")
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel() // just for the linter
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func (l *looper) portForward(ctx context.Context, providerConf provider.Provider) {
|
||||
if !l.settings.Provider.PortForwarding.Enabled {
|
||||
return
|
||||
}
|
||||
var port uint16
|
||||
err := fmt.Errorf("")
|
||||
for err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
port, err = providerConf.GetPortForward()
|
||||
if err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
l.logger.Info("port forwarded is %d", port)
|
||||
}
|
||||
|
||||
filepath := l.settings.Provider.PortForwarding.Filepath
|
||||
l.logger.Info("writing forwarded port to %s", filepath)
|
||||
err = l.fileManager.WriteLinesToFile(
|
||||
string(filepath), []string{fmt.Sprintf("%d", port)},
|
||||
files.Ownership(l.uid, l.gid), files.Permissions(0400),
|
||||
)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
|
||||
if err := l.fw.SetPortForward(ctx, port); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -125,11 +124,3 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity
|
||||
func (c *cyberghost) GetPortForward() (port uint16, err error) {
|
||||
panic("port forwarding is not supported for cyberghost")
|
||||
}
|
||||
|
||||
func (c *cyberghost) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
||||
panic("port forwarding is not supported for cyberghost")
|
||||
}
|
||||
|
||||
func (c *cyberghost) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
||||
panic("port forwarding is not supported for cyberghost")
|
||||
}
|
||||
|
||||
@@ -1,24 +1,20 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
type mullvad struct {
|
||||
fileManager files.FileManager
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func newMullvad(fileManager files.FileManager, logger logging.Logger) *mullvad {
|
||||
func newMullvad(fileManager files.FileManager) *mullvad {
|
||||
return &mullvad{
|
||||
fileManager: fileManager,
|
||||
logger: logger.WithPrefix("Mullvad configurator: "),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,11 +102,3 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u
|
||||
func (m *mullvad) GetPortForward() (port uint16, err error) {
|
||||
panic("port forwarding is not supported for mullvad")
|
||||
}
|
||||
|
||||
func (m *mullvad) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
||||
panic("port forwarding is not supported for mullvad")
|
||||
}
|
||||
|
||||
func (m *mullvad) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
||||
panic("port forwarding is not supported for mullvad")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -14,24 +13,21 @@ import (
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/verification"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
type pia struct {
|
||||
client network.Client
|
||||
fileManager files.FileManager
|
||||
firewall firewall.Configurator
|
||||
random random.Random
|
||||
verifyPort func(port string) error
|
||||
lookupIP func(host string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
func newPrivateInternetAccess(client network.Client, fileManager files.FileManager, firewall firewall.Configurator) *pia {
|
||||
func newPrivateInternetAccess(client network.Client, fileManager files.FileManager) *pia {
|
||||
return &pia{
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
firewall: firewall,
|
||||
random: random.NewRandom(),
|
||||
verifyPort: verification.NewVerifier().VerifyPort,
|
||||
lookupIP: net.LookupIP}
|
||||
@@ -168,7 +164,7 @@ func (p *pia) GetPortForward() (port uint16, err error) {
|
||||
}
|
||||
clientID := hex.EncodeToString(b)
|
||||
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
|
||||
content, status, err := p.client.GetContent(url)
|
||||
content, status, err := p.client.GetContent(url) // TODO add ctx
|
||||
switch {
|
||||
case err != nil:
|
||||
return 0, err
|
||||
@@ -185,15 +181,3 @@ func (p *pia) GetPortForward() (port uint16, err error) {
|
||||
}
|
||||
return body.Port, nil
|
||||
}
|
||||
|
||||
func (p *pia) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
||||
return p.fileManager.WriteLinesToFile(
|
||||
string(filepath),
|
||||
[]string{fmt.Sprintf("%d", port)},
|
||||
files.Ownership(uid, gid),
|
||||
files.Permissions(0400))
|
||||
}
|
||||
|
||||
func (p *pia) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
||||
return p.firewall.AllowInputTrafficOnPort(ctx, device, port)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
@@ -16,16 +12,14 @@ type Provider interface {
|
||||
GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error)
|
||||
BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (err error)
|
||||
GetPortForward() (port uint16, err error)
|
||||
WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error)
|
||||
AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error)
|
||||
}
|
||||
|
||||
func New(provider models.VPNProvider, logger logging.Logger, client network.Client, fileManager files.FileManager, firewall firewall.Configurator) Provider {
|
||||
func New(provider models.VPNProvider, client network.Client, fileManager files.FileManager) Provider {
|
||||
switch provider {
|
||||
case constants.PrivateInternetAccess:
|
||||
return newPrivateInternetAccess(client, fileManager, firewall)
|
||||
return newPrivateInternetAccess(client, fileManager)
|
||||
case constants.Mullvad:
|
||||
return newMullvad(fileManager, logger)
|
||||
return newMullvad(fileManager)
|
||||
case constants.Windscribe:
|
||||
return newWindscribe(fileManager)
|
||||
case constants.Surfshark:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -127,11 +126,3 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity,
|
||||
func (s *surfshark) GetPortForward() (port uint16, err error) {
|
||||
panic("port forwarding is not supported for surfshark")
|
||||
}
|
||||
|
||||
func (s *surfshark) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
||||
panic("port forwarding is not supported for surfshark")
|
||||
}
|
||||
|
||||
func (s *surfshark) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
||||
panic("port forwarding is not supported for surfshark")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -124,11 +123,3 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity
|
||||
func (w *windscribe) GetPortForward() (port uint16, err error) {
|
||||
panic("port forwarding is not supported for windscribe")
|
||||
}
|
||||
|
||||
func (w *windscribe) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
||||
panic("port forwarding is not supported for windscribe")
|
||||
}
|
||||
|
||||
func (w *windscribe) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
||||
panic("port forwarding is not supported for windscribe")
|
||||
}
|
||||
|
||||
@@ -7,29 +7,34 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (r *routing) AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
||||
for _, subnet := range subnets {
|
||||
func (r *routing) AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
||||
subnetStr := subnet.String()
|
||||
r.logger.Info("adding %s as route via %s %s", subnetStr, defaultGateway, defaultInterface)
|
||||
exists, err := r.routeExists(subnet)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if exists { // thanks to @npawelek https://github.com/npawelek
|
||||
if err := r.removeRoute(ctx, subnet); err != nil {
|
||||
return err
|
||||
} else if exists {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface)
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface)
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err)
|
||||
}
|
||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway, "dev", defaultInterface, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *routing) removeRoute(ctx context.Context, subnet net.IPNet) (err error) {
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "del", subnet.String())
|
||||
func (r *routing) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) {
|
||||
subnetStr := subnet.String()
|
||||
r.logger.Info("deleting route for %s", subnetStr)
|
||||
exists, err := r.routeExists(subnet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err)
|
||||
return err
|
||||
} else if !exists { // thanks to @npawelek https://github.com/npawelek
|
||||
return nil
|
||||
}
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "del", subnetStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete route for %s: %s: %w", subnetStr, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,12 +8,16 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/golibs/command/mock_command"
|
||||
"github.com/qdm12/golibs/files/mock_files"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_removeRoute(t *testing.T) {
|
||||
func Test_DeleteRouteVia(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
tests := map[string]struct {
|
||||
subnet net.IPNet
|
||||
runOutput string
|
||||
@@ -22,26 +26,26 @@ func Test_removeRoute(t *testing.T) {
|
||||
}{
|
||||
"no output no error": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 1, 0},
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
},
|
||||
"error only": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 1, 0},
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
runErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: : error"),
|
||||
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: : error"),
|
||||
},
|
||||
"error and output": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 1, 0},
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
runErr: fmt.Errorf("error"),
|
||||
runOutput: "output",
|
||||
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: output: error"),
|
||||
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: output: error"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
@@ -50,12 +54,26 @@ func Test_removeRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
commander := mock_command.NewMockCommander(mockCtrl)
|
||||
|
||||
commander.EXPECT().Run(context.Background(), "ip", "route", "del", tc.subnet.String()).
|
||||
subnetStr := tc.subnet.String()
|
||||
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("deleting route for %s")
|
||||
commander := mock_command.NewMockCommander(mockCtrl)
|
||||
commander.EXPECT().Run(ctx, "ip", "route", "del", subnetStr).
|
||||
Return(tc.runOutput, tc.runErr).Times(1)
|
||||
r := &routing{commander: commander}
|
||||
err := r.removeRoute(context.Background(), tc.subnet)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
routesData := []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||
`)
|
||||
fileManager.EXPECT().ReadFile(string(constants.NetRoute)).Return(routesData, nil)
|
||||
r := &routing{
|
||||
logger: logger,
|
||||
commander: commander,
|
||||
fileManager: fileManager,
|
||||
}
|
||||
|
||||
err := r.DeleteRouteVia(ctx, tc.subnet)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
|
||||
@@ -10,7 +10,8 @@ import (
|
||||
)
|
||||
|
||||
type Routing interface {
|
||||
AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||
AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
|
||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
|
||||
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
|
||||
}
|
||||
|
||||
@@ -59,6 +59,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
||||
}
|
||||
defer l.logger.Warn("loop exited")
|
||||
|
||||
var previousPort uint16
|
||||
for ctx.Err() == nil {
|
||||
nameserver := l.dnsSettings.PlaintextAddress.String()
|
||||
if l.dnsSettings.Enabled {
|
||||
@@ -75,11 +76,19 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
|
||||
// TODO remove firewall rule on exit below
|
||||
if err != nil {
|
||||
|
||||
if previousPort > 0 {
|
||||
if err := l.firewallConf.RemoveAllowedPort(ctx, previousPort); err != nil {
|
||||
l.logger.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := l.firewallConf.SetAllowedPort(ctx, l.settings.Port); err != nil {
|
||||
l.logger.Error(err)
|
||||
continue
|
||||
}
|
||||
previousPort = l.settings.Port
|
||||
|
||||
shadowsocksCtx, shadowsocksCancel := context.WithCancel(context.Background())
|
||||
stdout, stderr, waitFn, err := l.conf.Start(shadowsocksCtx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log)
|
||||
if err != nil {
|
||||
|
||||
@@ -57,6 +57,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
||||
}
|
||||
defer l.logger.Warn("loop exited")
|
||||
|
||||
var previousPort uint16
|
||||
for ctx.Err() == nil {
|
||||
err := l.conf.MakeConf(
|
||||
l.settings.LogLevel,
|
||||
@@ -69,11 +70,19 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
|
||||
// TODO remove firewall rule on exit below
|
||||
if err != nil {
|
||||
|
||||
if previousPort > 0 {
|
||||
if err := l.firewallConf.RemoveAllowedPort(ctx, previousPort); err != nil {
|
||||
l.logger.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := l.firewallConf.SetAllowedPort(ctx, l.settings.Port); err != nil {
|
||||
l.logger.Error(err)
|
||||
continue
|
||||
}
|
||||
previousPort = l.settings.Port
|
||||
|
||||
tinyproxyCtx, tinyproxyCancel := context.WithCancel(context.Background())
|
||||
stream, waitFn, err := l.conf.Start(tinyproxyCtx)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user