Maint: make all set status context aware
This commit is contained in:
@@ -384,7 +384,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
|
|
||||||
// Start openvpn for the first time in a blocking call
|
// Start openvpn for the first time in a blocking call
|
||||||
// until openvpn is launched
|
// until openvpn is launched
|
||||||
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
|
_, _ = openvpnLooper.SetStatus(ctx, constants.Running) // TODO option to disable with variable
|
||||||
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|
||||||
@@ -462,7 +462,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
|
|||||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
// Runs the Public IP getter job once
|
// Runs the Public IP getter job once
|
||||||
_, _ = publicIPLooper.SetStatus(constants.Running)
|
_, _ = publicIPLooper.SetStatus(ctx, constants.Running)
|
||||||
if versionInformation && first {
|
if versionInformation && first {
|
||||||
first = false
|
first = false
|
||||||
message, err := versionpkg.GetMessage(ctx, buildInfo, httpClient)
|
message, err := versionpkg.GetMessage(ctx, buildInfo, httpClient)
|
||||||
|
|||||||
@@ -69,9 +69,15 @@ func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
|
|||||||
l.state.status = constants.Stopping
|
l.state.status = constants.Stopping
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.stop <- struct{}{}
|
l.stop <- struct{}{}
|
||||||
<-l.stopped
|
|
||||||
|
newStatus := constants.Stopping // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-l.stopped:
|
||||||
|
newStatus = constants.Stopped
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = constants.Stopped
|
l.state.status = newStatus
|
||||||
return status.String(), nil
|
return status.String(), nil
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
||||||
|
|||||||
@@ -15,10 +15,12 @@ import (
|
|||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, done chan<- struct{})
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error)
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
GetSettings() (settings configuration.HTTPProxy)
|
GetSettings() (settings configuration.HTTPProxy)
|
||||||
SetSettings(settings configuration.HTTPProxy) (outcome string)
|
SetSettings(ctx context.Context, settings configuration.HTTPProxy) (
|
||||||
|
outcome string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
@@ -57,7 +59,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
|
|
||||||
if l.GetSettings().Enabled {
|
if l.GetSettings().Enabled {
|
||||||
go func() {
|
go func() {
|
||||||
_, _ = l.SetStatus(constants.Running)
|
_, _ = l.SetStatus(ctx, constants.Running)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package httpproxy
|
package httpproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -32,7 +33,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
|
|||||||
|
|
||||||
var ErrInvalidStatus = errors.New("invalid status")
|
var ErrInvalidStatus = errors.New("invalid status")
|
||||||
|
|
||||||
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error) {
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
defer l.state.statusMu.Unlock()
|
defer l.state.statusMu.Unlock()
|
||||||
existingStatus := l.state.status
|
existingStatus := l.state.status
|
||||||
@@ -48,7 +50,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Starting
|
l.state.status = constants.Starting
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.start <- struct{}{}
|
l.start <- struct{}{}
|
||||||
newStatus := <-l.running
|
|
||||||
|
newStatus := constants.Starting // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case newStatus = <-l.running:
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = newStatus
|
l.state.status = newStatus
|
||||||
return newStatus.String(), nil
|
return newStatus.String(), nil
|
||||||
@@ -62,9 +69,15 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Stopping
|
l.state.status = constants.Stopping
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.stop <- struct{}{}
|
l.stop <- struct{}{}
|
||||||
<-l.stopped
|
|
||||||
|
newStatus := constants.Stopping // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-l.stopped:
|
||||||
|
newStatus = constants.Stopped
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = status
|
l.state.status = newStatus
|
||||||
return status.String(), nil
|
return status.String(), nil
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
||||||
@@ -78,7 +91,8 @@ func (l *looper) GetSettings() (settings configuration.HTTPProxy) {
|
|||||||
return l.state.settings
|
return l.state.settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) SetSettings(settings configuration.HTTPProxy) (outcome string) {
|
func (l *looper) SetSettings(ctx context.Context, settings configuration.HTTPProxy) (
|
||||||
|
outcome string) {
|
||||||
l.state.settingsMu.Lock()
|
l.state.settingsMu.Lock()
|
||||||
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
|
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
|
||||||
if settingsUnchanged {
|
if settingsUnchanged {
|
||||||
@@ -93,12 +107,12 @@ func (l *looper) SetSettings(settings configuration.HTTPProxy) (outcome string)
|
|||||||
switch {
|
switch {
|
||||||
case !newEnabled && !previousEnabled:
|
case !newEnabled && !previousEnabled:
|
||||||
case newEnabled && previousEnabled:
|
case newEnabled && previousEnabled:
|
||||||
_, _ = l.SetStatus(constants.Stopped)
|
_, _ = l.SetStatus(ctx, constants.Stopped)
|
||||||
_, _ = l.SetStatus(constants.Running)
|
_, _ = l.SetStatus(ctx, constants.Running)
|
||||||
case newEnabled && !previousEnabled:
|
case newEnabled && !previousEnabled:
|
||||||
_, _ = l.SetStatus(constants.Running)
|
_, _ = l.SetStatus(ctx, constants.Running)
|
||||||
case !newEnabled && previousEnabled:
|
case !newEnabled && previousEnabled:
|
||||||
_, _ = l.SetStatus(constants.Stopped)
|
_, _ = l.SetStatus(ctx, constants.Stopped)
|
||||||
}
|
}
|
||||||
return "settings updated"
|
return "settings updated"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,9 +21,11 @@ import (
|
|||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, done chan<- struct{})
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error)
|
||||||
GetSettings() (settings configuration.OpenVPN)
|
GetSettings() (settings configuration.OpenVPN)
|
||||||
SetSettings(settings configuration.OpenVPN) (outcome string)
|
SetSettings(ctx context.Context, settings configuration.OpenVPN) (
|
||||||
|
outcome string)
|
||||||
GetServers() (servers models.AllServers)
|
GetServers() (servers models.AllServers)
|
||||||
SetServers(servers models.AllServers)
|
SetServers(servers models.AllServers)
|
||||||
GetPortForwarded() (port uint16)
|
GetPortForwarded() (port uint16)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package openvpn
|
package openvpn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -46,7 +47,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
|
|||||||
|
|
||||||
var ErrInvalidStatus = errors.New("invalid status")
|
var ErrInvalidStatus = errors.New("invalid status")
|
||||||
|
|
||||||
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error) {
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
defer l.state.statusMu.Unlock()
|
defer l.state.statusMu.Unlock()
|
||||||
existingStatus := l.state.status
|
existingStatus := l.state.status
|
||||||
@@ -62,7 +64,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Starting
|
l.state.status = constants.Starting
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.start <- struct{}{}
|
l.start <- struct{}{}
|
||||||
newStatus := <-l.running
|
|
||||||
|
newStatus := constants.Starting // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case newStatus = <-l.running:
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = newStatus
|
l.state.status = newStatus
|
||||||
return newStatus.String(), nil
|
return newStatus.String(), nil
|
||||||
@@ -76,9 +83,15 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Stopping
|
l.state.status = constants.Stopping
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.stop <- struct{}{}
|
l.stop <- struct{}{}
|
||||||
<-l.stopped
|
|
||||||
|
newStatus := constants.Stopping // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-l.stopped:
|
||||||
|
newStatus = constants.Stopped
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = constants.Stopped
|
l.state.status = newStatus
|
||||||
return status.String(), nil
|
return status.String(), nil
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
||||||
@@ -92,7 +105,8 @@ func (l *looper) GetSettings() (settings configuration.OpenVPN) {
|
|||||||
return l.state.settings
|
return l.state.settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) SetSettings(settings configuration.OpenVPN) (outcome string) {
|
func (l *looper) SetSettings(ctx context.Context, settings configuration.OpenVPN) (
|
||||||
|
outcome string) {
|
||||||
l.state.settingsMu.Lock()
|
l.state.settingsMu.Lock()
|
||||||
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
|
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
|
||||||
if settingsUnchanged {
|
if settingsUnchanged {
|
||||||
@@ -100,8 +114,8 @@ func (l *looper) SetSettings(settings configuration.OpenVPN) (outcome string) {
|
|||||||
return "settings left unchanged"
|
return "settings left unchanged"
|
||||||
}
|
}
|
||||||
l.state.settings = settings
|
l.state.settings = settings
|
||||||
_, _ = l.SetStatus(constants.Stopped)
|
_, _ = l.SetStatus(ctx, constants.Stopped)
|
||||||
outcome, _ = l.SetStatus(constants.Running)
|
outcome, _ = l.SetStatus(ctx, constants.Running)
|
||||||
return outcome
|
return outcome
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ type Looper interface {
|
|||||||
Run(ctx context.Context, done chan<- struct{})
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
RunRestartTicker(ctx context.Context, done chan<- struct{})
|
RunRestartTicker(ctx context.Context, done chan<- struct{})
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error)
|
||||||
GetSettings() (settings configuration.PublicIP)
|
GetSettings() (settings configuration.PublicIP)
|
||||||
SetSettings(settings configuration.PublicIP) (outcome string)
|
SetSettings(settings configuration.PublicIP) (outcome string)
|
||||||
GetPublicIP() (publicIP net.IP)
|
GetPublicIP() (publicIP net.IP)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package publicip
|
package publicip
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -35,7 +36,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
|
|||||||
|
|
||||||
var ErrInvalidStatus = errors.New("invalid status")
|
var ErrInvalidStatus = errors.New("invalid status")
|
||||||
|
|
||||||
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error) {
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
defer l.state.statusMu.Unlock()
|
defer l.state.statusMu.Unlock()
|
||||||
existingStatus := l.state.status
|
existingStatus := l.state.status
|
||||||
@@ -51,7 +53,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Starting
|
l.state.status = constants.Starting
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.start <- struct{}{}
|
l.start <- struct{}{}
|
||||||
newStatus := <-l.running
|
|
||||||
|
newStatus := constants.Starting // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case newStatus = <-l.running:
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = newStatus
|
l.state.status = newStatus
|
||||||
return newStatus.String(), nil
|
return newStatus.String(), nil
|
||||||
@@ -65,9 +72,15 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Stopping
|
l.state.status = constants.Stopping
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.stop <- struct{}{}
|
l.stop <- struct{}{}
|
||||||
<-l.stopped
|
|
||||||
|
newStatus := constants.Stopping // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-l.stopped:
|
||||||
|
newStatus = constants.Stopped
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = status
|
l.state.status = newStatus
|
||||||
return status.String(), nil
|
return status.String(), nil
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
||||||
@@ -81,7 +94,8 @@ func (l *looper) GetSettings() (settings configuration.PublicIP) {
|
|||||||
return l.state.settings
|
return l.state.settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) SetSettings(settings configuration.PublicIP) (outcome string) {
|
func (l *looper) SetSettings(settings configuration.PublicIP) (
|
||||||
|
outcome string) {
|
||||||
l.state.settingsMu.Lock()
|
l.state.settingsMu.Lock()
|
||||||
defer l.state.settingsMu.Unlock()
|
defer l.state.settingsMu.Unlock()
|
||||||
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
|
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ func newHandler(ctx context.Context, logger logging.Logger, logging bool,
|
|||||||
) http.Handler {
|
) http.Handler {
|
||||||
handler := &handler{}
|
handler := &handler{}
|
||||||
|
|
||||||
openvpn := newOpenvpnHandler(openvpnLooper, logger)
|
openvpn := newOpenvpnHandler(ctx, openvpnLooper, logger)
|
||||||
dns := newDNSHandler(ctx, unboundLooper, logger)
|
dns := newDNSHandler(ctx, unboundLooper, logger)
|
||||||
updater := newUpdaterHandler(updaterLooper, logger)
|
updater := newUpdaterHandler(ctx, updaterLooper, logger)
|
||||||
publicip := newPublicIPHandler(publicIPLooper, logger)
|
publicip := newPublicIPHandler(publicIPLooper, logger)
|
||||||
|
|
||||||
handler.v0 = newHandlerV0(ctx, logger, openvpnLooper, unboundLooper, updaterLooper)
|
handler.v0 = newHandlerV0(ctx, logger, openvpnLooper, unboundLooper, updaterLooper)
|
||||||
|
|||||||
@@ -39,9 +39,9 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
case "/version":
|
case "/version":
|
||||||
http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect)
|
http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect)
|
||||||
case "/openvpn/actions/restart":
|
case "/openvpn/actions/restart":
|
||||||
outcome, _ := h.openvpn.SetStatus(constants.Stopped)
|
outcome, _ := h.openvpn.SetStatus(h.ctx, constants.Stopped)
|
||||||
h.logger.Info("openvpn: %s", outcome)
|
h.logger.Info("openvpn: %s", outcome)
|
||||||
outcome, _ = h.openvpn.SetStatus(constants.Running)
|
outcome, _ = h.openvpn.SetStatus(h.ctx, constants.Running)
|
||||||
h.logger.Info("openvpn: %s", outcome)
|
h.logger.Info("openvpn: %s", outcome)
|
||||||
if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil {
|
if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil {
|
||||||
h.logger.Warn(err)
|
h.logger.Warn(err)
|
||||||
@@ -59,9 +59,9 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
case "/openvpn/settings":
|
case "/openvpn/settings":
|
||||||
http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect)
|
http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect)
|
||||||
case "/updater/restart":
|
case "/updater/restart":
|
||||||
outcome, _ := h.updater.SetStatus(constants.Stopped)
|
outcome, _ := h.updater.SetStatus(h.ctx, constants.Stopped)
|
||||||
h.logger.Info("updater: %s", outcome)
|
h.logger.Info("updater: %s", outcome)
|
||||||
outcome, _ = h.updater.SetStatus(constants.Running)
|
outcome, _ = h.updater.SetStatus(h.ctx, constants.Running)
|
||||||
h.logger.Info("updater: %s", outcome)
|
h.logger.Info("updater: %s", outcome)
|
||||||
if _, err := w.Write([]byte("updater restarted, please consider using the /v1/ API in the future.")); err != nil {
|
if _, err := w.Write([]byte("updater restarted, please consider using the /v1/ API in the future.")); err != nil {
|
||||||
h.logger.Warn(err)
|
h.logger.Warn(err)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -9,14 +10,17 @@ import (
|
|||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newOpenvpnHandler(looper openvpn.Looper, logger logging.Logger) http.Handler {
|
func newOpenvpnHandler(ctx context.Context, looper openvpn.Looper,
|
||||||
|
logger logging.Logger) http.Handler {
|
||||||
return &openvpnHandler{
|
return &openvpnHandler{
|
||||||
|
ctx: ctx,
|
||||||
looper: looper,
|
looper: looper,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type openvpnHandler struct {
|
type openvpnHandler struct {
|
||||||
|
ctx context.Context
|
||||||
looper openvpn.Looper
|
looper openvpn.Looper
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
}
|
}
|
||||||
@@ -75,7 +79,7 @@ func (h *openvpnHandler) setStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
outcome, err := h.looper.SetStatus(status)
|
outcome, err := h.looper.SetStatus(h.ctx, status)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -10,15 +11,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func newUpdaterHandler(
|
func newUpdaterHandler(
|
||||||
|
ctx context.Context,
|
||||||
looper updater.Looper,
|
looper updater.Looper,
|
||||||
logger logging.Logger) http.Handler {
|
logger logging.Logger) http.Handler {
|
||||||
return &updaterHandler{
|
return &updaterHandler{
|
||||||
|
ctx: ctx,
|
||||||
looper: looper,
|
looper: looper,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type updaterHandler struct {
|
type updaterHandler struct {
|
||||||
|
ctx context.Context
|
||||||
looper updater.Looper
|
looper updater.Looper
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
}
|
}
|
||||||
@@ -63,7 +67,7 @@ func (h *updaterHandler) setStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
outcome, err := h.looper.SetStatus(status)
|
outcome, err := h.looper.SetStatus(h.ctx, status)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -16,10 +16,12 @@ import (
|
|||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, done chan<- struct{})
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error)
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
GetSettings() (settings configuration.ShadowSocks)
|
GetSettings() (settings configuration.ShadowSocks)
|
||||||
SetSettings(settings configuration.ShadowSocks) (outcome string)
|
SetSettings(ctx context.Context, settings configuration.ShadowSocks) (
|
||||||
|
outcome string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
@@ -74,7 +76,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
|
|
||||||
if l.GetSettings().Enabled {
|
if l.GetSettings().Enabled {
|
||||||
go func() {
|
go func() {
|
||||||
_, _ = l.SetStatus(constants.Running)
|
_, _ = l.SetStatus(ctx, constants.Running)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package shadowsocks
|
package shadowsocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -32,7 +33,8 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
|
|||||||
|
|
||||||
var ErrInvalidStatus = errors.New("invalid status")
|
var ErrInvalidStatus = errors.New("invalid status")
|
||||||
|
|
||||||
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error) {
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
defer l.state.statusMu.Unlock()
|
defer l.state.statusMu.Unlock()
|
||||||
existingStatus := l.state.status
|
existingStatus := l.state.status
|
||||||
@@ -48,7 +50,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Starting
|
l.state.status = constants.Starting
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.start <- struct{}{}
|
l.start <- struct{}{}
|
||||||
newStatus := <-l.running
|
|
||||||
|
newStatus := constants.Starting // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case newStatus = <-l.running:
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = newStatus
|
l.state.status = newStatus
|
||||||
return newStatus.String(), nil
|
return newStatus.String(), nil
|
||||||
@@ -62,9 +69,14 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Stopping
|
l.state.status = constants.Stopping
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.stop <- struct{}{}
|
l.stop <- struct{}{}
|
||||||
<-l.stopped
|
newStatus := constants.Stopping // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-l.stopped:
|
||||||
|
newStatus = constants.Stopped
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = status
|
l.state.status = newStatus
|
||||||
return status.String(), nil
|
return status.String(), nil
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
||||||
@@ -78,7 +90,8 @@ func (l *looper) GetSettings() (settings configuration.ShadowSocks) {
|
|||||||
return l.state.settings
|
return l.state.settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) SetSettings(settings configuration.ShadowSocks) (outcome string) {
|
func (l *looper) SetSettings(ctx context.Context, settings configuration.ShadowSocks) (
|
||||||
|
outcome string) {
|
||||||
l.state.settingsMu.Lock()
|
l.state.settingsMu.Lock()
|
||||||
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
|
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
|
||||||
if settingsUnchanged {
|
if settingsUnchanged {
|
||||||
@@ -93,12 +106,12 @@ func (l *looper) SetSettings(settings configuration.ShadowSocks) (outcome string
|
|||||||
switch {
|
switch {
|
||||||
case !newEnabled && !previousEnabled:
|
case !newEnabled && !previousEnabled:
|
||||||
case newEnabled && previousEnabled:
|
case newEnabled && previousEnabled:
|
||||||
_, _ = l.SetStatus(constants.Stopped)
|
_, _ = l.SetStatus(ctx, constants.Stopped)
|
||||||
_, _ = l.SetStatus(constants.Running)
|
_, _ = l.SetStatus(ctx, constants.Running)
|
||||||
case newEnabled && !previousEnabled:
|
case newEnabled && !previousEnabled:
|
||||||
_, _ = l.SetStatus(constants.Running)
|
_, _ = l.SetStatus(ctx, constants.Running)
|
||||||
case !newEnabled && previousEnabled:
|
case !newEnabled && previousEnabled:
|
||||||
_, _ = l.SetStatus(constants.Stopped)
|
_, _ = l.SetStatus(ctx, constants.Stopped)
|
||||||
}
|
}
|
||||||
return "settings updated"
|
return "settings updated"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ type Looper interface {
|
|||||||
Run(ctx context.Context, done chan<- struct{})
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
RunRestartTicker(ctx context.Context, done chan<- struct{})
|
RunRestartTicker(ctx context.Context, done chan<- struct{})
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(ctx context.Context, status models.LoopStatus) (
|
||||||
|
outcome string, err error)
|
||||||
GetSettings() (settings configuration.Updater)
|
GetSettings() (settings configuration.Updater)
|
||||||
SetSettings(settings configuration.Updater) (outcome string)
|
SetSettings(settings configuration.Updater) (outcome string)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package updater
|
package updater
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -32,7 +33,7 @@ func (l *looper) GetStatus() (status models.LoopStatus) {
|
|||||||
|
|
||||||
var ErrInvalidStatus = errors.New("invalid status")
|
var ErrInvalidStatus = errors.New("invalid status")
|
||||||
|
|
||||||
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) (outcome string, err error) {
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
defer l.state.statusMu.Unlock()
|
defer l.state.statusMu.Unlock()
|
||||||
existingStatus := l.state.status
|
existingStatus := l.state.status
|
||||||
@@ -48,7 +49,12 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Starting
|
l.state.status = constants.Starting
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.start <- struct{}{}
|
l.start <- struct{}{}
|
||||||
newStatus := <-l.running
|
|
||||||
|
newStatus := constants.Starting // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case newStatus = <-l.running:
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = newStatus
|
l.state.status = newStatus
|
||||||
return newStatus.String(), nil
|
return newStatus.String(), nil
|
||||||
@@ -62,9 +68,15 @@ func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error)
|
|||||||
l.state.status = constants.Stopping
|
l.state.status = constants.Stopping
|
||||||
l.state.statusMu.Unlock()
|
l.state.statusMu.Unlock()
|
||||||
l.stop <- struct{}{}
|
l.stop <- struct{}{}
|
||||||
<-l.stopped
|
|
||||||
|
newStatus := constants.Stopping // for canceled context
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-l.stopped:
|
||||||
|
newStatus = constants.Stopped
|
||||||
|
}
|
||||||
l.state.statusMu.Lock()
|
l.state.statusMu.Lock()
|
||||||
l.state.status = status
|
l.state.status = newStatus
|
||||||
return status.String(), nil
|
return status.String(), nil
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
|
||||||
|
|||||||
Reference in New Issue
Block a user