diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 17eb4cb9..a4e072b7 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -18,18 +18,24 @@ type Looper interface { Restart() Start() Stop() + GetSettings() (settings settings.DNS) + SetSettings(settings settings.DNS) } type looper struct { - conf Configurator - settings settings.DNS - logger logging.Logger - streamMerger command.StreamMerger - uid int - gid int - restart chan struct{} - start chan struct{} - stop chan struct{} + conf Configurator + settings settings.DNS + settingsMutex sync.RWMutex + logger logging.Logger + streamMerger command.StreamMerger + uid int + gid int + restart chan struct{} + start chan struct{} + stop chan struct{} + updateTicker chan struct{} + tickerReady bool + tickerReadyMutex sync.Mutex } func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, @@ -44,6 +50,7 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, restart: make(chan struct{}), start: make(chan struct{}), stop: make(chan struct{}), + updateTicker: make(chan struct{}), } } @@ -51,6 +58,35 @@ func (l *looper) Restart() { l.restart <- struct{}{} } func (l *looper) Start() { l.start <- struct{}{} } func (l *looper) Stop() { l.stop <- struct{}{} } +func (l *looper) GetSettings() (settings settings.DNS) { + l.settingsMutex.RLock() + defer l.settingsMutex.RUnlock() + return l.settings +} + +func (l *looper) SetSettings(settings settings.DNS) { + l.settingsMutex.Lock() + defer l.settingsMutex.Unlock() + updatePeriodDiffers := l.settings.UpdatePeriod != settings.UpdatePeriod + l.settings = settings + l.settingsMutex.Unlock() + if updatePeriodDiffers { + l.updateTicker <- struct{}{} + } +} + +func (l *looper) isEnabled() bool { + l.settingsMutex.RLock() + defer l.settingsMutex.RUnlock() + return l.settings.Enabled +} + +func (l *looper) setEnabled(enabled bool) { + l.settingsMutex.Lock() + defer l.settingsMutex.Unlock() + l.settings.Enabled = enabled +} + func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Warn(err) l.logger.Info("attempting restart in 10 seconds") @@ -82,23 +118,25 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { var unboundCancel context.CancelFunc = func() {} var waitError chan error triggeredRestart := false - l.settings.Enabled = true + l.setEnabled(true) for ctx.Err() == nil { - for !l.settings.Enabled { + for !l.isEnabled() { // wait for a signal to re-enable select { case <-l.stop: l.logger.Info("already disabled") case <-l.restart: - l.settings.Enabled = true + l.setEnabled(true) case <-l.start: - l.settings.Enabled = true + l.setEnabled(true) case <-ctx.Done(): unboundCancel() return } } + settings := l.GetSettings() + // Setup if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil { l.logAndWait(ctx, err) @@ -108,7 +146,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { l.logAndWait(ctx, err) continue } - if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil { + if err := l.conf.MakeUnboundConf(settings, l.uid, l.gid); err != nil { l.logAndWait(ctx, err) continue } @@ -120,7 +158,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { close(waitError) } unboundCtx, unboundCancel = context.WithCancel(context.Background()) - stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel) + stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel) if err != nil { unboundCancel() l.fallbackToUnencryptedDNS() @@ -130,8 +168,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { // Started successfully go l.streamMerger.Merge(unboundCtx, stream, command.MergeName("unbound")) - l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound - if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, l.settings.KeepNameserver); err != nil { // use Unbound + l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound + if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound l.logger.Error(err) } if err := l.conf.WaitForUnbound(); err != nil { @@ -167,7 +205,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { unboundCancel() <-waitError close(waitError) - l.settings.Enabled = false + l.setEnabled(false) stayHere = false case err := <-waitError: // unexpected error close(waitError) @@ -182,25 +220,27 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { } func (l *looper) fallbackToUnencryptedDNS() { + settings := l.GetSettings() + // Try with user provided plaintext ip address - targetIP := l.settings.PlaintextAddress + targetIP := settings.PlaintextAddress if targetIP != nil { l.logger.Info("falling back on plaintext DNS at address %s", targetIP) l.conf.UseDNSInternally(targetIP) - if err := l.conf.UseDNSSystemWide(targetIP, l.settings.KeepNameserver); err != nil { + if err := l.conf.UseDNSSystemWide(targetIP, settings.KeepNameserver); err != nil { l.logger.Error(err) } return } // Try with any IPv4 address from the providers chosen - for _, provider := range l.settings.Providers { + for _, provider := range settings.Providers { data := constants.DNSProviderMapping()[provider] for _, targetIP = range data.IPs { if targetIP.To4() != nil { l.logger.Info("falling back on plaintext DNS at address %s", targetIP) l.conf.UseDNSInternally(targetIP) - if err := l.conf.UseDNSSystemWide(targetIP, l.settings.KeepNameserver); err != nil { + if err := l.conf.UseDNSSystemWide(targetIP, settings.KeepNameserver); err != nil { l.logger.Error(err) } return @@ -209,21 +249,33 @@ func (l *looper) fallbackToUnencryptedDNS() { } // No IPv4 address found - l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers) + l.logger.Error("no ipv4 DNS address found for providers %s", settings.Providers) } func (l *looper) RunRestartTicker(ctx context.Context) { - if l.settings.UpdatePeriod == 0 { - return + l.tickerReadyMutex.Lock() + l.tickerReady = true + l.tickerReadyMutex.Unlock() + var ticker *time.Ticker = nil + settings := l.GetSettings() + if settings.UpdatePeriod > 0 { + ticker = time.NewTicker(settings.UpdatePeriod) } - ticker := time.NewTicker(l.settings.UpdatePeriod) for { select { case <-ctx.Done(): - ticker.Stop() + if ticker != nil { + ticker.Stop() + } return case <-ticker.C: l.restart <- struct{}{} + case <-l.updateTicker: + if ticker != nil { + ticker.Stop() + } + period := l.GetSettings().UpdatePeriod + ticker = time.NewTicker(period) } } }