From bb2b8b4514100209eb99a13394fd41b6f98d178d Mon Sep 17 00:00:00 2001 From: "Quentin McGaw (desktop)" Date: Thu, 15 Jul 2021 22:42:58 +0000 Subject: [PATCH] Fix: events routing exit when gluetun stops at start --- cmd/gluetun/main.go | 8 ++++---- internal/dns/loop.go | 15 ++++++++++----- internal/dns/state.go | 18 +++++++++++++----- internal/server/dns.go | 8 ++++++-- internal/server/handler.go | 7 ++++--- internal/server/handlerv0.go | 9 ++++++--- internal/server/server.go | 4 ++-- 7 files changed, 45 insertions(+), 24 deletions(-) diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 3dac1101..24a5c763 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -359,11 +359,11 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port)) controlServerLogging := allSettings.ControlServer.Log - httpServer := server.New(controlServerAddress, controlServerLogging, - logger.NewChild(logging.Settings{Prefix: "http server: "}), - buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper) httpServerHandler, httpServerCtx, httpServerDone := goshutdown.NewGoRoutineHandler( "http server", defaultGoRoutineSettings) + httpServer := server.New(httpServerCtx, controlServerAddress, controlServerLogging, + logger.NewChild(logging.Settings{Prefix: "http server: "}), + buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper) go httpServer.Run(httpServerCtx, httpServerDone) controlGroupHandler.Add(httpServerHandler) @@ -454,7 +454,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model } if unboundLooper.GetSettings().Enabled { - _, _ = unboundLooper.SetStatus(constants.Running) + _, _ = unboundLooper.SetStatus(ctx, constants.Running) } restartTickerCancel() // stop previous restart tickers diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 531d30b0..49468721 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -24,9 +24,11 @@ type Looper interface { Run(ctx context.Context, done chan<- struct{}) RunRestartTicker(ctx context.Context, done chan<- struct{}) GetStatus() (status models.LoopStatus) - SetStatus(status models.LoopStatus) (outcome string, err error) + SetStatus(ctx context.Context, status models.LoopStatus) ( + outcome string, err error) GetSettings() (settings configuration.DNS) - SetSettings(settings configuration.DNS) (outcome string) + SetSettings(ctx context.Context, settings configuration.DNS) ( + outcome string) } type looper struct { @@ -88,7 +90,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) { } } -func (l *looper) Run(ctx context.Context, done chan<- struct{}) { +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocognit defer close(done) const fallback = false @@ -120,6 +122,9 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { } var err error unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, crashed) + if ctx.Err() != nil { + return + } if err != nil { if !errors.Is(err, errUpdateFiles) { const fallback = true @@ -306,8 +311,8 @@ func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) { } } - _, _ = l.SetStatus(constants.Stopped) - _, _ = l.SetStatus(constants.Running) + _, _ = l.SetStatus(ctx, constants.Stopped) + _, _ = l.SetStatus(ctx, constants.Running) settings := l.GetSettings() timer.Reset(settings.UpdatePeriod) diff --git a/internal/dns/state.go b/internal/dns/state.go index ec7d1c4e..b534ad29 100644 --- a/internal/dns/state.go +++ b/internal/dns/state.go @@ -1,6 +1,7 @@ package dns import ( + "context" "errors" "fmt" "reflect" @@ -32,7 +33,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) { var ErrInvalidStatus = errors.New("invalid status") -func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { +func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) ( + outcome string, err error) { l.state.statusMu.Lock() defer l.state.statusMu.Unlock() existingStatus := l.state.status @@ -48,7 +50,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) l.state.status = constants.Starting l.state.statusMu.Unlock() l.start <- struct{}{} - newStatus := <-l.running + newStatus := constants.Starting // for canceled context + select { + case <-ctx.Done(): + case newStatus = <-l.running: + } + l.state.statusMu.Lock() l.state.status = newStatus return newStatus.String(), nil @@ -78,7 +85,8 @@ func (l *looper) GetSettings() (settings configuration.DNS) { return l.state.settings } -func (l *looper) SetSettings(settings configuration.DNS) (outcome string) { +func (l *looper) SetSettings(ctx context.Context, settings configuration.DNS) ( + outcome string) { l.state.settingsMu.Lock() settingsUnchanged := reflect.DeepEqual(l.state.settings, settings) if settingsUnchanged { @@ -94,9 +102,9 @@ func (l *looper) SetSettings(settings configuration.DNS) (outcome string) { l.updateTicker <- struct{}{} return "update period changed" } - _, _ = l.SetStatus(constants.Stopped) + _, _ = l.SetStatus(ctx, constants.Stopped) if settings.Enabled { - outcome, _ = l.SetStatus(constants.Running) + outcome, _ = l.SetStatus(ctx, constants.Running) } return outcome } diff --git a/internal/server/dns.go b/internal/server/dns.go index 229597c9..cc65a37a 100644 --- a/internal/server/dns.go +++ b/internal/server/dns.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "net/http" "strings" @@ -9,14 +10,17 @@ import ( "github.com/qdm12/golibs/logging" ) -func newDNSHandler(looper dns.Looper, logger logging.Logger) http.Handler { +func newDNSHandler(ctx context.Context, looper dns.Looper, + logger logging.Logger) http.Handler { return &dnsHandler{ + ctx: ctx, looper: looper, logger: logger, } } type dnsHandler struct { + ctx context.Context looper dns.Looper logger logging.Logger } @@ -61,7 +65,7 @@ func (h *dnsHandler) setStatus(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - outcome, err := h.looper.SetStatus(status) + outcome, err := h.looper.SetStatus(h.ctx, status) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/internal/server/handler.go b/internal/server/handler.go index 9be8cb79..cb64acca 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/http" "strings" @@ -12,7 +13,7 @@ import ( "github.com/qdm12/golibs/logging" ) -func newHandler(logger logging.Logger, logging bool, +func newHandler(ctx context.Context, logger logging.Logger, logging bool, buildInfo models.BuildInformation, openvpnLooper openvpn.Looper, unboundLooper dns.Looper, @@ -22,11 +23,11 @@ func newHandler(logger logging.Logger, logging bool, handler := &handler{} openvpn := newOpenvpnHandler(openvpnLooper, logger) - dns := newDNSHandler(unboundLooper, logger) + dns := newDNSHandler(ctx, unboundLooper, logger) updater := newUpdaterHandler(updaterLooper, logger) publicip := newPublicIPHandler(publicIPLooper, logger) - handler.v0 = newHandlerV0(logger, openvpnLooper, unboundLooper, updaterLooper) + handler.v0 = newHandlerV0(ctx, logger, openvpnLooper, unboundLooper, updaterLooper) handler.v1 = newHandlerV1(logger, buildInfo, openvpn, dns, updater, publicip) handlerWithLog := withLogMiddleware(handler, logger, logging) diff --git a/internal/server/handlerv0.go b/internal/server/handlerv0.go index 8236cd43..00ea08d5 100644 --- a/internal/server/handlerv0.go +++ b/internal/server/handlerv0.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/http" "github.com/qdm12/gluetun/internal/constants" @@ -10,9 +11,10 @@ import ( "github.com/qdm12/golibs/logging" ) -func newHandlerV0(logger logging.Logger, +func newHandlerV0(ctx context.Context, logger logging.Logger, openvpn openvpn.Looper, dns dns.Looper, updater updater.Looper) http.Handler { return &handlerV0{ + ctx: ctx, logger: logger, openvpn: openvpn, dns: dns, @@ -21,6 +23,7 @@ func newHandlerV0(logger logging.Logger, } type handlerV0 struct { + ctx context.Context logger logging.Logger openvpn openvpn.Looper dns dns.Looper @@ -44,9 +47,9 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.logger.Warn(err) } case "/unbound/actions/restart": - outcome, _ := h.dns.SetStatus(constants.Stopped) + outcome, _ := h.dns.SetStatus(h.ctx, constants.Stopped) h.logger.Info("dns: %s", outcome) - outcome, _ = h.dns.SetStatus(constants.Running) + outcome, _ = h.dns.SetStatus(h.ctx, constants.Running) h.logger.Info("dns: %s", outcome) if _, err := w.Write([]byte("dns restarted, please consider using the /v1/ API in the future.")); err != nil { h.logger.Warn(err) diff --git a/internal/server/server.go b/internal/server/server.go index ba6b393b..db84957b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -25,11 +25,11 @@ type server struct { handler http.Handler } -func New(address string, logEnabled bool, logger logging.Logger, +func New(ctx context.Context, address string, logEnabled bool, logger logging.Logger, buildInfo models.BuildInformation, openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper) Server { - handler := newHandler(logger, logEnabled, buildInfo, + handler := newHandler(ctx, logger, logEnabled, buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper) return &server{ address: address,