Fix: events routing exit when gluetun stops at start

This commit is contained in:
Quentin McGaw (desktop)
2021-07-15 22:42:58 +00:00
parent e20b9c5774
commit bb2b8b4514
7 changed files with 45 additions and 24 deletions

View File

@@ -359,11 +359,11 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port)) controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port))
controlServerLogging := allSettings.ControlServer.Log controlServerLogging := allSettings.ControlServer.Log
httpServer := server.New(controlServerAddress, controlServerLogging,
logger.NewChild(logging.Settings{Prefix: "http server: "}),
buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
httpServerHandler, httpServerCtx, httpServerDone := goshutdown.NewGoRoutineHandler( httpServerHandler, httpServerCtx, httpServerDone := goshutdown.NewGoRoutineHandler(
"http server", defaultGoRoutineSettings) "http server", defaultGoRoutineSettings)
httpServer := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
logger.NewChild(logging.Settings{Prefix: "http server: "}),
buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
go httpServer.Run(httpServerCtx, httpServerDone) go httpServer.Run(httpServerCtx, httpServerDone)
controlGroupHandler.Add(httpServerHandler) controlGroupHandler.Add(httpServerHandler)
@@ -454,7 +454,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
} }
if unboundLooper.GetSettings().Enabled { if unboundLooper.GetSettings().Enabled {
_, _ = unboundLooper.SetStatus(constants.Running) _, _ = unboundLooper.SetStatus(ctx, constants.Running)
} }
restartTickerCancel() // stop previous restart tickers restartTickerCancel() // stop previous restart tickers

View File

@@ -24,9 +24,11 @@ type Looper interface {
Run(ctx context.Context, done chan<- struct{}) Run(ctx context.Context, done chan<- struct{})
RunRestartTicker(ctx context.Context, done chan<- struct{}) RunRestartTicker(ctx context.Context, done chan<- struct{})
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error)
GetSettings() (settings configuration.DNS) GetSettings() (settings configuration.DNS)
SetSettings(settings configuration.DNS) (outcome string) SetSettings(ctx context.Context, settings configuration.DNS) (
outcome string)
} }
type looper struct { type looper struct {
@@ -88,7 +90,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
} }
func (l *looper) Run(ctx context.Context, done chan<- struct{}) { func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocognit
defer close(done) defer close(done)
const fallback = false const fallback = false
@@ -120,6 +122,9 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
} }
var err error var err error
unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, crashed) unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, crashed)
if ctx.Err() != nil {
return
}
if err != nil { if err != nil {
if !errors.Is(err, errUpdateFiles) { if !errors.Is(err, errUpdateFiles) {
const fallback = true const fallback = true
@@ -306,8 +311,8 @@ func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
} }
} }
_, _ = l.SetStatus(constants.Stopped) _, _ = l.SetStatus(ctx, constants.Stopped)
_, _ = l.SetStatus(constants.Running) _, _ = l.SetStatus(ctx, constants.Running)
settings := l.GetSettings() settings := l.GetSettings()
timer.Reset(settings.UpdatePeriod) timer.Reset(settings.UpdatePeriod)

View File

