From ac3ff095a135e67927d7152e29cae16736b22143 Mon Sep 17 00:00:00 2001 From: "Quentin McGaw (desktop)" Date: Fri, 16 Jul 2021 19:00:56 +0000 Subject: [PATCH] Maint: rework DNS run loop - Fix fragile user triggered logic - Simplify state - Lock loop when crashed --- cmd/gluetun/main.go | 2 +- internal/dns/loop.go | 167 ++++++++++++++++++----------------- internal/dns/state.go | 148 ++++++++++++++++++++----------- internal/server/dns.go | 2 +- internal/server/handlerv0.go | 4 +- 5 files changed, 190 insertions(+), 133 deletions(-) diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 23e3e5e6..562fceda 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -453,7 +453,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model } if unboundLooper.GetSettings().Enabled { - _, _ = unboundLooper.SetStatus(ctx, constants.Running) + _, _ = unboundLooper.ApplyStatus(ctx, constants.Running) } restartTickerCancel() // stop previous restart tickers diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 49468721..0e6c2200 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -6,7 +6,6 @@ import ( "errors" "net" "net/http" - "sync" "time" "github.com/qdm12/dns/pkg/blacklist" @@ -24,7 +23,7 @@ type Looper interface { Run(ctx context.Context, done chan<- struct{}) RunRestartTicker(ctx context.Context, done chan<- struct{}) GetStatus() (status models.LoopStatus) - SetStatus(ctx context.Context, status models.LoopStatus) ( + ApplyStatus(ctx context.Context, status models.LoopStatus) ( outcome string, err error) GetSettings() (settings configuration.DNS) SetSettings(ctx context.Context, settings configuration.DNS) ( @@ -32,17 +31,16 @@ type Looper interface { } type looper struct { - state state + state *state conf unbound.Configurator blockBuilder blacklist.Builder client *http.Client logger logging.Logger - loopLock sync.Mutex - start chan struct{} - running chan models.LoopStatus - stop chan struct{} - stopped chan struct{} - updateTicker chan struct{} + start <-chan struct{} + running chan<- models.LoopStatus + stop <-chan struct{} + stopped chan<- struct{} + updateTicker <-chan struct{} backoffTime time.Duration timeNow func() time.Time timeSince func(time.Time) time.Duration @@ -53,20 +51,25 @@ const defaultBackoffTime = 10 * time.Second func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *http.Client, logger logging.Logger, openFile os.OpenFileFunc) Looper { + start := make(chan struct{}) + running := make(chan models.LoopStatus) + stop := make(chan struct{}) + stopped := make(chan struct{}) + updateTicker := make(chan struct{}) + + state := newState(constants.Stopped, settings, start, running, stop, stopped, updateTicker) + return &looper{ - state: state{ - status: constants.Stopped, - settings: settings, - }, + state: state, conf: conf, blockBuilder: blacklist.NewBuilder(client), client: client, logger: logger, - start: make(chan struct{}), - running: make(chan models.LoopStatus), - stop: make(chan struct{}), - stopped: make(chan struct{}), - updateTicker: make(chan struct{}), + start: start, + running: running, + stop: stop, + stopped: stopped, + updateTicker: updateTicker, backoffTime: defaultBackoffTime, timeNow: time.Now, timeSince: time.Since, @@ -78,7 +81,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) { if err != nil { l.logger.Warn(err) } - l.logger.Info("attempting restart in %s", l.backoffTime) + l.logger.Info("attempting restart in " + l.backoffTime.String()) timer := time.NewTimer(l.backoffTime) l.backoffTime *= 2 select { @@ -90,7 +93,16 @@ func (l *looper) logAndWait(ctx context.Context, err error) { } } -func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocognit +func (l *looper) signalOrSetStatus(userTriggered *bool, status models.LoopStatus) { + if *userTriggered { + *userTriggered = false + l.running <- status + } else { + l.state.SetStatus(status) + } +} + +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { defer close(done) const fallback = false @@ -103,46 +115,45 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog return } - crashed := false - l.backoffTime = defaultBackoffTime + userTriggered := true for ctx.Err() == nil { // Upper scope variables for Unbound only // Their values are to be used if DOT=off - var waitError chan error - var unboundCancel context.CancelFunc - var closeStreams func() + waitError := make(chan error) + unboundCancel := func() { waitError <- nil } + closeStreams := func() {} for l.GetSettings().Enabled { - if ctx.Err() != nil { - if !crashed { - l.running <- constants.Stopped - } - return - } var err error - unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, crashed) + unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx) + if err == nil { + l.backoffTime = defaultBackoffTime + l.logger.Info("ready") + l.signalOrSetStatus(&userTriggered, constants.Running) + break + } + + l.signalOrSetStatus(&userTriggered, constants.Crashed) + if ctx.Err() != nil { return } - if err != nil { - if !errors.Is(err, errUpdateFiles) { - const fallback = true - l.useUnencryptedDNS(fallback) - } - l.logAndWait(ctx, err) - continue + + if !errors.Is(err, errUpdateFiles) { + const fallback = true + l.useUnencryptedDNS(fallback) } - break + l.logAndWait(ctx, err) } + if !l.GetSettings().Enabled { const fallback = false l.useUnencryptedDNS(fallback) - waitError := make(chan error) - unboundCancel = func() { waitError <- nil } - closeStreams = func() {} } + userTriggered = false + stayHere := true for stayHere { select { @@ -153,31 +164,36 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog closeStreams() return case <-l.stop: + userTriggered = true l.logger.Info("stopping") const fallback = false l.useUnencryptedDNS(fallback) unboundCancel() <-waitError + // do not close waitError or the waitError + // select case will trigger + closeStreams() l.stopped <- struct{}{} case <-l.start: + userTriggered = true l.logger.Info("starting") stayHere = false case err := <-waitError: // unexpected error + close(waitError) + closeStreams() + + l.state.Lock() // prevent SetStatus from running in parallel + unboundCancel() - if ctx.Err() != nil { - close(waitError) - closeStreams() - return - } - l.state.setStatusWithLock(constants.Crashed) + l.state.SetStatus(constants.Crashed) const fallback = true l.useUnencryptedDNS(fallback) l.logAndWait(ctx, err) stayHere = false + + l.state.Unlock() } } - close(waitError) - closeStreams() } } @@ -186,11 +202,10 @@ var errUpdateFiles = errors.New("cannot update files") // Returning cancel == nil signals we want to re-run setupUnbound // Returning err == errUpdateFiles signals we should not fall back // on the plaintext DNS as DOT is still up and running. -func (l *looper) setupUnbound(ctx context.Context, previousCrashed bool) ( +func (l *looper) setupUnbound(ctx context.Context) ( cancel context.CancelFunc, waitError chan error, closeStreams func(), err error) { err = l.updateFiles(ctx) if err != nil { - l.state.setStatusWithLock(constants.Crashed) return nil, nil, nil, errUpdateFiles } @@ -200,14 +215,16 @@ func (l *looper) setupUnbound(ctx context.Context, previousCrashed bool) ( stdoutLines, stderrLines, waitError, err := l.conf.Start(unboundCtx, settings.Unbound.VerbosityDetailsLevel) if err != nil { cancel() - if !previousCrashed { - l.running <- constants.Crashed - } return nil, nil, nil, err } collectLinesDone := make(chan struct{}) go l.collectLines(stdoutLines, stderrLines, collectLinesDone) + closeStreams = func() { + close(stdoutLines) + close(stderrLines) + <-collectLinesDone + } // use Unbound nameserver.UseDNSInternally(net.IP{127, 0, 0, 1}) @@ -218,32 +235,13 @@ func (l *looper) setupUnbound(ctx context.Context, previousCrashed bool) ( } if err := check.WaitForDNS(ctx, net.DefaultResolver); err != nil { - if !previousCrashed { - l.running <- constants.Crashed - } cancel() <-waitError close(waitError) - close(stdoutLines) - close(stderrLines) - <-collectLinesDone + closeStreams() return nil, nil, nil, err } - l.logger.Info("ready") - if !previousCrashed { - l.running <- constants.Running - } else { - l.backoffTime = defaultBackoffTime - l.state.setStatusWithLock(constants.Running) - } - - closeStreams = func() { - close(stdoutLines) - close(stderrLines) - <-collectLinesDone - } - return cancel, waitError, closeStreams, nil } @@ -304,15 +302,15 @@ func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) { status := l.GetStatus() if status == constants.Running { if err := l.updateFiles(ctx); err != nil { - l.state.setStatusWithLock(constants.Crashed) + l.state.SetStatus(constants.Crashed) l.logger.Error(err) l.logger.Warn("skipping Unbound restart due to failed files update") continue } } - _, _ = l.SetStatus(ctx, constants.Stopped) - _, _ = l.SetStatus(ctx, constants.Running) + _, _ = l.ApplyStatus(ctx, constants.Stopped) + _, _ = l.ApplyStatus(ctx, constants.Running) settings := l.GetSettings() timer.Reset(settings.UpdatePeriod) @@ -358,3 +356,14 @@ func (l *looper) updateFiles(ctx context.Context) (err error) { return l.conf.MakeUnboundConf(settings.Unbound) } + +func (l *looper) GetStatus() (status models.LoopStatus) { return l.state.GetStatus() } +func (l *looper) ApplyStatus(ctx context.Context, status models.LoopStatus) ( + outcome string, err error) { + return l.state.ApplyStatus(ctx, status) +} +func (l *looper) GetSettings() (settings configuration.DNS) { return l.state.GetSettings() } +func (l *looper) SetSettings(ctx context.Context, settings configuration.DNS) ( + outcome string) { + return l.state.SetSettings(ctx, settings) +} diff --git a/internal/dns/state.go b/internal/dns/state.go index 52f560d1..4d7b7774 100644 --- a/internal/dns/state.go +++ b/internal/dns/state.go @@ -12,72 +12,114 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -type state struct { - status models.LoopStatus - settings configuration.DNS - statusMu sync.RWMutex - settingsMu sync.RWMutex +func newState(status models.LoopStatus, settings configuration.DNS, + start chan<- struct{}, running <-chan models.LoopStatus, + stop chan<- struct{}, stopped <-chan struct{}, + updateTicker chan<- struct{}) *state { + return &state{ + status: status, + settings: settings, + start: start, + running: running, + stop: stop, + stopped: stopped, + updateTicker: updateTicker, + } } -func (s *state) setStatusWithLock(status models.LoopStatus) { +type state struct { + loopMu sync.RWMutex + + status models.LoopStatus + statusMu sync.RWMutex + + settings configuration.DNS + settingsMu sync.RWMutex + + start chan<- struct{} + running <-chan models.LoopStatus + stop chan<- struct{} + stopped <-chan struct{} + + updateTicker chan<- struct{} +} + +func (s *state) Lock() { s.loopMu.Lock() } +func (s *state) Unlock() { s.loopMu.Unlock() } + +// SetStatus sets the status thread safely. +// It should only be called by the loop internal code since +// it does not interact with the loop code directly. +func (s *state) SetStatus(status models.LoopStatus) { s.statusMu.Lock() defer s.statusMu.Unlock() s.status = status } -func (l *looper) GetStatus() (status models.LoopStatus) { - l.state.statusMu.RLock() - defer l.state.statusMu.RUnlock() - return l.state.status +// GetStatus gets the status thread safely. +func (s *state) GetStatus() (status models.LoopStatus) { + s.statusMu.RLock() + defer s.statusMu.RUnlock() + return s.status } var ErrInvalidStatus = errors.New("invalid status") -func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) ( +// ApplyStatus sends signals to the running loop depending on the +// current status and status requested, such that its next status +// matches the requested one. It is thread safe and a synchronous call +// since it waits to the loop to fully change its status. +func (s *state) ApplyStatus(ctx context.Context, status models.LoopStatus) ( outcome string, err error) { - l.state.statusMu.Lock() - defer l.state.statusMu.Unlock() - existingStatus := l.state.status + // prevent simultaneous loop changes by restricting + // multiple SetStatus calls to run sequentially. + s.loopMu.Lock() + defer s.loopMu.Unlock() + + // not a read lock as we want to modify it eventually in + // the code below before any other call. + s.statusMu.Lock() + existingStatus := s.status switch status { case constants.Running: - switch existingStatus { - case constants.Starting, constants.Running, constants.Stopping, constants.Crashed: - return fmt.Sprintf("already %s", existingStatus), nil + if existingStatus != constants.Stopped { + // starting, running, stopping, crashed + s.statusMu.Unlock() + return "already " + existingStatus.String(), nil } - l.loopLock.Lock() - defer l.loopLock.Unlock() - l.state.status = constants.Starting - l.state.statusMu.Unlock() - l.start <- struct{}{} + + s.status = constants.Starting + s.statusMu.Unlock() + s.start <- struct{}{} + + // Wait for the loop to react to the start signal newStatus := constants.Starting // for canceled context select { case <-ctx.Done(): - case newStatus = <-l.running: + case newStatus = <-s.running: } + s.SetStatus(newStatus) - l.state.statusMu.Lock() - l.state.status = newStatus return newStatus.String(), nil case constants.Stopped: - switch existingStatus { - case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed: - return fmt.Sprintf("already %s", existingStatus), nil + if existingStatus != constants.Running { + return "already " + existingStatus.String(), nil } - l.loopLock.Lock() - defer l.loopLock.Unlock() - l.state.status = constants.Stopping - l.state.statusMu.Unlock() - l.stop <- struct{}{} + s.status = constants.Stopping + s.statusMu.Unlock() + s.stop <- struct{}{} + + // Wait for the loop to react to the stop signal newStatus := constants.Stopping // for canceled context select { case <-ctx.Done(): - case <-l.stopped: + case <-s.stopped: newStatus = constants.Stopped } - l.state.statusMu.Lock() - l.state.status = newStatus + s.SetStatus(newStatus) + return status.String(), nil default: return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", @@ -85,32 +127,38 @@ func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) ( } } -func (l *looper) GetSettings() (settings configuration.DNS) { - l.state.settingsMu.RLock() - defer l.state.settingsMu.RUnlock() - return l.state.settings +func (s *state) GetSettings() (settings configuration.DNS) { + s.settingsMu.RLock() + defer s.settingsMu.RUnlock() + return s.settings } -func (l *looper) SetSettings(ctx context.Context, settings configuration.DNS) ( +func (s *state) SetSettings(ctx context.Context, settings configuration.DNS) ( outcome string) { - l.state.settingsMu.Lock() - settingsUnchanged := reflect.DeepEqual(l.state.settings, settings) + s.settingsMu.Lock() + defer s.settingsMu.Unlock() + + settingsUnchanged := reflect.DeepEqual(s.settings, settings) if settingsUnchanged { - l.state.settingsMu.Unlock() return "settings left unchanged" } - tempSettings := l.state.settings + + // Check for only update period change + tempSettings := s.settings tempSettings.UpdatePeriod = settings.UpdatePeriod onlyUpdatePeriodChanged := reflect.DeepEqual(tempSettings, settings) - l.state.settings = settings - l.state.settingsMu.Unlock() + + s.settings = settings + if onlyUpdatePeriodChanged { - l.updateTicker <- struct{}{} + s.updateTicker <- struct{}{} return "update period changed" } - _, _ = l.SetStatus(ctx, constants.Stopped) + + // Restart + _, _ = s.ApplyStatus(ctx, constants.Stopped) if settings.Enabled { - outcome, _ = l.SetStatus(ctx, constants.Running) + outcome, _ = s.ApplyStatus(ctx, constants.Running) } return outcome } diff --git a/internal/server/dns.go b/internal/server/dns.go index cc65a37a..7c9a1a00 100644 --- a/internal/server/dns.go +++ b/internal/server/dns.go @@ -65,7 +65,7 @@ func (h *dnsHandler) setStatus(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - outcome, err := h.looper.SetStatus(h.ctx, status) + outcome, err := h.looper.ApplyStatus(h.ctx, status) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/internal/server/handlerv0.go b/internal/server/handlerv0.go index 3246bafe..d180c613 100644 --- a/internal/server/handlerv0.go +++ b/internal/server/handlerv0.go @@ -47,9 +47,9 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.logger.Warn(err) } case "/unbound/actions/restart": - outcome, _ := h.dns.SetStatus(h.ctx, constants.Stopped) + outcome, _ := h.dns.ApplyStatus(h.ctx, constants.Stopped) h.logger.Info("dns: %s", outcome) - outcome, _ = h.dns.SetStatus(h.ctx, constants.Running) + outcome, _ = h.dns.ApplyStatus(h.ctx, constants.Running) h.logger.Info("dns: %s", outcome) if _, err := w.Write([]byte("dns restarted, please consider using the /v1/ API in the future.")); err != nil { h.logger.Warn(err)