diff --git a/cmd/main.go b/cmd/main.go index 5f657954..2b8b81dc 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -57,7 +57,7 @@ func main() { dnsConf := dns.NewConfigurator(logger, client, fileManager) firewallConf := firewall.NewConfigurator(logger) routingConf := routing.NewRouting(logger, fileManager) - piaConf := pia.NewConfigurator(client, fileManager, firewallConf, logger) + piaConf := pia.NewConfigurator(client, fileManager, firewallConf) mullvadConf := mullvad.NewConfigurator(fileManager, logger) windscribeConf := windscribe.NewConfigurator(fileManager) tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger) @@ -121,15 +121,13 @@ func main() { streamMerger.CollectLines(ctx, func(line string) { logger.Info(line) if strings.Contains(line, "Initialization Sequence Completed") { - time.AfterFunc(time.Second, func() { - onConnected(ctx, logger, routingConf, fileManager, piaConf, - defaultInterface, - allSettings.PIA.PortForwarding.Enabled, - allSettings.PIA.PortForwarding.Filepath, - allSettings.System.IPStatusFilepath, - allSettings.System.UID, - allSettings.System.GID) - }) + go onConnected(logger, routingConf, fileManager, piaConf, + defaultInterface, + allSettings.PIA.PortForwarding.Enabled, + allSettings.PIA.PortForwarding.Filepath, + allSettings.System.IPStatusFilepath, + allSettings.System.UID, + allSettings.System.GID) } }, func(err error) { logger.Error(err) @@ -289,7 +287,8 @@ func main() { logger.Warn("context canceled, shutting down") } if allSettings.PIA.PortForwarding.Enabled { - if err := piaConf.ClearPortForward(allSettings.PIA.PortForwarding.Filepath, allSettings.System.UID, allSettings.System.GID); err != nil { + logger.Info("Clearing forwarded port status file %s", allSettings.PIA.PortForwarding.Filepath) + if err := fileManager.Remove(string(allSettings.PIA.PortForwarding.Filepath)); err != nil { logger.Error(err) } } @@ -303,7 +302,6 @@ func main() { } func onConnected( - ctx context.Context, logger logging.Logger, routingConf routing.Routing, fileManager files.FileManager, @@ -331,21 +329,29 @@ func onConnected( if !portForwarding { return } - var port uint16 - for { - port, err = piaConf.GetPortForward() - if err != nil { - logger.Error("port forwarding:", err) - logger.Info("port forwarding: retrying in 5 seconds...") - time.Sleep(5 * time.Second) - } else { - break + time.AfterFunc(5*time.Second, func() { + pfLogger := logger.WithPrefix("port forwarding: ") + var port uint16 + for { + port, err = piaConf.GetPortForward() + if err != nil { + pfLogger.Error(err) + pfLogger.Info("retrying in 5 seconds...") + time.Sleep(5 * time.Second) + } else { + pfLogger.Info("port forwarded is %d", port) + break + } } - } - if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil { - logger.Error("port forwarding:", err) - } - if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil { - logger.Error("port forwarding:", err) - } + pfLogger.Info("writing forwarded port to %s", portForwardingFilepath) + if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil { + pfLogger.Error(err) + } + pfLogger.Info("allowing forwarded port %d through firewall", port) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil { + pfLogger.Error(err) + } + }) } diff --git a/internal/pia/pia.go b/internal/pia/pia.go index 41b685d5..aae37f0d 100644 --- a/internal/pia/pia.go +++ b/internal/pia/pia.go @@ -6,7 +6,6 @@ import ( "github.com/qdm12/golibs/crypto/random" "github.com/qdm12/golibs/files" - "github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/network" "github.com/qdm12/golibs/verification" "github.com/qdm12/private-internet-access-docker/internal/firewall" @@ -20,7 +19,6 @@ type Configurator interface { BuildConf(connections []models.OpenVPNConnection, encryption models.PIAEncryption, verbosity, uid, gid int, root bool, cipher, auth string) (err error) GetPortForward() (port uint16, err error) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) - ClearPortForward(filepath models.Filepath, uid, gid int) (err error) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) } @@ -28,19 +26,17 @@ type configurator struct { client network.Client fileManager files.FileManager firewall firewall.Configurator - logger logging.Logger random random.Random verifyPort func(port string) error lookupIP func(host string) ([]net.IP, error) } // NewConfigurator returns a new Configurator object -func NewConfigurator(client network.Client, fileManager files.FileManager, firewall firewall.Configurator, logger logging.Logger) Configurator { +func NewConfigurator(client network.Client, fileManager files.FileManager, firewall firewall.Configurator) Configurator { return &configurator{ client: client, fileManager: fileManager, firewall: firewall, - logger: logger.WithPrefix("PIA configurator: "), random: random.NewRandom(), verifyPort: verification.NewVerifier().VerifyPort, lookupIP: net.LookupIP} diff --git a/internal/pia/portforward.go b/internal/pia/portforward.go index fde9adb0..116aa2b5 100644 --- a/internal/pia/portforward.go +++ b/internal/pia/portforward.go @@ -13,7 +13,6 @@ import ( ) func (c *configurator) GetPortForward() (port uint16, err error) { - c.logger.Info("Obtaining port to be forwarded") b, err := c.random.GenerateRandomBytes(32) if err != nil { return 0, err @@ -35,12 +34,10 @@ func (c *configurator) GetPortForward() (port uint16, err error) { if err := json.Unmarshal(content, &body); err != nil { return 0, fmt.Errorf("port forwarding response: %w", err) } - c.logger.Info("Port forwarded is %d", body.Port) return body.Port, nil } func (c *configurator) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) { - c.logger.Info("Writing forwarded port to %s", filepath) return c.fileManager.WriteLinesToFile( string(filepath), []string{fmt.Sprintf("%d", port)}, @@ -49,11 +46,5 @@ func (c *configurator) WritePortForward(filepath models.Filepath, port uint16, u } func (c *configurator) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) { - c.logger.Info("Allowing forwarded port %d through firewall", port) return c.firewall.AllowInputTrafficOnPort(ctx, device, port) } - -func (c *configurator) ClearPortForward(filepath models.Filepath, uid, gid int) (err error) { - c.logger.Info("Clearing forwarded port status file %s", filepath) - return c.fileManager.Remove(string(filepath)) -}