Loops and HTTP control server rework (#308)
- CRUD REST HTTP server - `/v1` HTTP server prefix - Retrocompatible with older routes (redirects to v1 or handles the requests directly) - DNS, Updater and Openvpn refactored to have a REST-like state with new methods to change their states synchronously - Openvpn, Unbound and Updater status, see #287
This commit is contained in:
@@ -217,15 +217,14 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
|
||||
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers,
|
||||
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, uid, gid, allServers,
|
||||
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
|
||||
wg.Add(1)
|
||||
// wait for restartOpenvpn
|
||||
go openvpnLooper.Run(ctx, wg)
|
||||
|
||||
updaterOptions := updater.NewOptions("127.0.0.1")
|
||||
updaterLooper := updater.NewLooper(updaterOptions, allSettings.UpdaterPeriod,
|
||||
allServers, storage, openvpnLooper.SetAllServers, httpClient, logger)
|
||||
updaterLooper := updater.NewLooper(allSettings.Updater,
|
||||
allServers, storage, openvpnLooper.SetServers, httpClient, logger)
|
||||
wg.Add(1)
|
||||
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker
|
||||
go updaterLooper.Run(ctx, wg)
|
||||
@@ -276,8 +275,9 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
wg.Add(1)
|
||||
go healthcheckServer.Run(ctx, wg)
|
||||
|
||||
// Start openvpn for the first time
|
||||
openvpnLooper.Restart()
|
||||
// Start openvpn for the first time in a blocking call
|
||||
// until openvpn is launched
|
||||
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
|
||||
|
||||
signalsCh := make(chan os.Signal, 1)
|
||||
signal.Notify(signalsCh,
|
||||
@@ -401,7 +401,7 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
|
||||
tickerWg.Wait()
|
||||
return
|
||||
case <-tunnelReadyCh: // blocks until openvpn is connected
|
||||
unboundLooper.Restart()
|
||||
_, _ = unboundLooper.SetStatus(constants.Running)
|
||||
restartTickerCancel() // stop previous restart tickers
|
||||
tickerWg.Wait()
|
||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||
|
||||
@@ -83,7 +83,7 @@ func OpenvpnConfig() error {
|
||||
}
|
||||
|
||||
func Update(args []string) error {
|
||||
options := updater.Options{CLI: true}
|
||||
options := settings.Updater{CLI: true}
|
||||
var flushToFile bool
|
||||
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
|
||||
flagSet.BoolVar(&flushToFile, "file", false, "Write results to /gluetun/servers.json (for end users)")
|
||||
|
||||
14
internal/constants/status.go
Normal file
14
internal/constants/status.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package constants
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
Starting models.LoopStatus = "starting"
|
||||
Running models.LoopStatus = "running"
|
||||
Stopping models.LoopStatus = "stopping"
|
||||
Stopped models.LoopStatus = "stopped"
|
||||
Crashed models.LoopStatus = "crashed"
|
||||
Completed models.LoopStatus = "completed"
|
||||
)
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
@@ -15,24 +16,24 @@ import (
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func())
|
||||
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
||||
Restart()
|
||||
Start()
|
||||
Stop()
|
||||
GetStatus() (status models.LoopStatus)
|
||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||
GetSettings() (settings settings.DNS)
|
||||
SetSettings(settings settings.DNS)
|
||||
SetSettings(settings settings.DNS) (outcome string)
|
||||
}
|
||||
|
||||
type looper struct {
|
||||
state state
|
||||
conf Configurator
|
||||
settings settings.DNS
|
||||
settingsMutex sync.RWMutex
|
||||
logger logging.Logger
|
||||
streamMerger command.StreamMerger
|
||||
uid int
|
||||
gid int
|
||||
restart chan struct{}
|
||||
loopLock sync.Mutex
|
||||
start chan struct{}
|
||||
running chan models.LoopStatus
|
||||
stop chan struct{}
|
||||
stopped chan struct{}
|
||||
updateTicker chan struct{}
|
||||
timeNow func() time.Time
|
||||
timeSince func(time.Time) time.Duration
|
||||
@@ -41,54 +42,25 @@ type looper struct {
|
||||
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
||||
streamMerger command.StreamMerger, uid, gid int) Looper {
|
||||
return &looper{
|
||||
conf: conf,
|
||||
state: state{
|
||||
status: constants.Stopped,
|
||||
settings: settings,
|
||||
},
|
||||
conf: conf,
|
||||
logger: logger.WithPrefix("dns over tls: "),
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
streamMerger: streamMerger,
|
||||
restart: make(chan struct{}),
|
||||
start: make(chan struct{}),
|
||||
running: make(chan models.LoopStatus),
|
||||
stop: make(chan struct{}),
|
||||
stopped: make(chan struct{}),
|
||||
updateTicker: make(chan struct{}),
|
||||
timeNow: time.Now,
|
||||
timeSince: time.Since,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||
func (l *looper) Start() { l.start <- struct{}{} }
|
||||
func (l *looper) Stop() { l.stop <- struct{}{} }
|
||||
|
||||
func (l *looper) GetSettings() (settings settings.DNS) {
|
||||
l.settingsMutex.RLock()
|
||||
defer l.settingsMutex.RUnlock()
|
||||
return l.settings
|
||||
}
|
||||
|
||||
func (l *looper) SetSettings(settings settings.DNS) {
|
||||
l.settingsMutex.Lock()
|
||||
defer l.settingsMutex.Unlock()
|
||||
updatePeriodDiffers := l.settings.UpdatePeriod != settings.UpdatePeriod
|
||||
l.settings = settings
|
||||
l.settingsMutex.Unlock()
|
||||
if updatePeriodDiffers {
|
||||
l.updateTicker <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) isEnabled() bool {
|
||||
l.settingsMutex.RLock()
|
||||
defer l.settingsMutex.RUnlock()
|
||||
return l.settings.Enabled
|
||||
}
|
||||
|
||||
func (l *looper) setEnabled(enabled bool) {
|
||||
l.settingsMutex.Lock()
|
||||
defer l.settingsMutex.Unlock()
|
||||
l.settings.Enabled = enabled
|
||||
}
|
||||
|
||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
l.logger.Warn(err)
|
||||
l.logger.Info("attempting restart in 10 seconds")
|
||||
@@ -103,96 +75,42 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) waitForFirstStart(ctx context.Context, signalDNSReady func()) {
|
||||
for {
|
||||
select {
|
||||
case <-l.stop:
|
||||
l.setEnabled(false)
|
||||
l.logger.Info("not started yet")
|
||||
case <-l.restart:
|
||||
if l.isEnabled() {
|
||||
return
|
||||
}
|
||||
signalDNSReady()
|
||||
l.logger.Info("not restarting because disabled")
|
||||
case <-l.start:
|
||||
l.setEnabled(true)
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel context.CancelFunc) {
|
||||
if l.isEnabled() {
|
||||
return
|
||||
}
|
||||
for {
|
||||
// wait for a signal to re-enable
|
||||
select {
|
||||
case <-l.stop:
|
||||
l.logger.Info("already disabled")
|
||||
case <-l.restart:
|
||||
if !l.isEnabled() {
|
||||
l.logger.Info("not restarting because disabled")
|
||||
} else {
|
||||
return
|
||||
}
|
||||
case <-l.start:
|
||||
l.setEnabled(true)
|
||||
return
|
||||
case <-ctx.Done():
|
||||
unboundCancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) {
|
||||
defer wg.Done()
|
||||
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.waitForFirstStart(ctx, signalDNSReady)
|
||||
if ctx.Err() != nil {
|
||||
l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution?
|
||||
|
||||
select {
|
||||
case <-l.start:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
defer l.logger.Warn("loop exited")
|
||||
|
||||
var unboundCtx context.Context
|
||||
var unboundCancel context.CancelFunc = func() {}
|
||||
var waitError chan error
|
||||
triggeredRestart := false
|
||||
l.setEnabled(true)
|
||||
for ctx.Err() == nil {
|
||||
l.waitForSubsequentStart(ctx, unboundCancel)
|
||||
err := l.updateFiles(ctx)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
l.state.setStatusWithLock(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
}
|
||||
|
||||
crashed := false
|
||||
|
||||
for ctx.Err() == nil {
|
||||
settings := l.GetSettings()
|
||||
|
||||
// Setup
|
||||
if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if triggeredRestart {
|
||||
triggeredRestart = false
|
||||
unboundCancel()
|
||||
<-waitError
|
||||
close(waitError)
|
||||
}
|
||||
unboundCtx, unboundCancel = context.WithCancel(context.Background())
|
||||
unboundCtx, unboundCancel := context.WithCancel(context.Background())
|
||||
stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel)
|
||||
if err != nil {
|
||||
unboundCancel()
|
||||
if !crashed {
|
||||
l.running <- constants.Crashed
|
||||
}
|
||||
crashed = true
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.logAndWait(ctx, err)
|
||||
@@ -201,23 +119,37 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
|
||||
|
||||
// Started successfully
|
||||
go l.streamMerger.Merge(unboundCtx, stream, command.MergeName("unbound"))
|
||||
|
||||
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
||||
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound
|
||||
l.logger.Error(err)
|
||||
}
|
||||
|
||||
if err := l.conf.WaitForUnbound(); err != nil {
|
||||
if !crashed {
|
||||
l.running <- constants.Crashed
|
||||
crashed = true
|
||||
}
|
||||
unboundCancel()
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
waitError = make(chan error)
|
||||
|
||||
waitError := make(chan error)
|
||||
go func() {
|
||||
err := waitFn() // blocking
|
||||
waitError <- err
|
||||
}()
|
||||
|
||||
l.logger.Info("DNS over TLS is ready")
|
||||
if !crashed {
|
||||
l.running <- constants.Running
|
||||
crashed = false
|
||||
} else {
|
||||
l.state.setStatusWithLock(constants.Running)
|
||||
}
|
||||
signalDNSReady()
|
||||
|
||||
stayHere := true
|
||||
@@ -229,32 +161,29 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
|
||||
<-waitError
|
||||
close(waitError)
|
||||
return
|
||||
case <-l.restart: // triggered restart
|
||||
l.logger.Info("restarting")
|
||||
// unboundCancel occurs next loop run when the setup is complete
|
||||
triggeredRestart = true
|
||||
stayHere = false
|
||||
case <-l.start:
|
||||
l.logger.Info("already started")
|
||||
case <-l.stop:
|
||||
l.logger.Info("stopping")
|
||||
const fallback = false
|
||||
l.useUnencryptedDNS(fallback)
|
||||
unboundCancel()
|
||||
<-waitError
|
||||
close(waitError)
|
||||
l.setEnabled(false)
|
||||
l.stopped <- struct{}{}
|
||||
case <-l.start:
|
||||
l.logger.Info("starting")
|
||||
stayHere = false
|
||||
case err := <-waitError: // unexpected error
|
||||
close(waitError)
|
||||
unboundCancel()
|
||||
l.state.setStatusWithLock(constants.Crashed)
|
||||
const fallback = true
|
||||
l.useUnencryptedDNS(fallback)
|
||||
l.logAndWait(ctx, err)
|
||||
stayHere = false
|
||||
}
|
||||
}
|
||||
}
|
||||
close(waitError)
|
||||
unboundCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) useUnencryptedDNS(fallback bool) {
|
||||
settings := l.GetSettings()
|
||||
@@ -279,7 +208,11 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
|
||||
data := constants.DNSProviderMapping()[provider]
|
||||
for _, targetIP = range data.IPs {
|
||||
if targetIP.To4() != nil {
|
||||
if fallback {
|
||||
l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
|
||||
} else {
|
||||
l.logger.Info("using plaintext DNS at address %s", targetIP)
|
||||
}
|
||||
l.conf.UseDNSInternally(targetIP)
|
||||
if err := l.conf.UseDNSSystemWide(targetIP, settings.KeepNameserver); err != nil {
|
||||
l.logger.Error(err)
|
||||
@@ -314,7 +247,20 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||
return
|
||||
case <-timer.C:
|
||||
lastTick = l.timeNow()
|
||||
l.restart <- struct{}{}
|
||||
|
||||
status := l.GetStatus()
|
||||
if status == constants.Running {
|
||||
if err := l.updateFiles(ctx); err != nil {
|
||||
l.state.setStatusWithLock(constants.Crashed)
|
||||
l.logger.Error(err)
|
||||
l.logger.Warn("skipping Unbound restart due to failed files update")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = l.SetStatus(constants.Stopped)
|
||||
_, _ = l.SetStatus(constants.Running)
|
||||
|
||||
settings := l.GetSettings()
|
||||
timer.Reset(settings.UpdatePeriod)
|
||||
case <-l.updateTicker:
|
||||
@@ -337,3 +283,17 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) updateFiles(ctx context.Context) (err error) {
|
||||
if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil {
|
||||
return err
|
||||
}
|
||||
settings := l.GetSettings()
|
||||
if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
96
internal/dns/state.go
Normal file
96
internal/dns/state.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
)
|
||||
|
||||
type state struct {
|
||||
status models.LoopStatus
|
||||
settings settings.DNS
|
||||
statusMu sync.RWMutex
|
||||
settingsMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *state) setStatusWithLock(status models.LoopStatus) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.status = status
|
||||
}
|
||||
|
||||
func (l *looper) GetStatus() (status models.LoopStatus) {
|
||||
l.state.statusMu.RLock()
|
||||
defer l.state.statusMu.RUnlock()
|
||||
return l.state.status
|
||||
}
|
||||
|
||||
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
||||
l.state.statusMu.Lock()
|
||||
defer l.state.statusMu.Unlock()
|
||||
existingStatus := l.state.status
|
||||
|
||||
switch status {
|
||||
case constants.Running:
|
||||
switch existingStatus {
|
||||
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
|
||||
return fmt.Sprintf("already %s", existingStatus), nil
|
||||
}
|
||||
l.loopLock.Lock()
|
||||
defer l.loopLock.Unlock()
|
||||
l.state.status = constants.Starting
|
||||
l.state.statusMu.Unlock()
|
||||
l.start <- struct{}{}
|
||||
newStatus := <-l.running
|
||||
l.state.statusMu.Lock()
|
||||
l.state.status = newStatus
|
||||
return newStatus.String(), nil
|
||||
case constants.Stopped:
|
||||
switch existingStatus {
|
||||
case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed:
|
||||
return fmt.Sprintf("already %s", existingStatus), nil
|
||||
}
|
||||
l.loopLock.Lock()
|
||||
defer l.loopLock.Unlock()
|
||||
l.state.status = constants.Stopping
|
||||
l.state.statusMu.Unlock()
|
||||
l.stop <- struct{}{}
|
||||
<-l.stopped
|
||||
l.state.statusMu.Lock()
|
||||
l.state.status = constants.Stopped
|
||||
return status.String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("status %q can only be %q or %q",
|
||||
status, constants.Running, constants.Stopped)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) GetSettings() (settings settings.DNS) {
|
||||
l.state.settingsMu.RLock()
|
||||
defer l.state.settingsMu.RUnlock()
|
||||
return l.state.settings
|
||||
}
|
||||
|
||||
func (l *looper) SetSettings(settings settings.DNS) (outcome string) {
|
||||
l.state.settingsMu.Lock()
|
||||
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
|
||||
if settingsUnchanged {
|
||||
l.state.settingsMu.Unlock()
|
||||
return "settings left unchanged"
|
||||
}
|
||||
tempSettings := l.state.settings
|
||||
tempSettings.UpdatePeriod = settings.UpdatePeriod
|
||||
onlyUpdatePeriodChanged := reflect.DeepEqual(tempSettings, settings)
|
||||
l.state.settings = settings
|
||||
if onlyUpdatePeriodChanged {
|
||||
l.updateTicker <- struct{}{}
|
||||
return "update period changed"
|
||||
}
|
||||
_, _ = l.SetStatus(constants.Stopped)
|
||||
outcome, _ = l.SetStatus(constants.Running)
|
||||
return outcome
|
||||
}
|
||||
@@ -20,8 +20,14 @@ type (
|
||||
VPNProvider string
|
||||
// NetworkProtocol contains the network protocol to be used to communicate with the VPN servers.
|
||||
NetworkProtocol string
|
||||
// Loop status such as stopped or running.
|
||||
LoopStatus string
|
||||
)
|
||||
|
||||
func (ls LoopStatus) String() string {
|
||||
return string(ls)
|
||||
}
|
||||
|
||||
func marshalJSONString(s string) (data []byte, err error) {
|
||||
return []byte(fmt.Sprintf("%q", s)), nil
|
||||
}
|
||||
|
||||
@@ -3,5 +3,5 @@ package models
|
||||
type BuildInformation struct {
|
||||
Version string `json:"version"`
|
||||
Commit string `json:"commit"`
|
||||
BuildDate string `json:"buildDate"`
|
||||
BuildDate string `json:"build_date"`
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package models
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type OpenVPNConnection struct {
|
||||
IP net.IP
|
||||
|
||||
@@ -9,15 +9,15 @@ import (
|
||||
// ProviderSettings contains settings specific to a VPN provider.
|
||||
type ProviderSettings struct {
|
||||
Name VPNProvider `json:"name"`
|
||||
ServerSelection ServerSelection `json:"serverSelection"`
|
||||
ExtraConfigOptions ExtraConfigOptions `json:"extraConfig"`
|
||||
PortForwarding PortForwarding `json:"portForwarding"`
|
||||
ServerSelection ServerSelection `json:"server_selection"`
|
||||
ExtraConfigOptions ExtraConfigOptions `json:"extra_config"`
|
||||
PortForwarding PortForwarding `json:"port_forwarding"`
|
||||
}
|
||||
|
||||
type ServerSelection struct {
|
||||
// Common
|
||||
Protocol NetworkProtocol `json:"networkProtocol"`
|
||||
TargetIP net.IP `json:"targetIP,omitempty"`
|
||||
Protocol NetworkProtocol `json:"network_protocol"`
|
||||
TargetIP net.IP `json:"target_ip,omitempty"`
|
||||
|
||||
// Cyberghost, PIA, Surfshark, Windscribe, Vyprvpn, NordVPN
|
||||
Regions []string `json:"regions"`
|
||||
@@ -34,20 +34,20 @@ type ServerSelection struct {
|
||||
Owned bool `json:"owned"`
|
||||
|
||||
// Mullvad, Windscribe
|
||||
CustomPort uint16 `json:"customPort"`
|
||||
CustomPort uint16 `json:"custom_port"`
|
||||
|
||||
// NordVPN
|
||||
Numbers []uint16 `json:"numbers"`
|
||||
|
||||
// PIA
|
||||
EncryptionPreset string `json:"encryptionPreset"`
|
||||
EncryptionPreset string `json:"encryption_preset"`
|
||||
}
|
||||
|
||||
type ExtraConfigOptions struct {
|
||||
ClientCertificate string `json:"-"` // Cyberghost
|
||||
ClientKey string `json:"-"` // Cyberghost
|
||||
EncryptionPreset string `json:"encryptionPreset"` // PIA
|
||||
OpenVPNIPv6 bool `json:"openvpnIPv6"` // Mullvad
|
||||
EncryptionPreset string `json:"encryption_preset"` // PIA
|
||||
OpenVPNIPv6 bool `json:"openvpn_ipv6"` // Mullvad
|
||||
}
|
||||
|
||||
// PortForwarding contains settings for port forwarding.
|
||||
|
||||
@@ -20,23 +20,18 @@ import (
|
||||
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||
Restart()
|
||||
PortForward(vpnGatewayIP net.IP)
|
||||
GetStatus() (status models.LoopStatus)
|
||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||
GetSettings() (settings settings.OpenVPN)
|
||||
SetSettings(settings settings.OpenVPN)
|
||||
GetPortForwarded() (portForwarded uint16)
|
||||
SetAllServers(allServers models.AllServers)
|
||||
SetSettings(settings settings.OpenVPN) (outcome string)
|
||||
GetServers() (servers models.AllServers)
|
||||
SetServers(servers models.AllServers)
|
||||
GetPortForwarded() (port uint16)
|
||||
PortForward(vpnGatewayIP net.IP)
|
||||
}
|
||||
|
||||
type looper struct {
|
||||
// Variable parameters
|
||||
provider models.VPNProvider
|
||||
settings settings.OpenVPN
|
||||
settingsMutex sync.RWMutex
|
||||
portForwarded uint16
|
||||
portForwardedMutex sync.RWMutex
|
||||
allServers models.AllServers
|
||||
allServersMutex sync.RWMutex
|
||||
state state
|
||||
// Fixed parameters
|
||||
uid int
|
||||
gid int
|
||||
@@ -50,22 +45,27 @@ type looper struct {
|
||||
fileManager files.FileManager
|
||||
streamMerger command.StreamMerger
|
||||
cancel context.CancelFunc
|
||||
// Internal channels
|
||||
restart chan struct{}
|
||||
// Internal channels and locks
|
||||
loopLock sync.Mutex
|
||||
running chan models.LoopStatus
|
||||
stop, stopped chan struct{}
|
||||
start chan struct{}
|
||||
portForwardSignals chan net.IP
|
||||
}
|
||||
|
||||
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||
func NewLooper(settings settings.OpenVPN,
|
||||
uid, gid int, allServers models.AllServers,
|
||||
conf Configurator, fw firewall.Configurator, routing routing.Routing,
|
||||
logger logging.Logger, client *http.Client, fileManager files.FileManager,
|
||||
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
|
||||
return &looper{
|
||||
provider: provider,
|
||||
state: state{
|
||||
status: constants.Stopped,
|
||||
settings: settings,
|
||||
allServers: allServers,
|
||||
},
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
allServers: allServers,
|
||||
conf: conf,
|
||||
fw: fw,
|
||||
routing: routing,
|
||||
@@ -75,46 +75,29 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||
fileManager: fileManager,
|
||||
streamMerger: streamMerger,
|
||||
cancel: cancel,
|
||||
restart: make(chan struct{}),
|
||||
start: make(chan struct{}),
|
||||
running: make(chan models.LoopStatus),
|
||||
stop: make(chan struct{}),
|
||||
stopped: make(chan struct{}),
|
||||
portForwardSignals: make(chan net.IP),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||
func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
|
||||
|
||||
func (l *looper) GetSettings() (settings settings.OpenVPN) {
|
||||
l.settingsMutex.RLock()
|
||||
defer l.settingsMutex.RUnlock()
|
||||
return l.settings
|
||||
}
|
||||
|
||||
func (l *looper) SetSettings(settings settings.OpenVPN) {
|
||||
l.settingsMutex.Lock()
|
||||
defer l.settingsMutex.Unlock()
|
||||
l.settings = settings
|
||||
}
|
||||
|
||||
func (l *looper) SetAllServers(allServers models.AllServers) {
|
||||
l.allServersMutex.Lock()
|
||||
defer l.allServersMutex.Unlock()
|
||||
l.allServers = allServers
|
||||
}
|
||||
|
||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
crashed := false
|
||||
select {
|
||||
case <-l.restart:
|
||||
case <-l.start:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
defer l.logger.Warn("loop exited")
|
||||
|
||||
for ctx.Err() == nil {
|
||||
settings := l.GetSettings()
|
||||
l.allServersMutex.RLock()
|
||||
providerConf := provider.New(l.provider, l.allServers, time.Now)
|
||||
l.allServersMutex.RUnlock()
|
||||
settings, allServers := l.state.getSettingsAndServers()
|
||||
providerConf := provider.New(settings.Provider.Name, allServers, time.Now)
|
||||
connection, err := providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
@@ -155,6 +138,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||
if err != nil {
|
||||
openvpnCancel()
|
||||
if !crashed {
|
||||
l.running <- constants.Crashed
|
||||
crashed = true
|
||||
}
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
@@ -179,6 +166,16 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
err := waitFn() // blocking
|
||||
waitError <- err
|
||||
}()
|
||||
|
||||
if !crashed {
|
||||
l.running <- constants.Running
|
||||
crashed = false
|
||||
} else {
|
||||
l.state.setStatusWithLock(constants.Running)
|
||||
}
|
||||
|
||||
stayHere := true
|
||||
for stayHere {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
l.logger.Warn("context canceled: exiting loop")
|
||||
@@ -186,17 +183,25 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
<-waitError
|
||||
close(waitError)
|
||||
return
|
||||
case <-l.restart: // triggered restart
|
||||
l.logger.Info("restarting")
|
||||
case <-l.stop:
|
||||
l.logger.Info("stopping")
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
close(waitError)
|
||||
l.stopped <- struct{}{}
|
||||
case <-l.start:
|
||||
l.logger.Info("starting")
|
||||
stayHere = false
|
||||
case err := <-waitError: // unexpected error
|
||||
openvpnCancel()
|
||||
close(waitError)
|
||||
l.state.setStatusWithLock(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
crashed = true
|
||||
stayHere = false
|
||||
}
|
||||
}
|
||||
close(waitError)
|
||||
openvpnCancel() // just for the linter
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
@@ -218,24 +223,21 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
|
||||
providerConf provider.Provider, client *http.Client, gateway net.IP) {
|
||||
defer wg.Done()
|
||||
settings := l.GetSettings()
|
||||
l.state.portForwardedMu.RLock()
|
||||
settings := l.state.settings
|
||||
l.state.portForwardedMu.RUnlock()
|
||||
if !settings.Provider.PortForwarding.Enabled {
|
||||
return
|
||||
}
|
||||
syncState := func(port uint16) (pfFilepath models.Filepath) {
|
||||
l.portForwardedMutex.Lock()
|
||||
l.portForwarded = port
|
||||
l.portForwardedMutex.Unlock()
|
||||
settings := l.GetSettings()
|
||||
l.state.portForwardedMu.Lock()
|
||||
defer l.state.portForwardedMu.Unlock()
|
||||
l.state.portForwarded = port
|
||||
l.state.settingsMu.RLock()
|
||||
defer l.state.settingsMu.RUnlock()
|
||||
return settings.Provider.PortForwarding.Filepath
|
||||
}
|
||||
providerConf.PortForward(ctx,
|
||||
client, l.fileManager, l.pfLogger,
|
||||
gateway, l.fw, syncState)
|
||||
}
|
||||
|
||||
func (l *looper) GetPortForwarded() (portForwarded uint16) {
|
||||
l.portForwardedMutex.RLock()
|
||||
defer l.portForwardedMutex.RUnlock()
|
||||
return l.portForwarded
|
||||
}
|
||||
|
||||
121
internal/openvpn/state.go
Normal file
121
internal/openvpn/state.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
)
|
||||
|
||||
type state struct {
|
||||
status models.LoopStatus
|
||||
settings settings.OpenVPN
|
||||
allServers models.AllServers
|
||||
portForwarded uint16
|
||||
statusMu sync.RWMutex
|
||||
settingsMu sync.RWMutex
|
||||
allServersMu sync.RWMutex
|
||||
portForwardedMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *state) setStatusWithLock(status models.LoopStatus) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.status = status
|
||||
}
|
||||
|
||||
func (s *state) getSettingsAndServers() (settings settings.OpenVPN, allServers models.AllServers) {
|
||||
s.settingsMu.RLock()
|
||||
s.allServersMu.RLock()
|
||||
settings = s.settings
|
||||
allServers = s.allServers
|
||||
s.settingsMu.RLock()
|
||||
s.allServersMu.RLock()
|
||||
return settings, allServers
|
||||
}
|
||||
|
||||
func (l *looper) GetStatus() (status models.LoopStatus) {
|
||||
l.state.statusMu.RLock()
|
||||
defer l.state.statusMu.RUnlock()
|
||||
return l.state.status
|
||||
}
|
||||
|
||||
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
||||
l.state.statusMu.Lock()
|
||||
defer l.state.statusMu.Unlock()
|
||||
existingStatus := l.state.status
|
||||
|
||||
switch status {
|
||||
case constants.Running:
|
||||
switch existingStatus {
|
||||
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
|
||||
return fmt.Sprintf("already %s", existingStatus), nil
|
||||
}
|
||||
l.loopLock.Lock()
|
||||
defer l.loopLock.Unlock()
|
||||
l.state.status = constants.Starting
|
||||
l.state.statusMu.Unlock()
|
||||
l.start <- struct{}{}
|
||||
newStatus := <-l.running
|
||||
l.state.statusMu.Lock()
|
||||
l.state.status = newStatus
|
||||
return newStatus.String(), nil
|
||||
case constants.Stopped:
|
||||
switch existingStatus {
|
||||
case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed:
|
||||
return fmt.Sprintf("already %s", existingStatus), nil
|
||||
}
|
||||
l.loopLock.Lock()
|
||||
defer l.loopLock.Unlock()
|
||||
l.state.status = constants.Stopping
|
||||
l.state.statusMu.Unlock()
|
||||
l.stop <- struct{}{}
|
||||
<-l.stopped
|
||||
l.state.statusMu.Lock()
|
||||
l.state.status = constants.Stopped
|
||||
return status.String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("status %q can only be %q or %q",
|
||||
status, constants.Running, constants.Stopped)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) GetSettings() (settings settings.OpenVPN) {
|
||||
l.state.settingsMu.RLock()
|
||||
defer l.state.settingsMu.RUnlock()
|
||||
return l.state.settings
|
||||
}
|
||||
|
||||
func (l *looper) SetSettings(settings settings.OpenVPN) (outcome string) {
|
||||
l.state.settingsMu.Lock()
|
||||
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
|
||||
if settingsUnchanged {
|
||||
l.state.settingsMu.Unlock()
|
||||
return "settings left unchanged"
|
||||
}
|
||||
l.state.settings = settings
|
||||
_, _ = l.SetStatus(constants.Stopped)
|
||||
outcome, _ = l.SetStatus(constants.Running)
|
||||
return outcome
|
||||
}
|
||||
|
||||
func (l *looper) GetServers() (servers models.AllServers) {
|
||||
l.state.allServersMu.RLock()
|
||||
defer l.state.allServersMu.RUnlock()
|
||||
return l.state.allServers
|
||||
}
|
||||
|
||||
func (l *looper) SetServers(servers models.AllServers) {
|
||||
l.state.allServersMu.Lock()
|
||||
defer l.state.allServersMu.Unlock()
|
||||
l.state.allServers = servers
|
||||
}
|
||||
|
||||
func (l *looper) GetPortForwarded() (port uint16) {
|
||||
l.state.portForwardedMu.RLock()
|
||||
defer l.state.portForwardedMu.RUnlock()
|
||||
return port
|
||||
}
|
||||
@@ -130,8 +130,8 @@ func (r *reader) GetDNSOverTLSPrivateAddresses() (privateAddresses []string, err
|
||||
return privateAddresses, nil
|
||||
}
|
||||
|
||||
// GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using ipv6 DNS over TLS
|
||||
// servers from the environment variable DOT_IPV6.
|
||||
// GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using
|
||||
// ipv6 DNS over TLS from the environment variable DOT_IPV6.
|
||||
func (r *reader) GetDNSOverTLSIPv6() (ipv6 bool, err error) {
|
||||
return r.envParams.GetOnOff("DOT_IPV6", libparams.Default("off"))
|
||||
}
|
||||
|
||||
76
internal/server/dns.go
Normal file
76
internal/server/dns.go
Normal file
@@ -0,0 +1,76 @@
|
||||
//nolint:dupl
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/dns"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func newDNSHandler(looper dns.Looper, logger logging.Logger) http.Handler {
|
||||
return &dnsHandler{
|
||||
looper: looper,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
type dnsHandler struct {
|
||||
looper dns.Looper
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (h *dnsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/dns")
|
||||
switch r.RequestURI {
|
||||
case "/status": //nolint:goconst
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getStatus(w)
|
||||
case http.MethodPut:
|
||||
h.setStatus(w, r)
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *dnsHandler) getStatus(w http.ResponseWriter) {
|
||||
status := h.looper.GetStatus()
|
||||
encoder := json.NewEncoder(w)
|
||||
data := statusWrapper{Status: string(status)}
|
||||
if err := encoder.Encode(data); err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *dnsHandler) setStatus(w http.ResponseWriter, r *http.Request) {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var data statusWrapper
|
||||
if err := decoder.Decode(&data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
status, err := data.getStatus()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
outcome, err := h.looper.SetStatus(status)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
encoder := json.NewEncoder(w)
|
||||
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
|
||||
h.logger.Warn(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/dns"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
@@ -17,54 +17,33 @@ func newHandler(logger logging.Logger, logging bool,
|
||||
unboundLooper dns.Looper,
|
||||
updaterLooper updater.Looper,
|
||||
) http.Handler {
|
||||
return &handler{
|
||||
logger: logger,
|
||||
logging: logging,
|
||||
buildInfo: buildInfo,
|
||||
openvpnLooper: openvpnLooper,
|
||||
unboundLooper: unboundLooper,
|
||||
updaterLooper: updaterLooper,
|
||||
}
|
||||
handler := &handler{}
|
||||
|
||||
openvpn := newOpenvpnHandler(openvpnLooper, logger)
|
||||
dns := newDNSHandler(unboundLooper, logger)
|
||||
updater := newUpdaterHandler(updaterLooper, logger)
|
||||
|
||||
handler.v0 = newHandlerV0(logger, openvpnLooper, unboundLooper, updaterLooper)
|
||||
handler.v1 = newHandlerV1(logger, buildInfo, openvpn, dns, updater)
|
||||
|
||||
handlerWithLog := withLogMiddleware(handler, logger, logging)
|
||||
handler.setLogEnabled = handlerWithLog.setEnabled
|
||||
|
||||
return handlerWithLog
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
logger logging.Logger
|
||||
logging bool
|
||||
buildInfo models.BuildInformation
|
||||
openvpnLooper openvpn.Looper
|
||||
unboundLooper dns.Looper
|
||||
updaterLooper updater.Looper
|
||||
v0 http.Handler
|
||||
v1 http.Handler
|
||||
setLogEnabled func(enabled bool)
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
|
||||
if h.logging {
|
||||
h.logger.Info("HTTP %s %s", request.Method, request.RequestURI)
|
||||
}
|
||||
switch request.Method {
|
||||
case http.MethodGet:
|
||||
switch request.RequestURI {
|
||||
case "/version":
|
||||
h.getVersion(responseWriter)
|
||||
responseWriter.WriteHeader(http.StatusOK)
|
||||
case "/openvpn/actions/restart":
|
||||
h.openvpnLooper.Restart()
|
||||
responseWriter.WriteHeader(http.StatusOK)
|
||||
case "/unbound/actions/restart":
|
||||
h.unboundLooper.Restart()
|
||||
responseWriter.WriteHeader(http.StatusOK)
|
||||
case "/openvpn/portforwarded":
|
||||
h.getPortForwarded(responseWriter)
|
||||
case "/openvpn/settings":
|
||||
h.getOpenvpnSettings(responseWriter)
|
||||
case "/updater/restart":
|
||||
h.updaterLooper.Restart()
|
||||
responseWriter.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI)
|
||||
http.Error(responseWriter, errString, http.StatusBadRequest)
|
||||
}
|
||||
default:
|
||||
errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI)
|
||||
http.Error(responseWriter, errString, http.StatusBadRequest)
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.RequestURI = strings.TrimSuffix(r.RequestURI, "/")
|
||||
if !strings.HasPrefix(r.RequestURI, "/v1/") && r.RequestURI != "/v1" {
|
||||
h.v0.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/v1")
|
||||
h.v1.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
69
internal/server/handlerv0.go
Normal file
69
internal/server/handlerv0.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/dns"
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func newHandlerV0(logger logging.Logger,
|
||||
openvpn openvpn.Looper, dns dns.Looper, updater updater.Looper) http.Handler {
|
||||
return &handlerV0{
|
||||
logger: logger,
|
||||
openvpn: openvpn,
|
||||
dns: dns,
|
||||
updater: updater,
|
||||
}
|
||||
}
|
||||
|
||||
type handlerV0 struct {
|
||||
logger logging.Logger
|
||||
openvpn openvpn.Looper
|
||||
dns dns.Looper
|
||||
updater updater.Looper
|
||||
}
|
||||
|
||||
func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "unversioned API: only supports GET method", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
switch r.RequestURI {
|
||||
case "/version":
|
||||
http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect)
|
||||
case "/openvpn/actions/restart":
|
||||
outcome, _ := h.openvpn.SetStatus(constants.Stopped)
|
||||
h.logger.Info("openvpn: %s", outcome)
|
||||
outcome, _ = h.openvpn.SetStatus(constants.Running)
|
||||
h.logger.Info("openvpn: %s", outcome)
|
||||
if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil {
|
||||
h.logger.Warn(err)
|
||||
}
|
||||
case "/unbound/actions/restart":
|
||||
outcome, _ := h.dns.SetStatus(constants.Stopped)
|
||||
h.logger.Info("dns: %s", outcome)
|
||||
outcome, _ = h.dns.SetStatus(constants.Running)
|
||||
h.logger.Info("dns: %s", outcome)
|
||||
if _, err := w.Write([]byte("dns restarted, please consider using the /v1/ API in the future.")); err != nil {
|
||||
h.logger.Warn(err)
|
||||
}
|
||||
case "/openvpn/portforwarded":
|
||||
http.Redirect(w, r, "/v1/openvpn/portforwarded", http.StatusPermanentRedirect)
|
||||
case "/openvpn/settings":
|
||||
http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect)
|
||||
case "/updater/restart":
|
||||
outcome, _ := h.updater.SetStatus(constants.Stopped)
|
||||
h.logger.Info("updater: %s", outcome)
|
||||
outcome, _ = h.updater.SetStatus(constants.Running)
|
||||
h.logger.Info("updater: %s", outcome)
|
||||
if _, err := w.Write([]byte("updater restarted, please consider using the /v1/ API in the future.")); err != nil {
|
||||
h.logger.Warn(err)
|
||||
}
|
||||
default:
|
||||
http.Error(w, "unversioned API: requested URI not found", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
58
internal/server/handlerv1.go
Normal file
58
internal/server/handlerv1.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func newHandlerV1(logger logging.Logger, buildInfo models.BuildInformation,
|
||||
openvpn, dns, updater http.Handler) http.Handler {
|
||||
return &handlerV1{
|
||||
logger: logger,
|
||||
buildInfo: buildInfo,
|
||||
openvpn: openvpn,
|
||||
dns: dns,
|
||||
updater: updater,
|
||||
}
|
||||
}
|
||||
|
||||
type handlerV1 struct {
|
||||
logger logging.Logger
|
||||
buildInfo models.BuildInformation
|
||||
openvpn http.Handler
|
||||
dns http.Handler
|
||||
updater http.Handler
|
||||
}
|
||||
|
||||
func (h *handlerV1) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.RequestURI == "/version" && r.Method == http.MethodGet:
|
||||
h.getVersion(w)
|
||||
case strings.HasPrefix(r.RequestURI, "/openvpn"):
|
||||
h.openvpn.ServeHTTP(w, r)
|
||||
case strings.HasPrefix(r.RequestURI, "/dns"):
|
||||
h.dns.ServeHTTP(w, r)
|
||||
case strings.HasPrefix(r.RequestURI, "/updater"):
|
||||
h.updater.ServeHTTP(w, r)
|
||||
default:
|
||||
errString := fmt.Sprintf("%s %s not found", r.Method, r.RequestURI)
|
||||
http.Error(w, errString, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handlerV1) getVersion(w http.ResponseWriter) {
|
||||
data, err := json.Marshal(h.buildInfo)
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
h.logger.Warn(err)
|
||||
}
|
||||
}
|
||||
75
internal/server/log.go
Normal file
75
internal/server/log.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func withLogMiddleware(childHandler http.Handler, logger logging.Logger, enabled bool) *logMiddleware {
|
||||
return &logMiddleware{
|
||||
childHandler: childHandler,
|
||||
logger: logger,
|
||||
timeNow: time.Now,
|
||||
enabled: enabled,
|
||||
}
|
||||
}
|
||||
|
||||
type logMiddleware struct {
|
||||
childHandler http.Handler
|
||||
logger logging.Logger
|
||||
timeNow func() time.Time
|
||||
enabled bool
|
||||
enabledMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.isEnabled() {
|
||||
m.childHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
tStart := m.timeNow()
|
||||
statefulWriter := &statefulResponseWriter{httpWriter: w}
|
||||
m.childHandler.ServeHTTP(statefulWriter, r)
|
||||
duration := m.timeNow().Sub(tStart)
|
||||
m.logger.Info("%d %s %s wrote %dB to %s in %s",
|
||||
statefulWriter.statusCode, r.Method, r.RequestURI, statefulWriter.length, r.RemoteAddr, duration)
|
||||
}
|
||||
|
||||
func (m *logMiddleware) setEnabled(enabled bool) {
|
||||
m.enabledMu.Lock()
|
||||
defer m.enabledMu.Unlock()
|
||||
m.enabled = enabled
|
||||
}
|
||||
|
||||
func (m *logMiddleware) isEnabled() (enabled bool) {
|
||||
m.enabledMu.RLock()
|
||||
defer m.enabledMu.RUnlock()
|
||||
return m.enabled
|
||||
}
|
||||
|
||||
type statefulResponseWriter struct {
|
||||
httpWriter http.ResponseWriter
|
||||
statusCode int
|
||||
length int
|
||||
}
|
||||
|
||||
func (w *statefulResponseWriter) Write(b []byte) (n int, err error) {
|
||||
n, err = w.httpWriter.Write(b)
|
||||
if w.statusCode == 0 {
|
||||
w.statusCode = http.StatusOK
|
||||
}
|
||||
w.length += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *statefulResponseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.httpWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *statefulResponseWriter) Header() http.Header {
|
||||
return w.httpWriter.Header()
|
||||
}
|
||||
@@ -3,34 +3,110 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func (h *handler) getPortForwarded(w http.ResponseWriter) {
|
||||
port := h.openvpnLooper.GetPortForwarded()
|
||||
data, err := json.Marshal(struct {
|
||||
Port uint16 `json:"port"`
|
||||
}{port})
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
func newOpenvpnHandler(looper openvpn.Looper, logger logging.Logger) http.Handler {
|
||||
return &openvpnHandler{
|
||||
looper: looper,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) getOpenvpnSettings(w http.ResponseWriter) {
|
||||
settings := h.openvpnLooper.GetSettings()
|
||||
data, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
type openvpnHandler struct {
|
||||
looper openvpn.Looper
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (h *openvpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/openvpn")
|
||||
switch r.RequestURI {
|
||||
case "/status":
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getStatus(w)
|
||||
case http.MethodPut:
|
||||
h.setStatus(w, r)
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
case "/settings":
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getSettings(w)
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
case "/portforwarded":
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getPortForwarded(w)
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *openvpnHandler) getStatus(w http.ResponseWriter) {
|
||||
status := h.looper.GetStatus()
|
||||
encoder := json.NewEncoder(w)
|
||||
data := statusWrapper{Status: string(status)}
|
||||
if err := encoder.Encode(data); err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
}
|
||||
|
||||
func (h *openvpnHandler) setStatus(w http.ResponseWriter, r *http.Request) { //nolint:dupl
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var data statusWrapper
|
||||
if err := decoder.Decode(&data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
status, err := data.getStatus()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
outcome, err := h.looper.SetStatus(status)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
encoder := json.NewEncoder(w)
|
||||
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
|
||||
h.logger.Warn(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
|
||||
settings := h.looper.GetSettings()
|
||||
settings.User = "redacted"
|
||||
settings.Password = "redacted"
|
||||
encoder := json.NewEncoder(w)
|
||||
if err := encoder.Encode(settings); err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
|
||||
port := h.looper.GetPortForwarded()
|
||||
encoder := json.NewEncoder(w)
|
||||
data := portWrapper{Port: port}
|
||||
if err := encoder.Encode(data); err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,8 @@ type server struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func New(address string, logging bool, logger logging.Logger, buildInfo models.BuildInformation,
|
||||
func New(address string, logging bool, logger logging.Logger,
|
||||
buildInfo models.BuildInformation,
|
||||
openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper) Server {
|
||||
serverLogger := logger.WithPrefix("http server: ")
|
||||
handler := newHandler(serverLogger, logging, buildInfo, openvpnLooper, unboundLooper, updaterLooper)
|
||||
|
||||
78
internal/server/updater.go
Normal file
78
internal/server/updater.go
Normal file
@@ -0,0 +1,78 @@
|
||||
//nolint:dupl
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func newUpdaterHandler(
|
||||
looper updater.Looper,
|
||||
logger logging.Logger) http.Handler {
|
||||
return &updaterHandler{
|
||||
looper: looper,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
type updaterHandler struct {
|
||||
looper updater.Looper
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (h *updaterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/updater")
|
||||
switch r.RequestURI {
|
||||
case "/status":
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getStatus(w)
|
||||
case http.MethodPut:
|
||||
h.setStatus(w, r)
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *updaterHandler) getStatus(w http.ResponseWriter) {
|
||||
status := h.looper.GetStatus()
|
||||
encoder := json.NewEncoder(w)
|
||||
data := statusWrapper{Status: string(status)}
|
||||
if err := encoder.Encode(data); err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *updaterHandler) setStatus(w http.ResponseWriter, r *http.Request) {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
var data statusWrapper
|
||||
if err := decoder.Decode(&data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
status, err := data.getStatus()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
outcome, err := h.looper.SetStatus(status)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
encoder := json.NewEncoder(w)
|
||||
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
|
||||
h.logger.Warn(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (h *handler) getVersion(w http.ResponseWriter) {
|
||||
data, err := json.Marshal(h.buildInfo)
|
||||
if err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
h.logger.Warn(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
32
internal/server/wrappers.go
Normal file
32
internal/server/wrappers.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type statusWrapper struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (sw *statusWrapper) getStatus() (status models.LoopStatus, err error) {
|
||||
status = models.LoopStatus(sw.Status)
|
||||
switch status {
|
||||
case constants.Stopped, constants.Running:
|
||||
return status, nil
|
||||
default:
|
||||
return "", fmt.Errorf(
|
||||
"invalid status %q: possible values are: %s, %s",
|
||||
sw.Status, constants.Stopped, constants.Running)
|
||||
}
|
||||
}
|
||||
|
||||
type portWrapper struct {
|
||||
Port uint16 `json:"port"`
|
||||
}
|
||||
|
||||
type outcomeWrapper struct {
|
||||
Outcome string `json:"outcome"`
|
||||
}
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
// OpenVPN contains settings to configure the OpenVPN client.
|
||||
type OpenVPN struct {
|
||||
User string `json:"user"`
|
||||
Password string `json:"-"`
|
||||
Password string `json:"password"`
|
||||
Verbosity int `json:"verbosity"`
|
||||
Root bool `json:"runAsRoot"`
|
||||
Root bool `json:"run_as_root"`
|
||||
Cipher string `json:"cipher"`
|
||||
Auth string `json:"auth"`
|
||||
Provider models.ProviderSettings `json:"provider"`
|
||||
|
||||
@@ -20,7 +20,7 @@ func Test_OpenVPN_JSON(t *testing.T) {
|
||||
data, err := json.Marshal(in)
|
||||
require.NoError(t, err)
|
||||
//nolint:lll
|
||||
assert.Equal(t, `{"user":"","verbosity":0,"runAsRoot":true,"cipher":"","auth":"","provider":{"name":"name","serverSelection":{"networkProtocol":"","regions":null,"group":"","countries":null,"cities":null,"hostnames":null,"isps":null,"owned":false,"customPort":0,"numbers":null,"encryptionPreset":""},"extraConfig":{"encryptionPreset":"","openvpnIPv6":false},"portForwarding":{"enabled":false,"filepath":""}}}`, string(data))
|
||||
assert.Equal(t, `{"user":"","password":"","verbosity":0,"run_as_root":true,"cipher":"","auth":"","provider":{"name":"name","server_selection":{"network_protocol":"","regions":null,"group":"","countries":null,"cities":null,"hostnames":null,"isps":null,"owned":false,"custom_port":0,"numbers":null,"encryption_preset":""},"extra_config":{"encryption_preset":"","openvpn_ipv6":false},"port_forwarding":{"enabled":false,"filepath":""}}}`, string(data))
|
||||
var out OpenVPN
|
||||
err = json.Unmarshal(data, &out)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -24,7 +23,7 @@ type Settings struct {
|
||||
HTTPProxy HTTPProxy
|
||||
ShadowSocks ShadowSocks
|
||||
PublicIPPeriod time.Duration
|
||||
UpdaterPeriod time.Duration
|
||||
Updater Updater
|
||||
VersionInformation bool
|
||||
ControlServer ControlServer
|
||||
}
|
||||
@@ -34,10 +33,6 @@ func (s *Settings) String() string {
|
||||
if s.VersionInformation {
|
||||
versionInformation = enabled
|
||||
}
|
||||
updaterLine := "Updater: disabled"
|
||||
if s.UpdaterPeriod > 0 {
|
||||
updaterLine = fmt.Sprintf("Updater period: %s", s.UpdaterPeriod)
|
||||
}
|
||||
return strings.Join([]string{
|
||||
"Settings summary below:",
|
||||
s.OpenVPN.String(),
|
||||
@@ -47,9 +42,9 @@ func (s *Settings) String() string {
|
||||
s.HTTPProxy.String(),
|
||||
s.ShadowSocks.String(),
|
||||
s.ControlServer.String(),
|
||||
s.Updater.String(),
|
||||
"Public IP check period: " + s.PublicIPPeriod.String(), // TODO print disabled if 0
|
||||
"Version information: " + versionInformation,
|
||||
updaterLine,
|
||||
"", // new line at the end
|
||||
}, "\n")
|
||||
}
|
||||
@@ -93,7 +88,7 @@ func GetAllSettings(paramsReader params.Reader) (settings Settings, err error) {
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.UpdaterPeriod, err = paramsReader.GetUpdaterPeriod()
|
||||
settings.Updater, err = GetUpdaterSettings(paramsReader)
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
|
||||
59
internal/settings/updater.go
Normal file
59
internal/settings/updater.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/params"
|
||||
)
|
||||
|
||||
type Updater struct {
|
||||
Period time.Duration `json:"period"`
|
||||
DNSAddress string `json:"dns_address"`
|
||||
Cyberghost bool `json:"cyberghost"`
|
||||
Mullvad bool `json:"mullvad"`
|
||||
Nordvpn bool `json:"nordvpn"`
|
||||
PIA bool `json:"pia"`
|
||||
Privado bool `json:"privado"`
|
||||
Purevpn bool `json:"purevpn"`
|
||||
Surfshark bool `json:"surfshark"`
|
||||
Vyprvpn bool `json:"vyprvpn"`
|
||||
Windscribe bool `json:"windscribe"`
|
||||
// The two below should be used in CLI mode only
|
||||
Stdout bool `json:"-"` // in order to update constants file (maintainer side)
|
||||
CLI bool `json:"-"`
|
||||
}
|
||||
|
||||
// GetUpdaterSettings obtains the server updater settings using the params functions.
|
||||
func GetUpdaterSettings(paramsReader params.Reader) (settings Updater, err error) {
|
||||
settings = Updater{
|
||||
Cyberghost: true,
|
||||
Mullvad: true,
|
||||
Nordvpn: true,
|
||||
PIA: true,
|
||||
Purevpn: true,
|
||||
Surfshark: true,
|
||||
Vyprvpn: true,
|
||||
Windscribe: true,
|
||||
Stdout: false,
|
||||
CLI: false,
|
||||
DNSAddress: "127.0.0.1",
|
||||
}
|
||||
settings.Period, err = paramsReader.GetUpdaterPeriod()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func (s *Updater) String() string {
|
||||
if s.Period == 0 {
|
||||
return "Server updater settings: disabled"
|
||||
}
|
||||
settingsList := []string{
|
||||
"Server updater settings:",
|
||||
fmt.Sprintf("Period: %s", s.Period),
|
||||
}
|
||||
return strings.Join(settingsList, "\n|--")
|
||||
}
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
@@ -14,60 +16,54 @@ import (
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
||||
Restart()
|
||||
Stop()
|
||||
GetPeriod() (period time.Duration)
|
||||
SetPeriod(period time.Duration)
|
||||
GetStatus() (status models.LoopStatus)
|
||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||
GetSettings() (settings settings.Updater)
|
||||
SetSettings(settings settings.Updater) (outcome string)
|
||||
}
|
||||
|
||||
type looper struct {
|
||||
period time.Duration
|
||||
periodMutex sync.RWMutex
|
||||
state state
|
||||
// Objects
|
||||
updater Updater
|
||||
storage storage.Storage
|
||||
setAllServers func(allServers models.AllServers)
|
||||
logger logging.Logger
|
||||
restart chan struct{}
|
||||
// Internal channels and locks
|
||||
loopLock sync.Mutex
|
||||
start chan struct{}
|
||||
running chan models.LoopStatus
|
||||
stop chan struct{}
|
||||
stopped chan struct{}
|
||||
updateTicker chan struct{}
|
||||
// Mock functions
|
||||
timeNow func() time.Time
|
||||
timeSince func(time.Time) time.Duration
|
||||
}
|
||||
|
||||
func NewLooper(options Options, period time.Duration, currentServers models.AllServers,
|
||||
func NewLooper(settings settings.Updater, currentServers models.AllServers,
|
||||
storage storage.Storage, setAllServers func(allServers models.AllServers),
|
||||
client *http.Client, logger logging.Logger) Looper {
|
||||
loggerWithPrefix := logger.WithPrefix("updater: ")
|
||||
return &looper{
|
||||
period: period,
|
||||
updater: New(options, client, currentServers, loggerWithPrefix),
|
||||
state: state{
|
||||
status: constants.Stopped,
|
||||
settings: settings,
|
||||
},
|
||||
updater: New(settings, client, currentServers, loggerWithPrefix),
|
||||
storage: storage,
|
||||
setAllServers: setAllServers,
|
||||
logger: loggerWithPrefix,
|
||||
restart: make(chan struct{}),
|
||||
start: make(chan struct{}),
|
||||
running: make(chan models.LoopStatus),
|
||||
stop: make(chan struct{}),
|
||||
stopped: make(chan struct{}),
|
||||
updateTicker: make(chan struct{}),
|
||||
timeNow: time.Now,
|
||||
timeSince: time.Since,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||
func (l *looper) Stop() { l.stop <- struct{}{} }
|
||||
|
||||
func (l *looper) GetPeriod() (period time.Duration) {
|
||||
l.periodMutex.RLock()
|
||||
defer l.periodMutex.RUnlock()
|
||||
return l.period
|
||||
}
|
||||
|
||||
func (l *looper) SetPeriod(period time.Duration) {
|
||||
l.periodMutex.Lock()
|
||||
l.period = period
|
||||
l.periodMutex.Unlock()
|
||||
l.updateTicker <- struct{}{}
|
||||
}
|
||||
|
||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
l.logger.Error(err)
|
||||
const waitTime = 5 * time.Minute
|
||||
@@ -84,53 +80,72 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
|
||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
crashed := false
|
||||
select {
|
||||
case <-l.restart:
|
||||
l.logger.Info("starting...")
|
||||
case <-l.start:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
defer l.logger.Warn("loop exited")
|
||||
|
||||
enabled := true
|
||||
|
||||
for ctx.Err() == nil {
|
||||
for !enabled {
|
||||
// wait for a signal to re-enable
|
||||
select {
|
||||
case <-l.stop:
|
||||
l.logger.Info("already disabled")
|
||||
case <-l.restart:
|
||||
enabled = true
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Enabled and has a period set
|
||||
|
||||
servers, err := l.updater.UpdateServers(ctx)
|
||||
updateCtx, updateCancel := context.WithCancel(ctx)
|
||||
defer updateCancel()
|
||||
serversCh := make(chan models.AllServers)
|
||||
errorCh := make(chan error)
|
||||
go func() {
|
||||
servers, err := l.updater.UpdateServers(updateCtx)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
serversCh <- servers
|
||||
}()
|
||||
|
||||
if !crashed {
|
||||
l.running <- constants.Running
|
||||
crashed = false
|
||||
} else {
|
||||
l.state.setStatusWithLock(constants.Running)
|
||||
}
|
||||
|
||||
stayHere := true
|
||||
for stayHere {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
l.logger.Warn("context canceled: exiting loop")
|
||||
updateCancel()
|
||||
close(errorCh)
|
||||
return
|
||||
case <-l.start:
|
||||
l.logger.Info("starting")
|
||||
updateCancel()
|
||||
stayHere = false
|
||||
case <-l.stop:
|
||||
l.logger.Info("stopping")
|
||||
updateCancel()
|
||||
<-errorCh
|
||||
l.stopped <- struct{}{}
|
||||
case servers := <-serversCh:
|
||||
updateCancel()
|
||||
close(serversCh)
|
||||
l.setAllServers(servers)
|
||||
if err := l.storage.FlushToFile(servers); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
l.state.setStatusWithLock(constants.Completed)
|
||||
l.logger.Info("Updated servers information")
|
||||
|
||||
select {
|
||||
case <-l.restart: // triggered restart
|
||||
case <-l.stop:
|
||||
enabled = false
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-errorCh:
|
||||
updateCancel()
|
||||
close(serversCh)
|
||||
l.state.setStatusWithLock(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
crashed = true
|
||||
stayHere = false
|
||||
}
|
||||
}
|
||||
close(errorCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||
@@ -138,7 +153,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||
timer := time.NewTimer(time.Hour)
|
||||
timer.Stop()
|
||||
timerIsStopped := true
|
||||
if period := l.GetPeriod(); period > 0 {
|
||||
if period := l.GetSettings().Period; period > 0 {
|
||||
timerIsStopped = false
|
||||
timer.Reset(period)
|
||||
}
|
||||
@@ -152,14 +167,14 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||
return
|
||||
case <-timer.C:
|
||||
lastTick = l.timeNow()
|
||||
l.restart <- struct{}{}
|
||||
timer.Reset(l.GetPeriod())
|
||||
l.start <- struct{}{}
|
||||
timer.Reset(l.GetSettings().Period)
|
||||
case <-l.updateTicker:
|
||||
if !timerIsStopped && !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
timerIsStopped = true
|
||||
period := l.GetPeriod()
|
||||
period := l.GetSettings().Period
|
||||
if period == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package updater
|
||||
|
||||
type Options struct {
|
||||
Cyberghost bool
|
||||
Mullvad bool
|
||||
Nordvpn bool
|
||||
PIA bool
|
||||
Privado bool
|
||||
Purevpn bool
|
||||
Surfshark bool
|
||||
Vyprvpn bool
|
||||
Windscribe bool
|
||||
Stdout bool // in order to update constants file (maintainer side)
|
||||
CLI bool
|
||||
DNSAddress string
|
||||
}
|
||||
|
||||
func NewOptions(dnsAddress string) Options {
|
||||
return Options{
|
||||
Cyberghost: true,
|
||||
Mullvad: true,
|
||||
Nordvpn: true,
|
||||
PIA: true,
|
||||
Purevpn: true,
|
||||
Surfshark: true,
|
||||
Vyprvpn: true,
|
||||
Windscribe: true,
|
||||
Stdout: false,
|
||||
CLI: false,
|
||||
DNSAddress: dnsAddress,
|
||||
}
|
||||
}
|
||||
88
internal/updater/state.go
Normal file
88
internal/updater/state.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
)
|
||||
|
||||
type state struct {
|
||||
status models.LoopStatus
|
||||
settings settings.Updater
|
||||
statusMu sync.RWMutex
|
||||
periodMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *state) setStatusWithLock(status models.LoopStatus) {
|
||||
s.statusMu.Lock()
|
||||
defer s.statusMu.Unlock()
|
||||
s.status = status
|
||||
}
|
||||
|
||||
func (l *looper) GetStatus() (status models.LoopStatus) {
|
||||
l.state.statusMu.RLock()
|
||||
defer l.state.statusMu.RUnlock()
|
||||
return l.state.status
|
||||
}
|
||||
|
||||
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
||||
l.state.statusMu.Lock()
|
||||
defer l.state.statusMu.Unlock()
|
||||
existingStatus := l.state.status
|
||||
|
||||
switch status {
|
||||
case constants.Running:
|
||||
switch existingStatus {
|
||||
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
|
||||
return fmt.Sprintf("already %s", existingStatus), nil
|
||||
}
|
||||
l.loopLock.Lock()
|
||||
defer l.loopLock.Unlock()
|
||||
l.state.status = constants.Starting
|
||||
l.state.statusMu.Unlock()
|
||||
l.start <- struct{}{}
|
||||
newStatus := <-l.running
|
||||
l.state.statusMu.Lock()
|
||||
l.state.status = newStatus
|
||||
return newStatus.String(), nil
|
||||
case constants.Stopped:
|
||||
switch existingStatus {
|
||||
case constants.Stopped, constants.Stopping, constants.Starting, constants.Crashed:
|
||||
return fmt.Sprintf("already %s", existingStatus), nil
|
||||
}
|
||||
l.loopLock.Lock()
|
||||
defer l.loopLock.Unlock()
|
||||
l.state.status = constants.Stopping
|
||||
l.state.statusMu.Unlock()
|
||||
l.stop <- struct{}{}
|
||||
<-l.stopped
|
||||
l.state.statusMu.Lock()
|
||||
l.state.status = status
|
||||
return status.String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("status %q can only be %q or %q",
|
||||
status, constants.Running, constants.Stopped)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) GetSettings() (settings settings.Updater) {
|
||||
l.state.periodMu.RLock()
|
||||
defer l.state.periodMu.RUnlock()
|
||||
return l.state.settings
|
||||
}
|
||||
|
||||
func (l *looper) SetSettings(settings settings.Updater) (outcome string) {
|
||||
l.state.periodMu.Lock()
|
||||
defer l.state.periodMu.Unlock()
|
||||
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
|
||||
if settingsUnchanged {
|
||||
return "settings left unchanged"
|
||||
}
|
||||
l.state.settings = settings
|
||||
l.updateTicker <- struct{}{}
|
||||
return "settings updated"
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
@@ -17,7 +18,7 @@ type Updater interface {
|
||||
|
||||
type updater struct {
|
||||
// configuration
|
||||
options Options
|
||||
options settings.Updater
|
||||
|
||||
// state
|
||||
servers models.AllServers
|
||||
@@ -30,11 +31,12 @@ type updater struct {
|
||||
client network.Client
|
||||
}
|
||||
|
||||
func New(options Options, httpClient *http.Client, currentServers models.AllServers, logger logging.Logger) Updater {
|
||||
if len(options.DNSAddress) == 0 {
|
||||
options.DNSAddress = "1.1.1.1"
|
||||
func New(settings settings.Updater, httpClient *http.Client,
|
||||
currentServers models.AllServers, logger logging.Logger) Updater {
|
||||
if len(settings.DNSAddress) == 0 {
|
||||
settings.DNSAddress = "1.1.1.1"
|
||||
}
|
||||
resolver := newResolver(options.DNSAddress)
|
||||
resolver := newResolver(settings.DNSAddress)
|
||||
const clientTimeout = 10 * time.Second
|
||||
return &updater{
|
||||
logger: logger,
|
||||
@@ -42,7 +44,7 @@ func New(options Options, httpClient *http.Client, currentServers models.AllServ
|
||||
println: func(s string) { fmt.Println(s) },
|
||||
lookupIP: newLookupIP(resolver),
|
||||
client: network.NewClient(clientTimeout),
|
||||
options: options,
|
||||
options: settings,
|
||||
servers: currentServers,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user