Better onConnected logic

- First port forward after 5 seconds
- Public IP obtained ASAP
- Logging in main only
- Allow port forward firewall with 1 second timeout local context
This commit is contained in:
Quentin McGaw
2020-04-30 12:54:48 +00:00
parent ac706bd156
commit 94255aaa38
3 changed files with 35 additions and 42 deletions

View File

@@ -57,7 +57,7 @@ func main() {
dnsConf := dns.NewConfigurator(logger, client, fileManager) dnsConf := dns.NewConfigurator(logger, client, fileManager)
firewallConf := firewall.NewConfigurator(logger) firewallConf := firewall.NewConfigurator(logger)
routingConf := routing.NewRouting(logger, fileManager) routingConf := routing.NewRouting(logger, fileManager)
piaConf := pia.NewConfigurator(client, fileManager, firewallConf, logger) piaConf := pia.NewConfigurator(client, fileManager, firewallConf)
mullvadConf := mullvad.NewConfigurator(fileManager, logger) mullvadConf := mullvad.NewConfigurator(fileManager, logger)
windscribeConf := windscribe.NewConfigurator(fileManager) windscribeConf := windscribe.NewConfigurator(fileManager)
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger) tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
@@ -121,15 +121,13 @@ func main() {
streamMerger.CollectLines(ctx, func(line string) { streamMerger.CollectLines(ctx, func(line string) {
logger.Info(line) logger.Info(line)
if strings.Contains(line, "Initialization Sequence Completed") { if strings.Contains(line, "Initialization Sequence Completed") {
time.AfterFunc(time.Second, func() { go onConnected(logger, routingConf, fileManager, piaConf,
onConnected(ctx, logger, routingConf, fileManager, piaConf, defaultInterface,
defaultInterface, allSettings.PIA.PortForwarding.Enabled,
allSettings.PIA.PortForwarding.Enabled, allSettings.PIA.PortForwarding.Filepath,
allSettings.PIA.PortForwarding.Filepath, allSettings.System.IPStatusFilepath,
allSettings.System.IPStatusFilepath, allSettings.System.UID,
allSettings.System.UID, allSettings.System.GID)
allSettings.System.GID)
})
} }
}, func(err error) { }, func(err error) {
logger.Error(err) logger.Error(err)
@@ -289,7 +287,8 @@ func main() {
logger.Warn("context canceled, shutting down") logger.Warn("context canceled, shutting down")
} }
if allSettings.PIA.PortForwarding.Enabled { if allSettings.PIA.PortForwarding.Enabled {
if err := piaConf.ClearPortForward(allSettings.PIA.PortForwarding.Filepath, allSettings.System.UID, allSettings.System.GID); err != nil { logger.Info("Clearing forwarded port status file %s", allSettings.PIA.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.PIA.PortForwarding.Filepath)); err != nil {
logger.Error(err) logger.Error(err)
} }
} }
@@ -303,7 +302,6 @@ func main() {
} }
func onConnected( func onConnected(
ctx context.Context,
logger logging.Logger, logger logging.Logger,
routingConf routing.Routing, routingConf routing.Routing,
fileManager files.FileManager, fileManager files.FileManager,
@@ -331,21 +329,29 @@ func onConnected(
if !portForwarding { if !portForwarding {
return return
} }
var port uint16 time.AfterFunc(5*time.Second, func() {
for { pfLogger := logger.WithPrefix("port forwarding: ")
port, err = piaConf.GetPortForward() var port uint16
if err != nil { for {
logger.Error("port forwarding:", err) port, err = piaConf.GetPortForward()
logger.Info("port forwarding: retrying in 5 seconds...") if err != nil {
time.Sleep(5 * time.Second) pfLogger.Error(err)
} else { pfLogger.Info("retrying in 5 seconds...")
break time.Sleep(5 * time.Second)
} else {
pfLogger.Info("port forwarded is %d", port)
break
}
} }
} pfLogger.Info("writing forwarded port to %s", portForwardingFilepath)
if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil { if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil {
logger.Error("port forwarding:", err) pfLogger.Error(err)
} }
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil { pfLogger.Info("allowing forwarded port %d through firewall", port)
logger.Error("port forwarding:", err) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
} defer cancel()
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
pfLogger.Error(err)
}
})
} }

View File

@@ -6,7 +6,6 @@ import (
"github.com/qdm12/golibs/crypto/random" "github.com/qdm12/golibs/crypto/random"
"github.com/qdm12/golibs/files" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/verification" "github.com/qdm12/golibs/verification"
"github.com/qdm12/private-internet-access-docker/internal/firewall" "github.com/qdm12/private-internet-access-docker/internal/firewall"
@@ -20,7 +19,6 @@ type Configurator interface {
BuildConf(connections []models.OpenVPNConnection, encryption models.PIAEncryption, verbosity, uid, gid int, root bool, cipher, auth string) (err error) BuildConf(connections []models.OpenVPNConnection, encryption models.PIAEncryption, verbosity, uid, gid int, root bool, cipher, auth string) (err error)
GetPortForward() (port uint16, err error) GetPortForward() (port uint16, err error)
WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error)
ClearPortForward(filepath models.Filepath, uid, gid int) (err error)
AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error)
} }
@@ -28,19 +26,17 @@ type configurator struct {
client network.Client client network.Client
fileManager files.FileManager fileManager files.FileManager
firewall firewall.Configurator firewall firewall.Configurator
logger logging.Logger
random random.Random random random.Random
verifyPort func(port string) error verifyPort func(port string) error
lookupIP func(host string) ([]net.IP, error) lookupIP func(host string) ([]net.IP, error)
} }
// NewConfigurator returns a new Configurator object // NewConfigurator returns a new Configurator object
func NewConfigurator(client network.Client, fileManager files.FileManager, firewall firewall.Configurator, logger logging.Logger) Configurator { func NewConfigurator(client network.Client, fileManager files.FileManager, firewall firewall.Configurator) Configurator {
return &configurator{ return &configurator{
client: client, client: client,
fileManager: fileManager, fileManager: fileManager,
firewall: firewall, firewall: firewall,
logger: logger.WithPrefix("PIA configurator: "),
random: random.NewRandom(), random: random.NewRandom(),
verifyPort: verification.NewVerifier().VerifyPort, verifyPort: verification.NewVerifier().VerifyPort,
lookupIP: net.LookupIP} lookupIP: net.LookupIP}

View File

@@ -13,7 +13,6 @@ import (
) )
func (c *configurator) GetPortForward() (port uint16, err error) { func (c *configurator) GetPortForward() (port uint16, err error) {
c.logger.Info("Obtaining port to be forwarded")
b, err := c.random.GenerateRandomBytes(32) b, err := c.random.GenerateRandomBytes(32)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -35,12 +34,10 @@ func (c *configurator) GetPortForward() (port uint16, err error) {
if err := json.Unmarshal(content, &body); err != nil { if err := json.Unmarshal(content, &body); err != nil {
return 0, fmt.Errorf("port forwarding response: %w", err) return 0, fmt.Errorf("port forwarding response: %w", err)
} }
c.logger.Info("Port forwarded is %d", body.Port)
return body.Port, nil return body.Port, nil
} }
func (c *configurator) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) { func (c *configurator) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
c.logger.Info("Writing forwarded port to %s", filepath)
return c.fileManager.WriteLinesToFile( return c.fileManager.WriteLinesToFile(
string(filepath), string(filepath),
[]string{fmt.Sprintf("%d", port)}, []string{fmt.Sprintf("%d", port)},
@@ -49,11 +46,5 @@ func (c *configurator) WritePortForward(filepath models.Filepath, port uint16, u
} }
func (c *configurator) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) { func (c *configurator) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
c.logger.Info("Allowing forwarded port %d through firewall", port)
return c.firewall.AllowInputTrafficOnPort(ctx, device, port) return c.firewall.AllowInputTrafficOnPort(ctx, device, port)
} }
func (c *configurator) ClearPortForward(filepath models.Filepath, uid, gid int) (err error) {
c.logger.Info("Clearing forwarded port status file %s", filepath)
return c.fileManager.Remove(string(filepath))
}