Get public IP and version only when DNS is ready

This commit is contained in:
Quentin McGaw
2020-09-12 18:50:42 +00:00
parent cb1520cb18
commit 464c7074d0
2 changed files with 16 additions and 13 deletions

View File

@@ -156,11 +156,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
} }
} }
connectedCh := make(chan struct{}) connectedCh, dnsReadyCh := make(chan struct{}), make(chan struct{})
signalConnected := func() { signalConnected := func() { connectedCh <- struct{}{} }
connectedCh <- struct{}{} signalDNSReady := func() { dnsReadyCh <- struct{}{} }
}
defer close(connectedCh) defer close(connectedCh)
defer close(dnsReadyCh)
if allSettings.Firewall.Enabled { if allSettings.Firewall.Enabled {
err := firewallConf.SetEnabled(ctx, true) // disabled by default err := firewallConf.SetEnabled(ctx, true) // disabled by default
@@ -208,7 +208,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
restartUnbound := unboundLooper.Restart restartUnbound := unboundLooper.Restart
wg.Add(1) wg.Add(1)
// wait for restartUnbound or its ticker launched with RunRestartTicker // wait for restartUnbound or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg) go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid) publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid)
restartPublicIP := publicIPLooper.Restart restartPublicIP := publicIPLooper.Restart
@@ -267,7 +267,10 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
tickerWg.Add(2) tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation) onConnected(allSettings, logger, routingConf, portForward, restartUnbound)
case <-dnsReadyCh:
restartPublicIP() // TODO do not restart if disabled
versionInformation()
} }
} }
}() }()
@@ -374,10 +377,9 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
} }
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing, func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
portForward, restartUnbound, restartPublicIP, versionInformation func(), portForward, restartUnbound func(),
) { ) {
restartUnbound() restartUnbound()
restartPublicIP()
if allSettings.OpenVPN.Provider.PortForwarding.Enabled { if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
time.AfterFunc(5*time.Second, portForward) time.AfterFunc(5*time.Second, portForward)
} }
@@ -392,5 +394,4 @@ func onConnected(allSettings settings.Settings, logger logging.Logger, routingCo
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
} }
} }
versionInformation()
} }

View File

@@ -13,7 +13,7 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func())
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart() Restart()
Start() Start()
@@ -93,7 +93,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
<-ctx.Done() <-ctx.Done()
} }
func (l *looper) waitForFirstStart(ctx context.Context) { func (l *looper) waitForFirstStart(ctx context.Context, signalDNSReady func()) {
for { for {
select { select {
case <-l.stop: case <-l.stop:
@@ -103,6 +103,7 @@ func (l *looper) waitForFirstStart(ctx context.Context) {
if l.isEnabled() { if l.isEnabled() {
return return
} }
signalDNSReady()
l.logger.Info("not restarting because disabled") l.logger.Info("not restarting because disabled")
case <-l.start: case <-l.start:
l.setEnabled(true) l.setEnabled(true)
@@ -138,11 +139,11 @@ func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel conte
} }
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) {
defer wg.Done() defer wg.Done()
const fallback = false const fallback = false
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
l.waitForFirstStart(ctx) l.waitForFirstStart(ctx, signalDNSReady)
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
@@ -207,6 +208,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
waitError <- err waitError <- err
}() }()
l.logger.Info("DNS over TLS is ready") l.logger.Info("DNS over TLS is ready")
signalDNSReady()
stayHere := true stayHere := true
for stayHere { for stayHere {