From 83cf59b93e3aa2355c90c936da6e89b3fe200396 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 15 Jul 2020 23:51:34 +0000 Subject: [PATCH] Start and Stop for dns over tls --- internal/dns/loop.go | 81 +++++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 1eb7a66a..17eb4cb9 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -16,6 +16,8 @@ type Looper interface { Run(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context) Restart() + Start() + Stop() } type looper struct { @@ -26,6 +28,8 @@ type looper struct { uid int gid int restart chan struct{} + start chan struct{} + stop chan struct{} } func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, @@ -38,10 +42,14 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, gid: gid, streamMerger: streamMerger, restart: make(chan struct{}), + start: make(chan struct{}), + stop: make(chan struct{}), } } func (l *looper) Restart() { l.restart <- struct{}{} } +func (l *looper) Start() { l.start <- struct{}{} } +func (l *looper) Stop() { l.stop <- struct{}{} } func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Warn(err) @@ -55,10 +63,18 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() l.fallbackToUnencryptedDNS() - select { - case <-l.restart: - case <-ctx.Done(): - return + waitForStart := true + for waitForStart { + select { + case <-l.stop: + l.logger.Info("not started yet") + case <-l.restart: + waitForStart = false + case <-l.start: + waitForStart = false + case <-ctx.Done(): + return + } } defer l.logger.Warn("loop exited") @@ -66,11 +82,17 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { var unboundCancel context.CancelFunc = func() {} var waitError chan error triggeredRestart := false + l.settings.Enabled = true for ctx.Err() == nil { - if !l.settings.Enabled { - // wait for another restart signal to recheck if it is enabled + for !l.settings.Enabled { + // wait for a signal to re-enable select { + case <-l.stop: + l.logger.Info("already disabled") case <-l.restart: + l.settings.Enabled = true + case <-l.start: + l.settings.Enabled = true case <-ctx.Done(): unboundCancel() return @@ -124,23 +146,36 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { waitError <- err }() - // Wait for one of the three cases below - select { - case <-ctx.Done(): - l.logger.Warn("context canceled: exiting loop") - unboundCancel() - <-waitError - close(waitError) - return - case <-l.restart: // triggered restart - l.logger.Info("restarting") - // unboundCancel occurs next loop run when the setup is complete - triggeredRestart = true - case err := <-waitError: // unexpected error - close(waitError) - unboundCancel() - l.fallbackToUnencryptedDNS() - l.logAndWait(ctx, err) + stayHere := true + for stayHere { + select { + case <-ctx.Done(): + l.logger.Warn("context canceled: exiting loop") + unboundCancel() + <-waitError + close(waitError) + return + case <-l.restart: // triggered restart + l.logger.Info("restarting") + // unboundCancel occurs next loop run when the setup is complete + triggeredRestart = true + stayHere = false + case <-l.start: + l.logger.Info("already started") + case <-l.stop: + l.logger.Info("stopping") + unboundCancel() + <-waitError + close(waitError) + l.settings.Enabled = false + stayHere = false + case err := <-waitError: // unexpected error + close(waitError) + unboundCancel() + l.fallbackToUnencryptedDNS() + l.logAndWait(ctx, err) + stayHere = false + } } } unboundCancel()