Get public IP and version only when DNS is ready

This commit is contained in:
Quentin McGaw
2020-09-12 18:50:42 +00:00
parent cb1520cb18
commit 464c7074d0
2 changed files with 16 additions and 13 deletions

View File

@@ -156,11 +156,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
}
}
connectedCh := make(chan struct{})
signalConnected := func() {
connectedCh <- struct{}{}
}
connectedCh, dnsReadyCh := make(chan struct{}), make(chan struct{})
signalConnected := func() { connectedCh <- struct{}{} }
signalDNSReady := func() { dnsReadyCh <- struct{}{} }
defer close(connectedCh)
defer close(dnsReadyCh)
if allSettings.Firewall.Enabled {
err := firewallConf.SetEnabled(ctx, true) // disabled by default
@@ -208,7 +208,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
restartUnbound := unboundLooper.Restart
wg.Add(1)
// wait for restartUnbound or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg)
go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid)
restartPublicIP := publicIPLooper.Restart
@@ -267,7 +267,10 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound)
case <-dnsReadyCh:
restartPublicIP() // TODO do not restart if disabled
versionInformation()
}
}
}()
@@ -374,10 +377,9 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
}
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
portForward, restartUnbound, restartPublicIP, versionInformation func(),
portForward, restartUnbound func(),
) {
restartUnbound()
restartPublicIP()
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
time.AfterFunc(5*time.Second, portForward)
}
@@ -392,5 +394,4 @@ func onConnected(allSettings settings.Settings, logger logging.Logger, routingCo
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
}
}
versionInformation()
}