From 464c7074d0596e855d678ee5e92ab8c6bb8b91fc Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 12 Sep 2020 18:50:42 +0000 Subject: [PATCH] Get public IP and version only when DNS is ready --- cmd/gluetun/main.go | 19 ++++++++++--------- internal/dns/loop.go | 10 ++++++---- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 03583ab3..90330b14 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -156,11 +156,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go } } - connectedCh := make(chan struct{}) - signalConnected := func() { - connectedCh <- struct{}{} - } + connectedCh, dnsReadyCh := make(chan struct{}), make(chan struct{}) + signalConnected := func() { connectedCh <- struct{}{} } + signalDNSReady := func() { dnsReadyCh <- struct{}{} } defer close(connectedCh) + defer close(dnsReadyCh) if allSettings.Firewall.Enabled { err := firewallConf.SetEnabled(ctx, true) // disabled by default @@ -208,7 +208,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go restartUnbound := unboundLooper.Restart wg.Add(1) // wait for restartUnbound or its ticker launched with RunRestartTicker - go unboundLooper.Run(ctx, wg) + go unboundLooper.Run(ctx, wg, signalDNSReady) publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid) restartPublicIP := publicIPLooper.Restart @@ -267,7 +267,10 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go tickerWg.Add(2) go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) - onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation) + onConnected(allSettings, logger, routingConf, portForward, restartUnbound) + case <-dnsReadyCh: + restartPublicIP() // TODO do not restart if disabled + versionInformation() } } }() @@ -374,10 +377,9 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, } func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing, - portForward, restartUnbound, restartPublicIP, versionInformation func(), + portForward, restartUnbound func(), ) { restartUnbound() - restartPublicIP() if allSettings.OpenVPN.Provider.PortForwarding.Enabled { time.AfterFunc(5*time.Second, portForward) } @@ -392,5 +394,4 @@ func onConnected(allSettings settings.Settings, logger logging.Logger, routingCo logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) } } - versionInformation() } diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 76b8d874..eb11c4e7 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -13,7 +13,7 @@ import ( ) type Looper interface { - Run(ctx context.Context, wg *sync.WaitGroup) + Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) Restart() Start() @@ -93,7 +93,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) { <-ctx.Done() } -func (l *looper) waitForFirstStart(ctx context.Context) { +func (l *looper) waitForFirstStart(ctx context.Context, signalDNSReady func()) { for { select { case <-l.stop: @@ -103,6 +103,7 @@ func (l *looper) waitForFirstStart(ctx context.Context) { if l.isEnabled() { return } + signalDNSReady() l.logger.Info("not restarting because disabled") case <-l.start: l.setEnabled(true) @@ -138,11 +139,11 @@ func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel conte } } -func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { +func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) { defer wg.Done() const fallback = false l.useUnencryptedDNS(fallback) - l.waitForFirstStart(ctx) + l.waitForFirstStart(ctx, signalDNSReady) if ctx.Err() != nil { return } @@ -207,6 +208,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { waitError <- err }() l.logger.Info("DNS over TLS is ready") + signalDNSReady() stayHere := true for stayHere {