Get public IP and version only when DNS is ready

This commit is contained in:
Quentin McGaw
2020-09-12 18:50:42 +00:00
parent cb1520cb18
commit 464c7074d0
2 changed files with 16 additions and 13 deletions

View File

@@ -156,11 +156,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
}
}
connectedCh := make(chan struct{})
signalConnected := func() {
connectedCh <- struct{}{}
}
connectedCh, dnsReadyCh := make(chan struct{}), make(chan struct{})
signalConnected := func() { connectedCh <- struct{}{} }
signalDNSReady := func() { dnsReadyCh <- struct{}{} }
defer close(connectedCh)
defer close(dnsReadyCh)
if allSettings.Firewall.Enabled {
err := firewallConf.SetEnabled(ctx, true) // disabled by default
@@ -208,7 +208,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
restartUnbound := unboundLooper.Restart
wg.Add(1)
// wait for restartUnbound or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg)
go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid)
restartPublicIP := publicIPLooper.Restart
@@ -267,7 +267,10 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound)
case <-dnsReadyCh:
restartPublicIP() // TODO do not restart if disabled
versionInformation()
}
}
}()
@@ -374,10 +377,9 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
}
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
portForward, restartUnbound, restartPublicIP, versionInformation func(),
portForward, restartUnbound func(),
) {
restartUnbound()
restartPublicIP()
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
time.AfterFunc(5*time.Second, portForward)
}
@@ -392,5 +394,4 @@ func onConnected(allSettings settings.Settings, logger logging.Logger, routingCo
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
}
}
versionInformation()
}

View File

@@ -13,7 +13,7 @@ import (
)
type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup)
Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func())
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart()
Start()
@@ -93,7 +93,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
<-ctx.Done()
}
func (l *looper) waitForFirstStart(ctx context.Context) {
func (l *looper) waitForFirstStart(ctx context.Context, signalDNSReady func()) {
for {
select {
case <-l.stop:
@@ -103,6 +103,7 @@ func (l *looper) waitForFirstStart(ctx context.Context) {
if l.isEnabled() {
return
}
signalDNSReady()
l.logger.Info("not restarting because disabled")
case <-l.start:
l.setEnabled(true)
@@ -138,11 +139,11 @@ func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel conte
}
}
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) {
defer wg.Done()
const fallback = false
l.useUnencryptedDNS(fallback)
l.waitForFirstStart(ctx)
l.waitForFirstStart(ctx, signalDNSReady)
if ctx.Err() != nil {
return
}
@@ -207,6 +208,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
waitError <- err
}()
l.logger.Info("DNS over TLS is ready")
signalDNSReady()
stayHere := true
for stayHere {