Maint: make all set status context aware

This commit is contained in:
Quentin McGaw (desktop)
2021-07-16 00:49:59 +00:00
parent 6bbb7c8f7d
commit 0ed738cd61
16 changed files with 146 additions and 57 deletions

View File

@@ -384,7 +384,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
// Start openvpn for the first time in a blocking call // Start openvpn for the first time in a blocking call
// until openvpn is launched // until openvpn is launched
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable _, _ = openvpnLooper.SetStatus(ctx, constants.Running) // TODO option to disable with variable
<-ctx.Done() <-ctx.Done()
@@ -462,7 +462,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
restartTickerContext, restartTickerCancel = context.WithCancel(ctx) restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
// Runs the Public IP getter job once // Runs the Public IP getter job once
_, _ = publicIPLooper.SetStatus(constants.Running) _, _ = publicIPLooper.SetStatus(ctx, constants.Running)
if versionInformation && first { if versionInformation && first {
first = false first = false
message, err := versionpkg.GetMessage(ctx, buildInfo, httpClient) message, err := versionpkg.GetMessage(ctx, buildInfo, httpClient)

View File

@@ -69,9 +69,15 @@ func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
l.state.status = constants.Stopping l.state.status = constants.Stopping
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.stop <- struct{}{} l.stop <- struct{}{}
<-l.stopped
newStatus := constants.Stopping // for canceled context
select {
case <-ctx.Done():
case <-l.stopped:
newStatus = constants.Stopped
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = constants.Stopped l.state.status = newStatus
return status.String(), nil return status.String(), nil
default: default:
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",

View File

@@ -15,10 +15,12 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, done chan<- struct{}) Run(ctx context.Context, done chan<- struct{})
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error)
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
GetSettings() (settings configuration.HTTPProxy) GetSettings() (settings configuration.HTTPProxy)
SetSettings(settings configuration.HTTPProxy) (outcome string) SetSettings(ctx context.Context, settings configuration.HTTPProxy) (
outcome string)
} }
type looper struct { type looper struct {
@@ -57,7 +59,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
if l.GetSettings().Enabled { if l.GetSettings().Enabled {
go func() { go func() {
_, _ = l.SetStatus(constants.Running) _, _ = l.SetStatus(ctx, constants.Running)
}() }()
} }

View File

@@ -1,6 +1,7 @@
package httpproxy package httpproxy
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -32,7 +33,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
var ErrInvalidStatus = errors.New("invalid status") var ErrInvalidStatus = errors.New("invalid status")
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error) {
l.state.statusMu.Lock() l.state.statusMu.Lock()
defer l.state.statusMu.Unlock() defer l.state.statusMu.Unlock()
existingStatus := l.state.status existingStatus := l.state.status
@@ -48,7 +50,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Starting l.state.status = constants.Starting
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.start <- struct{}{} l.start <- struct{}{}
newStatus := <-l.running
newStatus := constants.Starting // for canceled context
select {
case <-ctx.Done():
case newStatus = <-l.running:
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = newStatus l.state.status = newStatus
return newStatus.String(), nil return newStatus.String(), nil
@@ -62,9 +69,15 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Stopping l.state.status = constants.Stopping
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.stop <- struct{}{} l.stop <- struct{}{}
<-l.stopped
newStatus := constants.Stopping // for canceled context
select {
case <-ctx.Done():
case <-l.stopped:
newStatus = constants.Stopped
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = status l.state.status = newStatus
return status.String(), nil return status.String(), nil
default: default:
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
@@ -78,7 +91,8 @@ func (l *looper) GetSettings() (settings configuration.HTTPProxy) {
return l.state.settings return l.state.settings
} }
func (l *looper) SetSettings(settings configuration.HTTPProxy) (outcome string) { func (l *looper) SetSettings(ctx context.Context, settings configuration.HTTPProxy) (
outcome string) {
l.state.settingsMu.Lock() l.state.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings) settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
if settingsUnchanged { if settingsUnchanged {
@@ -93,12 +107,12 @@ func (l *looper) SetSettings(settings configuration.HTTPProxy) (outcome string)
switch { switch {
case !newEnabled && !previousEnabled: case !newEnabled && !previousEnabled:
case newEnabled && previousEnabled: case newEnabled && previousEnabled:
_, _ = l.SetStatus(constants.Stopped) _, _ = l.SetStatus(ctx, constants.Stopped)
_, _ = l.SetStatus(constants.Running) _, _ = l.SetStatus(ctx, constants.Running)
case newEnabled && !previousEnabled: case newEnabled && !previousEnabled:
_, _ = l.SetStatus(constants.Running) _, _ = l.SetStatus(ctx, constants.Running)
case !newEnabled && previousEnabled: case !newEnabled && previousEnabled:
_, _ = l.SetStatus(constants.Stopped) _, _ = l.SetStatus(ctx, constants.Stopped)
} }
return "settings updated" return "settings updated"
} }

View File

@@ -21,9 +21,11 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, done chan<- struct{}) Run(ctx context.Context, done chan<- struct{})
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error)
GetSettings() (settings configuration.OpenVPN) GetSettings() (settings configuration.OpenVPN)
SetSettings(settings configuration.OpenVPN) (outcome string) SetSettings(ctx context.Context, settings configuration.OpenVPN) (
outcome string)
GetServers() (servers models.AllServers) GetServers() (servers models.AllServers)
SetServers(servers models.AllServers) SetServers(servers models.AllServers)
GetPortForwarded() (port uint16) GetPortForwarded() (port uint16)

View File

@@ -1,6 +1,7 @@
package openvpn package openvpn
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -46,7 +47,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
var ErrInvalidStatus = errors.New("invalid status") var ErrInvalidStatus = errors.New("invalid status")
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error) {
l.state.statusMu.Lock() l.state.statusMu.Lock()
defer l.state.statusMu.Unlock() defer l.state.statusMu.Unlock()
existingStatus := l.state.status existingStatus := l.state.status
@@ -62,7 +64,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Starting l.state.status = constants.Starting
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.start <- struct{}{} l.start <- struct{}{}
newStatus := <-l.running
newStatus := constants.Starting // for canceled context
select {
case <-ctx.Done():
case newStatus = <-l.running:
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = newStatus l.state.status = newStatus
return newStatus.String(), nil return newStatus.String(), nil
@@ -76,9 +83,15 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Stopping l.state.status = constants.Stopping
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.stop <- struct{}{} l.stop <- struct{}{}
<-l.stopped
newStatus := constants.Stopping // for canceled context
select {
case <-ctx.Done():
case <-l.stopped:
newStatus = constants.Stopped
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = constants.Stopped l.state.status = newStatus
return status.String(), nil return status.String(), nil
default: default:
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
@@ -92,7 +105,8 @@ func (l *looper) GetSettings() (settings configuration.OpenVPN) {
return l.state.settings return l.state.settings
} }
func (l *looper) SetSettings(settings configuration.OpenVPN) (outcome string) { func (l *looper) SetSettings(ctx context.Context, settings configuration.OpenVPN) (
outcome string) {
l.state.settingsMu.Lock() l.state.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings) settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
if settingsUnchanged { if settingsUnchanged {
@@ -100,8 +114,8 @@ func (l *looper) SetSettings(settings configuration.OpenVPN) (outcome string) {
return "settings left unchanged" return "settings left unchanged"
} }
l.state.settings = settings l.state.settings = settings
_, _ = l.SetStatus(constants.Stopped) _, _ = l.SetStatus(ctx, constants.Stopped)
outcome, _ = l.SetStatus(constants.Running) outcome, _ = l.SetStatus(ctx, constants.Running)
return outcome return outcome
} }

View File

@@ -18,7 +18,8 @@ type Looper interface {
Run(ctx context.Context, done chan<- struct{}) Run(ctx context.Context, done chan<- struct{})
RunRestartTicker(ctx context.Context, done chan<- struct{}) RunRestartTicker(ctx context.Context, done chan<- struct{})
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error)
GetSettings() (settings configuration.PublicIP) GetSettings() (settings configuration.PublicIP)
SetSettings(settings configuration.PublicIP) (outcome string) SetSettings(settings configuration.PublicIP) (outcome string)
GetPublicIP() (publicIP net.IP) GetPublicIP() (publicIP net.IP)

View File

@@ -1,6 +1,7 @@
package publicip package publicip
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -35,7 +36,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
var ErrInvalidStatus = errors.New("invalid status") var ErrInvalidStatus = errors.New("invalid status")
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error) {
l.state.statusMu.Lock() l.state.statusMu.Lock()
defer l.state.statusMu.Unlock() defer l.state.statusMu.Unlock()
existingStatus := l.state.status existingStatus := l.state.status
@@ -51,7 +53,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Starting l.state.status = constants.Starting
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.start <- struct{}{} l.start <- struct{}{}
newStatus := <-l.running
newStatus := constants.Starting // for canceled context
select {
case <-ctx.Done():
case newStatus = <-l.running:
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = newStatus l.state.status = newStatus
return newStatus.String(), nil return newStatus.String(), nil
@@ -65,9 +72,15 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Stopping l.state.status = constants.Stopping
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.stop <- struct{}{} l.stop <- struct{}{}
<-l.stopped
newStatus := constants.Stopping // for canceled context
select {
case <-ctx.Done():
case <-l.stopped:
newStatus = constants.Stopped
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = status l.state.status = newStatus
return status.String(), nil return status.String(), nil
default: default:
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
@@ -81,7 +94,8 @@ func (l *looper) GetSettings() (settings configuration.PublicIP) {
return l.state.settings return l.state.settings
} }
func (l *looper) SetSettings(settings configuration.PublicIP) (outcome string) { func (l *looper) SetSettings(settings configuration.PublicIP) (
outcome string) {
l.state.settingsMu.Lock() l.state.settingsMu.Lock()
defer l.state.settingsMu.Unlock() defer l.state.settingsMu.Unlock()
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings) settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)

View File

@@ -22,9 +22,9 @@ func newHandler(ctx context.Context, logger logging.Logger, logging bool,
) http.Handler { ) http.Handler {
handler := &handler{} handler := &handler{}
openvpn := newOpenvpnHandler(openvpnLooper, logger) openvpn := newOpenvpnHandler(ctx, openvpnLooper, logger)
dns := newDNSHandler(ctx, unboundLooper, logger) dns := newDNSHandler(ctx, unboundLooper, logger)
updater := newUpdaterHandler(updaterLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger)
publicip := newPublicIPHandler(publicIPLooper, logger) publicip := newPublicIPHandler(publicIPLooper, logger)
handler.v0 = newHandlerV0(ctx, logger, openvpnLooper, unboundLooper, updaterLooper) handler.v0 = newHandlerV0(ctx, logger, openvpnLooper, unboundLooper, updaterLooper)

View File

@@ -39,9 +39,9 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "/version": case "/version":
http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect) http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect)
case "/openvpn/actions/restart": case "/openvpn/actions/restart":
outcome, _ := h.openvpn.SetStatus(constants.Stopped) outcome, _ := h.openvpn.SetStatus(h.ctx, constants.Stopped)
h.logger.Info("openvpn: %s", outcome) h.logger.Info("openvpn: %s", outcome)
outcome, _ = h.openvpn.SetStatus(constants.Running) outcome, _ = h.openvpn.SetStatus(h.ctx, constants.Running)
h.logger.Info("openvpn: %s", outcome) h.logger.Info("openvpn: %s", outcome)
if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil { if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil {
h.logger.Warn(err) h.logger.Warn(err)
@@ -59,9 +59,9 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "/openvpn/settings": case "/openvpn/settings":
http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect) http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect)
case "/updater/restart": case "/updater/restart":
outcome, _ := h.updater.SetStatus(constants.Stopped) outcome, _ := h.updater.SetStatus(h.ctx, constants.Stopped)
h.logger.Info("updater: %s", outcome) h.logger.Info("updater: %s", outcome)
outcome, _ = h.updater.SetStatus(constants.Running) outcome, _ = h.updater.SetStatus(h.ctx, constants.Running)
h.logger.Info("updater: %s", outcome) h.logger.Info("updater: %s", outcome)
if _, err := w.Write([]byte("updater restarted, please consider using the /v1/ API in the future.")); err != nil { if _, err := w.Write([]byte("updater restarted, please consider using the /v1/ API in the future.")); err != nil {
h.logger.Warn(err) h.logger.Warn(err)

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
@@ -9,14 +10,17 @@ import (
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
func newOpenvpnHandler(looper openvpn.Looper, logger logging.Logger) http.Handler { func newOpenvpnHandler(ctx context.Context, looper openvpn.Looper,
logger logging.Logger) http.Handler {
return &openvpnHandler{ return &openvpnHandler{
ctx: ctx,
looper: looper, looper: looper,
logger: logger, logger: logger,
} }
} }
type openvpnHandler struct { type openvpnHandler struct {
ctx context.Context
looper openvpn.Looper looper openvpn.Looper
logger logging.Logger logger logging.Logger
} }
@@ -75,7 +79,7 @@ func (h *openvpnHandler) setStatus(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
outcome, err := h.looper.SetStatus(status) outcome, err := h.looper.SetStatus(h.ctx, status)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
@@ -10,15 +11,18 @@ import (
) )
func newUpdaterHandler( func newUpdaterHandler(
ctx context.Context,
looper updater.Looper, looper updater.Looper,
logger logging.Logger) http.Handler { logger logging.Logger) http.Handler {
return &updaterHandler{ return &updaterHandler{
ctx: ctx,
looper: looper, looper: looper,
logger: logger, logger: logger,
} }
} }
type updaterHandler struct { type updaterHandler struct {
ctx context.Context
looper updater.Looper looper updater.Looper
logger logging.Logger logger logging.Logger
} }
@@ -63,7 +67,7 @@ func (h *updaterHandler) setStatus(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
outcome, err := h.looper.SetStatus(status) outcome, err := h.looper.SetStatus(h.ctx, status)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return

View File

@@ -16,10 +16,12 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, done chan<- struct{}) Run(ctx context.Context, done chan<- struct{})
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error)
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
GetSettings() (settings configuration.ShadowSocks) GetSettings() (settings configuration.ShadowSocks)
SetSettings(settings configuration.ShadowSocks) (outcome string) SetSettings(ctx context.Context, settings configuration.ShadowSocks) (
outcome string)
} }
type looper struct { type looper struct {
@@ -74,7 +76,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
if l.GetSettings().Enabled { if l.GetSettings().Enabled {
go func() { go func() {
_, _ = l.SetStatus(constants.Running) _, _ = l.SetStatus(ctx, constants.Running)
}() }()
} }

View File

@@ -1,6 +1,7 @@
package shadowsocks package shadowsocks
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -32,7 +33,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
var ErrInvalidStatus = errors.New("invalid status") var ErrInvalidStatus = errors.New("invalid status")
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error) {
l.state.statusMu.Lock() l.state.statusMu.Lock()
defer l.state.statusMu.Unlock() defer l.state.statusMu.Unlock()
existingStatus := l.state.status existingStatus := l.state.status
@@ -48,7 +50,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Starting l.state.status = constants.Starting
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.start <- struct{}{} l.start <- struct{}{}
newStatus := <-l.running
newStatus := constants.Starting // for canceled context
select {
case <-ctx.Done():
case newStatus = <-l.running:
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = newStatus l.state.status = newStatus
return newStatus.String(), nil return newStatus.String(), nil
@@ -62,9 +69,14 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Stopping l.state.status = constants.Stopping
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.stop <- struct{}{} l.stop <- struct{}{}
<-l.stopped newStatus := constants.Stopping // for canceled context
select {
case <-ctx.Done():
case <-l.stopped:
newStatus = constants.Stopped
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = status l.state.status = newStatus
return status.String(), nil return status.String(), nil
default: default:
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
@@ -78,7 +90,8 @@ func (l *looper) GetSettings() (settings configuration.ShadowSocks) {
return l.state.settings return l.state.settings
} }
func (l *looper) SetSettings(settings configuration.ShadowSocks) (outcome string) { func (l *looper) SetSettings(ctx context.Context, settings configuration.ShadowSocks) (
outcome string) {
l.state.settingsMu.Lock() l.state.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings) settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
if settingsUnchanged { if settingsUnchanged {
@@ -93,12 +106,12 @@ func (l *looper) SetSettings(settings configuration.ShadowSocks) (outcome string
switch { switch {
case !newEnabled && !previousEnabled: case !newEnabled && !previousEnabled:
case newEnabled && previousEnabled: case newEnabled && previousEnabled:
_, _ = l.SetStatus(constants.Stopped) _, _ = l.SetStatus(ctx, constants.Stopped)
_, _ = l.SetStatus(constants.Running) _, _ = l.SetStatus(ctx, constants.Running)
case newEnabled && !previousEnabled: case newEnabled && !previousEnabled:
_, _ = l.SetStatus(constants.Running) _, _ = l.SetStatus(ctx, constants.Running)
case !newEnabled && previousEnabled: case !newEnabled && previousEnabled:
_, _ = l.SetStatus(constants.Stopped) _, _ = l.SetStatus(ctx, constants.Stopped)
} }
return "settings updated" return "settings updated"
} }

View File

@@ -17,7 +17,8 @@ type Looper interface {
Run(ctx context.Context, done chan<- struct{}) Run(ctx context.Context, done chan<- struct{})
RunRestartTicker(ctx context.Context, done chan<- struct{}) RunRestartTicker(ctx context.Context, done chan<- struct{})
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(ctx context.Context, status models.LoopStatus) (
outcome string, err error)
GetSettings() (settings configuration.Updater) GetSettings() (settings configuration.Updater)
SetSettings(settings configuration.Updater) (outcome string) SetSettings(settings configuration.Updater) (outcome string)
} }

View File

@@ -1,6 +1,7 @@
package updater package updater
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -32,7 +33,7 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
var ErrInvalidStatus = errors.New("invalid status") var ErrInvalidStatus = errors.New("invalid status")
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (outcome string, err error) {
l.state.statusMu.Lock() l.state.statusMu.Lock()
defer l.state.statusMu.Unlock() defer l.state.statusMu.Unlock()
existingStatus := l.state.status existingStatus := l.state.status
@@ -48,7 +49,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Starting l.state.status = constants.Starting
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.start <- struct{}{} l.start <- struct{}{}
newStatus := <-l.running
newStatus := constants.Starting // for canceled context
select {
case <-ctx.Done():
case newStatus = <-l.running:
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = newStatus l.state.status = newStatus
return newStatus.String(), nil return newStatus.String(), nil
@@ -62,9 +68,15 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
l.state.status = constants.Stopping l.state.status = constants.Stopping
l.state.statusMu.Unlock() l.state.statusMu.Unlock()
l.stop <- struct{}{} l.stop <- struct{}{}
<-l.stopped
newStatus := constants.Stopping // for canceled context
select {
case <-ctx.Done():
case <-l.stopped:
newStatus = constants.Stopped
}
l.state.statusMu.Lock() l.state.statusMu.Lock()
l.state.status = status l.state.status = newStatus
return status.String(), nil return status.String(), nil
default: default:
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",