@@ -1,6 +1,7 @@
package dns package dns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -32,7 +33,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
var ErrInvalidStatus = errors.New("invalid status") var ErrInvalidStatus = errors.New("invalid status")
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error) {
l.state.statusMu.Lock() l.state.statusMu.Lock()
defer l.state.statusMu.Unlock() defer l.state.statusMu.Unlock()
existingStatus := l.state.status existingStatus := l.state.status
@@ -48,7 +50,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Starting l.state.status = constants.Starting
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.start <- struct{}{} l.start <- struct{}{}
newStatus := <-l.running newStatus := constants.Starting // for canceled context
select {
case <-ctx.Done():
case newStatus = <-l.running:
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = newStatus l.state.status = newStatus
return newStatus.String(), nil return newStatus.String(), nil
@@ -78,7 +85,8 @@ func (l *looper) GetSettings() (settings configuration.DNS) {
return l.state.settings return l.state.settings
} }
func (l *looper) SetSettings(settings configuration.DNS) (outcome string) { func (l *looper) SetSettings(ctx context.Context, settings configuration.DNS) (
outcome string) {
l.state.settingsMu.Lock() l.state.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings) settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
if settingsUnchanged { if settingsUnchanged {
@@ -94,9 +102,9 @@ func (l *looper) SetSettings(settings configuration.DNS) (outcome string) {
l.updateTicker <- struct{}{} l.updateTicker <- struct{}{}
return "update period changed" return "update period changed"
} }
_, _ = l.SetStatus(constants.Stopped) _, _ = l.SetStatus(ctx, constants.Stopped)
if settings.Enabled { if settings.Enabled {
outcome, _ = l.SetStatus(constants.Running) outcome, _ = l.SetStatus(ctx, constants.Running)
} }
return outcome return outcome
} }

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
@@ -9,14 +10,17 @@ import (
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
func newDNSHandler(looper dns.Looper, logger logging.Logger) http.Handler { func newDNSHandler(ctx context.Context, looper dns.Looper,
logger logging.Logger) http.Handler {
return &dnsHandler{ return &dnsHandler{
ctx: ctx,
looper: looper, looper: looper,
logger: logger, logger: logger,
} }
} }
type dnsHandler struct { type dnsHandler struct {
ctx context.Context
looper dns.Looper looper dns.Looper
logger logging.Logger logger logging.Logger
} }
@@ -61,7 +65,7 @@ func (h *dnsHandler) setStatus(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
outcome, err := h.looper.SetStatus(status) outcome, err := h.looper.SetStatus(h.ctx, status)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
@@ -12,7 +13,7 @@ import (
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
func newHandler(logger logging.Logger, logging bool, func newHandler(ctx context.Context, logger logging.Logger, logging bool,
buildInfo models.BuildInformation, buildInfo models.BuildInformation,
openvpnLooper openvpn.Looper, openvpnLooper openvpn.Looper,
unboundLooper dns.Looper, unboundLooper dns.Looper,
@@ -22,11 +23,11 @@ func newHandler(logger logging.Logger, logging bool,
handler := &handler{} handler := &handler{}
openvpn := newOpenvpnHandler(openvpnLooper, logger) openvpn := newOpenvpnHandler(openvpnLooper, logger)
dns := newDNSHandler(unboundLooper, logger) dns := newDNSHandler(ctx, unboundLooper, logger)
updater := newUpdaterHandler(updaterLooper, logger) updater := newUpdaterHandler(updaterLooper, logger)
publicip := newPublicIPHandler(publicIPLooper, logger) publicip := newPublicIPHandler(publicIPLooper, logger)
handler.v0 = newHandlerV0(logger, openvpnLooper, unboundLooper, updaterLooper) handler.v0 = newHandlerV0(ctx, logger, openvpnLooper, unboundLooper, updaterLooper)
handler.v1 = newHandlerV1(logger, buildInfo, openvpn, dns, updater, publicip) handler.v1 = newHandlerV1(logger, buildInfo, openvpn, dns, updater, publicip)
handlerWithLog := withLogMiddleware(handler, logger, logging) handlerWithLog := withLogMiddleware(handler, logger, logging)

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"net/http" "net/http"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
@@ -10,9 +11,10 @@ import (
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
func newHandlerV0(logger logging.Logger, func newHandlerV0(ctx context.Context, logger logging.Logger,
openvpn openvpn.Looper, dns dns.Looper, updater updater.Looper) http.Handler { openvpn openvpn.Looper, dns dns.Looper, updater updater.Looper) http.Handler {
return &handlerV0{ return &handlerV0{
ctx: ctx,
logger: logger, logger: logger,
openvpn: openvpn, openvpn: openvpn,
dns: dns, dns: dns,
@@ -21,6 +23,7 @@ func newHandlerV0(logger logging.Logger,
} }
type handlerV0 struct { type handlerV0 struct {
ctx context.Context
logger logging.Logger logger logging.Logger
openvpn openvpn.Looper openvpn openvpn.Looper
dns dns.Looper dns dns.Looper
@@ -44,9 +47,9 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.logger.Warn(err) h.logger.Warn(err)
} }
case "/unbound/actions/restart": case "/unbound/actions/restart":
outcome, _ := h.dns.SetStatus(constants.Stopped) outcome, _ := h.dns.SetStatus(h.ctx, constants.Stopped)
h.logger.Info("dns: %s", outcome) h.logger.Info("dns: %s", outcome)
outcome, _ = h.dns.SetStatus(constants.Running) outcome, _ = h.dns.SetStatus(h.ctx, constants.Running)
h.logger.Info("dns: %s", outcome) h.logger.Info("dns: %s", outcome)
if _, err := w.Write([]byte("dns restarted, please consider using the /v1/ API in the future.")); err != nil { if _, err := w.Write([]byte("dns restarted, please consider using the /v1/ API in the future.")); err != nil {
h.logger.Warn(err) h.logger.Warn(err)

View File

@@ -25,11 +25,11 @@ type server struct {
handler http.Handler handler http.Handler
} }
func New(address string, logEnabled bool, logger logging.Logger, func New(ctx context.Context, address string, logEnabled bool, logger logging.Logger,
buildInfo models.BuildInformation, buildInfo models.BuildInformation,
openvpnLooper openvpn.Looper, unboundLooper dns.Looper, openvpnLooper openvpn.Looper, unboundLooper dns.Looper,
updaterLooper updater.Looper, publicIPLooper publicip.Looper) Server { updaterLooper updater.Looper, publicIPLooper publicip.Looper) Server {
handler := newHandler(logger, logEnabled, buildInfo, handler := newHandler(ctx, logger, logEnabled, buildInfo,
openvpnLooper, unboundLooper, updaterLooper, publicIPLooper) openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
return &server{ return &server{
address: address, address: address,