diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index f9a570f5..498e9a1c 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -431,7 +431,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, "http server", goroutine.OptionTimeout(defaultShutdownTimeout)) httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging, logger.New(log.SetComponent("http server")), - buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper) + buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper, storage) if err != nil { return fmt.Errorf("cannot setup control server: %w", err) } diff --git a/internal/configuration/settings/settings.go b/internal/configuration/settings/settings.go index 74fb6df2..b9042553 100644 --- a/internal/configuration/settings/settings.go +++ b/internal/configuration/settings/settings.go @@ -46,7 +46,7 @@ func (s *Settings) Validate(storage Storage) (err error) { "version": s.Version.validate, // Pprof validation done in pprof constructor "VPN": func() error { - return s.VPN.validate(storage) + return s.VPN.Validate(storage) }, } @@ -73,7 +73,7 @@ func (s *Settings) copy() (copied Settings) { System: s.System.copy(), Updater: s.Updater.copy(), Version: s.Version.copy(), - VPN: s.VPN.copy(), + VPN: s.VPN.Copy(), Pprof: s.Pprof.Copy(), } } @@ -108,7 +108,7 @@ func (s *Settings) OverrideWith(other Settings, patchedSettings.System.overrideWith(other.System) patchedSettings.Updater.overrideWith(other.Updater) patchedSettings.Version.overrideWith(other.Version) - patchedSettings.VPN.overrideWith(other.VPN) + patchedSettings.VPN.OverrideWith(other.VPN) patchedSettings.Pprof.OverrideWith(other.Pprof) err = patchedSettings.Validate(storage) if err != nil { diff --git a/internal/configuration/settings/vpn.go b/internal/configuration/settings/vpn.go index 5923c5a6..8977638e 100644 --- a/internal/configuration/settings/vpn.go +++ b/internal/configuration/settings/vpn.go @@ -20,7 +20,7 @@ type VPN struct { } // TODO v4 remove pointer for receiver (because of Surfshark). -func (v *VPN) validate(storage Storage) (err error) { +func (v *VPN) Validate(storage Storage) (err error) { // Validate Type validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard} if !helpers.IsOneOf(v.Type, validVPNTypes...) { @@ -48,7 +48,7 @@ func (v *VPN) validate(storage Storage) (err error) { return nil } -func (v *VPN) copy() (copied VPN) { +func (v *VPN) Copy() (copied VPN) { return VPN{ Type: v.Type, Provider: v.Provider.copy(), @@ -64,7 +64,7 @@ func (v *VPN) mergeWith(other VPN) { v.Wireguard.mergeWith(other.Wireguard) } -func (v *VPN) overrideWith(other VPN) { +func (v *VPN) OverrideWith(other VPN) { v.Type = helpers.OverrideWithString(v.Type, other.Type) v.Provider.overrideWith(other.Provider) v.OpenVPN.overrideWith(other.OpenVPN) diff --git a/internal/server/handler.go b/internal/server/handler.go index 443088c0..00173d5b 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -15,10 +15,11 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool, unboundLooper DNSLoop, updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, + storage Storage, ) http.Handler { handler := &handler{} - vpn := newVPNHandler(ctx, vpnLooper, logger) + vpn := newVPNHandler(ctx, vpnLooper, storage, logger) openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger) dns := newDNSHandler(ctx, unboundLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger) diff --git a/internal/server/interfaces.go b/internal/server/interfaces.go index 5c2cec3a..62c24acb 100644 --- a/internal/server/interfaces.go +++ b/internal/server/interfaces.go @@ -12,6 +12,7 @@ type VPNLooper interface { ApplyStatus(ctx context.Context, status models.LoopStatus) ( outcome string, err error) GetSettings() (settings settings.VPN) + SetSettings(ctx context.Context, settings settings.VPN) (outcome string) } type DNSLoop interface { @@ -27,3 +28,7 @@ type PortForwardedGetter interface { type PublicIPLoop interface { GetData() (data models.PublicIP) } + +type Storage interface { + GetFilterChoices(provider string) models.FilterChoices +} diff --git a/internal/server/server.go b/internal/server/server.go index ea1b7599..e7e431d3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -11,9 +11,10 @@ import ( func New(ctx context.Context, address string, logEnabled bool, logger Logger, buildInfo models.BuildInformation, openvpnLooper VPNLooper, pfGetter PortForwardedGetter, unboundLooper DNSLoop, - updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop) (server *httpserver.Server, err error) { + updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage) ( + server *httpserver.Server, err error) { handler := newHandler(ctx, logger, logEnabled, buildInfo, - openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper) + openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper, storage) httpServerSettings := httpserver.Settings{ Address: address, diff --git a/internal/server/vpn.go b/internal/server/vpn.go index 9c452c45..ff91a5ae 100644 --- a/internal/server/vpn.go +++ b/internal/server/vpn.go @@ -5,21 +5,25 @@ import ( "encoding/json" "net/http" "strings" + + "github.com/qdm12/gluetun/internal/configuration/settings" ) func newVPNHandler(ctx context.Context, looper VPNLooper, - w warner) http.Handler { + storage Storage, w warner) http.Handler { return &vpnHandler{ - ctx: ctx, - looper: looper, - warner: w, + ctx: ctx, + looper: looper, + storage: storage, + warner: w, } } type vpnHandler struct { - ctx context.Context //nolint:containedctx - looper VPNLooper - warner warner + ctx context.Context //nolint:containedctx + looper VPNLooper + storage Storage + warner warner } func (h *vpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -38,6 +42,8 @@ func (h *vpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: h.getSettings(w) + case http.MethodPut: + h.patchSettings(w, r) default: http.Error(w, "method "+r.Method+" not supported", http.StatusBadRequest) } @@ -91,3 +97,32 @@ func (h *vpnHandler) getSettings(w http.ResponseWriter) { return } } + +func (h *vpnHandler) patchSettings(w http.ResponseWriter, r *http.Request) { + var overrideSettings settings.VPN + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&overrideSettings) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + err = r.Body.Close() + if err != nil { + h.warner.Warn("closing body: " + err.Error()) + } + + updatedSettings := h.looper.GetSettings() // already copied + updatedSettings.OverrideWith(overrideSettings) + err = updatedSettings.Validate(h.storage) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + outcome := h.looper.SetSettings(h.ctx, updatedSettings) + _, err = w.Write([]byte(outcome)) + if err != nil { + h.warner.Warn("writing response: " + err.Error()) + } +} diff --git a/internal/vpn/state/vpn.go b/internal/vpn/state/vpn.go index 34d75298..4ea67fd7 100644 --- a/internal/vpn/state/vpn.go +++ b/internal/vpn/state/vpn.go @@ -10,7 +10,7 @@ import ( func (s *State) GetSettings() (vpn settings.VPN) { s.settingsMu.RLock() - vpn = s.vpn + vpn = s.vpn.Copy() s.settingsMu.RUnlock() return vpn }