Better onConnected logic

- First port forward after 5 seconds
- Public IP obtained ASAP
- Logging in main only
- Allow port forward firewall with 1 second timeout local context
This commit is contained in:
Quentin McGaw
2020-04-30 12:54:48 +00:00
parent ac706bd156
commit 94255aaa38
3 changed files with 35 additions and 42 deletions

View File

@@ -57,7 +57,7 @@ func main() {
dnsConf := dns.NewConfigurator(logger, client, fileManager)
firewallConf := firewall.NewConfigurator(logger)
routingConf := routing.NewRouting(logger, fileManager)
piaConf := pia.NewConfigurator(client, fileManager, firewallConf, logger)
piaConf := pia.NewConfigurator(client, fileManager, firewallConf)
mullvadConf := mullvad.NewConfigurator(fileManager, logger)
windscribeConf := windscribe.NewConfigurator(fileManager)
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
@@ -121,15 +121,13 @@ func main() {
streamMerger.CollectLines(ctx, func(line string) {
logger.Info(line)
if strings.Contains(line, "Initialization Sequence Completed") {
time.AfterFunc(time.Second, func() {
onConnected(ctx, logger, routingConf, fileManager, piaConf,
defaultInterface,
allSettings.PIA.PortForwarding.Enabled,
allSettings.PIA.PortForwarding.Filepath,
allSettings.System.IPStatusFilepath,
allSettings.System.UID,
allSettings.System.GID)
})
go onConnected(logger, routingConf, fileManager, piaConf,
defaultInterface,
allSettings.PIA.PortForwarding.Enabled,
allSettings.PIA.PortForwarding.Filepath,
allSettings.System.IPStatusFilepath,
allSettings.System.UID,
allSettings.System.GID)
}
}, func(err error) {
logger.Error(err)
@@ -289,7 +287,8 @@ func main() {
logger.Warn("context canceled, shutting down")
}
if allSettings.PIA.PortForwarding.Enabled {
if err := piaConf.ClearPortForward(allSettings.PIA.PortForwarding.Filepath, allSettings.System.UID, allSettings.System.GID); err != nil {
logger.Info("Clearing forwarded port status file %s", allSettings.PIA.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.PIA.PortForwarding.Filepath)); err != nil {
logger.Error(err)
}
}
@@ -303,7 +302,6 @@ func main() {
}
func onConnected(
ctx context.Context,
logger logging.Logger,
routingConf routing.Routing,
fileManager files.FileManager,
@@ -331,21 +329,29 @@ func onConnected(
if !portForwarding {
return
}
var port uint16
for {
port, err = piaConf.GetPortForward()
if err != nil {
logger.Error("port forwarding:", err)
logger.Info("port forwarding: retrying in 5 seconds...")
time.Sleep(5 * time.Second)
} else {
break
time.AfterFunc(5*time.Second, func() {
pfLogger := logger.WithPrefix("port forwarding: ")
var port uint16
for {
port, err = piaConf.GetPortForward()
if err != nil {
pfLogger.Error(err)
pfLogger.Info("retrying in 5 seconds...")
time.Sleep(5 * time.Second)
} else {
pfLogger.Info("port forwarded is %d", port)
break
}
}
}
if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil {
logger.Error("port forwarding:", err)
}
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
logger.Error("port forwarding:", err)
}
pfLogger.Info("writing forwarded port to %s", portForwardingFilepath)
if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil {
pfLogger.Error(err)
}
pfLogger.Info("allowing forwarded port %d through firewall", port)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
pfLogger.Error(err)
}
})
}