Loops and HTTP control server rework (#308)
- CRUD REST HTTP server - `/v1` HTTP server prefix - Retrocompatible with older routes (redirects to v1 or handles the requests directly) - DNS, Updater and Openvpn refactored to have a REST-like state with new methods to change their states synchronously - Openvpn, Unbound and Updater status, see #287
This commit is contained in:
@@ -217,15 +217,14 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
|||||||
|
|
||||||
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
|
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
|
||||||
|
|
||||||
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers,
|
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, uid, gid, allServers,
|
||||||
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
|
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
// wait for restartOpenvpn
|
// wait for restartOpenvpn
|
||||||
go openvpnLooper.Run(ctx, wg)
|
go openvpnLooper.Run(ctx, wg)
|
||||||
|
|
||||||
updaterOptions := updater.NewOptions("127.0.0.1")
|
updaterLooper := updater.NewLooper(allSettings.Updater,
|
||||||
updaterLooper := updater.NewLooper(updaterOptions, allSettings.UpdaterPeriod,
|
allServers, storage, openvpnLooper.SetServers, httpClient, logger)
|
||||||
allServers, storage, openvpnLooper.SetAllServers, httpClient, logger)
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker
|
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker
|
||||||
go updaterLooper.Run(ctx, wg)
|
go updaterLooper.Run(ctx, wg)
|
||||||
@@ -276,8 +275,9 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go healthcheckServer.Run(ctx, wg)
|
go healthcheckServer.Run(ctx, wg)
|
||||||
|
|
||||||
// Start openvpn for the first time
|
// Start openvpn for the first time in a blocking call
|
||||||
openvpnLooper.Restart()
|
// until openvpn is launched
|
||||||
|
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
|
||||||
|
|
||||||
signalsCh := make(chan os.Signal, 1)
|
signalsCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(signalsCh,
|
signal.Notify(signalsCh,
|
||||||
@@ -401,7 +401,7 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
|
|||||||
tickerWg.Wait()
|
tickerWg.Wait()
|
||||||
return
|
return
|
||||||
case <-tunnelReadyCh: // blocks until openvpn is connected
|
case <-tunnelReadyCh: // blocks until openvpn is connected
|
||||||
unboundLooper.Restart()
|
_, _ = unboundLooper.SetStatus(constants.Running)
|
||||||
restartTickerCancel() // stop previous restart tickers
|
restartTickerCancel() // stop previous restart tickers
|
||||||
tickerWg.Wait()
|
tickerWg.Wait()
|
||||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ func OpenvpnConfig() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Update(args []string) error {
|
func Update(args []string) error {
|
||||||
options := updater.Options{CLI: true}
|
options := settings.Updater{CLI: true}
|
||||||
var flushToFile bool
|
var flushToFile bool
|
||||||
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
|
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
|
||||||
flagSet.BoolVar(&flushToFile, "file", false, "Write results to /gluetun/servers.json (for end users)")
|
flagSet.BoolVar(&flushToFile, "file", false, "Write results to /gluetun/servers.json (for end users)")
|
||||||
|
|||||||
14
internal/constants/status.go
Normal file
14
internal/constants/status.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package constants
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Starting models.LoopStatus = "starting"
|
||||||
|
Running models.LoopStatus = "running"
|
||||||
|
Stopping models.LoopStatus = "stopping"
|
||||||
|
Stopped models.LoopStatus = "stopped"
|
||||||
|
Crashed models.LoopStatus = "crashed"
|
||||||
|
Completed models.LoopStatus = "completed"
|
||||||
|
)
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/settings"
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
"github.com/qdm12/golibs/command"
|
"github.com/qdm12/golibs/command"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
@@ -15,80 +16,51 @@ import (
|
|||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func())
|
Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func())
|
||||||
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
||||||
Restart()
|
GetStatus() (status models.LoopStatus)
|
||||||
Start()
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
Stop()
|
|
||||||
GetSettings() (settings settings.DNS)
|
GetSettings() (settings settings.DNS)
|
||||||
SetSettings(settings settings.DNS)
|
SetSettings(settings settings.DNS) (outcome string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
conf Configurator
|
state state
|
||||||
settings settings.DNS
|
conf Configurator
|
||||||
settingsMutex sync.RWMutex
|
logger logging.Logger
|
||||||
logger logging.Logger
|
streamMerger command.StreamMerger
|
||||||
streamMerger command.StreamMerger
|
uid int
|
||||||
uid int
|
gid int
|
||||||
gid int
|
loopLock sync.Mutex
|
||||||
restart chan struct{}
|
start chan struct{}
|
||||||
start chan struct{}
|
running chan models.LoopStatus
|
||||||
stop chan struct{}
|
stop chan struct{}
|
||||||
updateTicker chan struct{}
|
stopped chan struct{}
|
||||||
timeNow func() time.Time
|
updateTicker chan struct{}
|
||||||
timeSince func(time.Time) time.Duration
|
timeNow func() time.Time
|
||||||
|
timeSince func(time.Time) time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
||||||
streamMerger command.StreamMerger, uid, gid int) Looper {
|
streamMerger command.StreamMerger, uid, gid int) Looper {
|
||||||
return &looper{
|
return &looper{
|
||||||
|
state: state{
|
||||||
|
status: constants.Stopped,
|
||||||
|
settings: settings,
|
||||||
|
},
|
||||||
conf: conf,
|
conf: conf,
|
||||||
settings: settings,
|
|
||||||
logger: logger.WithPrefix("dns over tls: "),
|
logger: logger.WithPrefix("dns over tls: "),
|
||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
streamMerger: streamMerger,
|
streamMerger: streamMerger,
|
||||||
restart: make(chan struct{}),
|
|
||||||
start: make(chan struct{}),
|
start: make(chan struct{}),
|
||||||
|
running: make(chan models.LoopStatus),
|
||||||
stop: make(chan struct{}),
|
stop: make(chan struct{}),
|
||||||
|
stopped: make(chan struct{}),
|
||||||
updateTicker: make(chan struct{}),
|
updateTicker: make(chan struct{}),
|
||||||
timeNow: time.Now,
|
timeNow: time.Now,
|
||||||
timeSince: time.Since,
|
timeSince: time.Since,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
|
||||||
func (l *looper) Start() { l.start <- struct{}{} }
|
|
||||||
func (l *looper) Stop() { l.stop <- struct{}{} }
|
|
||||||
|
|
||||||
func (l *looper) GetSettings() (settings settings.DNS) {
|
|
||||||
l.settingsMutex.RLock()
|
|
||||||
defer l.settingsMutex.RUnlock()
|
|
||||||
return l.settings
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) SetSettings(settings settings.DNS) {
|
|
||||||
l.settingsMutex.Lock()
|
|
||||||
defer l.settingsMutex.Unlock()
|
|
||||||
updatePeriodDiffers := l.settings.UpdatePeriod != settings.UpdatePeriod
|
|
||||||
l.settings = settings
|
|
||||||
l.settingsMutex.Unlock()
|
|
||||||
if updatePeriodDiffers {
|
|
||||||
l.updateTicker <- struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) isEnabled() bool {
|
|
||||||
l.settingsMutex.RLock()
|
|
||||||
defer l.settingsMutex.RUnlock()
|
|
||||||
return l.settings.Enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) setEnabled(enabled bool) {
|
|
||||||
l.settingsMutex.Lock()
|
|
||||||
defer l.settingsMutex.Unlock()
|
|
||||||
l.settings.Enabled = enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||||
l.logger.Warn(err)
|
l.logger.Warn(err)
|
||||||
l.logger.Info("attempting restart in 10 seconds")
|
l.logger.Info("attempting restart in 10 seconds")
|
||||||
@@ -103,96 +75,42 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) waitForFirstStart(ctx context.Context, signalDNSReady func()) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-l.stop:
|
|
||||||
l.setEnabled(false)
|
|
||||||
l.logger.Info("not started yet")
|
|
||||||
case <-l.restart:
|
|
||||||
if l.isEnabled() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
signalDNSReady()
|
|
||||||
l.logger.Info("not restarting because disabled")
|
|
||||||
case <-l.start:
|
|
||||||
l.setEnabled(true)
|
|
||||||
return
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel context.CancelFunc) {
|
|
||||||
if l.isEnabled() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
// wait for a signal to re-enable
|
|
||||||
select {
|
|
||||||
case <-l.stop:
|
|
||||||
l.logger.Info("already disabled")
|
|
||||||
case <-l.restart:
|
|
||||||
if !l.isEnabled() {
|
|
||||||
l.logger.Info("not restarting because disabled")
|
|
||||||
} else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-l.start:
|
|
||||||
l.setEnabled(true)
|
|
||||||
return
|
|
||||||
case <-ctx.Done():
|
|
||||||
unboundCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution?
|
||||||
l.waitForFirstStart(ctx, signalDNSReady)
|
|
||||||
if ctx.Err() != nil {
|
select {
|
||||||
|
case <-l.start:
|
||||||
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer l.logger.Warn("loop exited")
|
defer l.logger.Warn("loop exited")
|
||||||
|
|
||||||
var unboundCtx context.Context
|
|
||||||
var unboundCancel context.CancelFunc = func() {}
|
|
||||||
var waitError chan error
|
|
||||||
triggeredRestart := false
|
|
||||||
l.setEnabled(true)
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
l.waitForSubsequentStart(ctx, unboundCancel)
|
err := l.updateFiles(ctx)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
l.state.setStatusWithLock(constants.Crashed)
|
||||||
|
l.logAndWait(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
crashed := false
|
||||||
|
|
||||||
|
for ctx.Err() == nil {
|
||||||
settings := l.GetSettings()
|
settings := l.GetSettings()
|
||||||
|
|
||||||
// Setup
|
unboundCtx, unboundCancel := context.WithCancel(context.Background())
|
||||||
if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil {
|
|
||||||
l.logAndWait(ctx, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil {
|
|
||||||
l.logAndWait(ctx, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil {
|
|
||||||
l.logAndWait(ctx, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if triggeredRestart {
|
|
||||||
triggeredRestart = false
|
|
||||||
unboundCancel()
|
|
||||||
<-waitError
|
|
||||||
close(waitError)
|
|
||||||
}
|
|
||||||
unboundCtx, unboundCancel = context.WithCancel(context.Background())
|
|
||||||
stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel)
|
stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unboundCancel()
|
unboundCancel()
|
||||||
|
if !crashed {
|
||||||
|
l.running <- constants.Crashed
|
||||||
|
}
|
||||||
|
crashed = true
|
||||||
const fallback = true
|
const fallback = true
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(fallback)
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
@@ -201,23 +119,37 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
|
|||||||
|
|
||||||
// Started successfully
|
// Started successfully
|
||||||
go l.streamMerger.Merge(unboundCtx, stream, command.MergeName("unbound"))
|
go l.streamMerger.Merge(unboundCtx, stream, command.MergeName("unbound"))
|
||||||
|
|
||||||
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
||||||
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound
|
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound
|
||||||
l.logger.Error(err)
|
l.logger.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := l.conf.WaitForUnbound(); err != nil {
|
if err := l.conf.WaitForUnbound(); err != nil {
|
||||||
|
if !crashed {
|
||||||
|
l.running <- constants.Crashed
|
||||||
|
crashed = true
|
||||||
|
}
|
||||||
unboundCancel()
|
unboundCancel()
|
||||||
const fallback = true
|
const fallback = true
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(fallback)
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
waitError = make(chan error)
|
|
||||||
|
waitError := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
err := waitFn() // blocking
|
err := waitFn() // blocking
|
||||||
waitError <- err
|
waitError <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
l.logger.Info("DNS over TLS is ready")
|
l.logger.Info("DNS over TLS is ready")
|
||||||
|
if !crashed {
|
||||||
|
l.running <- constants.Running
|
||||||
|
crashed = false
|
||||||
|
} else {
|
||||||
|
l.state.setStatusWithLock(constants.Running)
|
||||||
|
}
|
||||||
signalDNSReady()
|
signalDNSReady()
|
||||||
|
|
||||||
stayHere := true
|
stayHere := true
|
||||||
@@ -229,31 +161,28 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
|
|||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
return
|
return
|
||||||
case <-l.restart: // triggered restart
|
|
||||||
l.logger.Info("restarting")
|
|
||||||
// unboundCancel occurs next loop run when the setup is complete
|
|
||||||
triggeredRestart = true
|
|
||||||
stayHere = false
|
|
||||||
case <-l.start:
|
|
||||||
l.logger.Info("already started")
|
|
||||||
case <-l.stop:
|
case <-l.stop:
|
||||||
l.logger.Info("stopping")
|
l.logger.Info("stopping")
|
||||||
|
const fallback = false
|
||||||
|
l.useUnencryptedDNS(fallback)
|
||||||
unboundCancel()
|
unboundCancel()
|
||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
l.stopped <- struct{}{}
|
||||||
l.setEnabled(false)
|
case <-l.start:
|
||||||
|
l.logger.Info("starting")
|
||||||
stayHere = false
|
stayHere = false
|
||||||
case err := <-waitError: // unexpected error
|
case err := <-waitError: // unexpected error
|
||||||
close(waitError)
|
|
||||||
unboundCancel()
|
unboundCancel()
|
||||||
|
l.state.setStatusWithLock(constants.Crashed)
|
||||||
const fallback = true
|
const fallback = true
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(fallback)
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
stayHere = false
|
stayHere = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
close(waitError)
|
||||||
|
unboundCancel()
|
||||||
}
|
}
|
||||||
unboundCancel()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) useUnencryptedDNS(fallback bool) {
|
func (l *looper) useUnencryptedDNS(fallback bool) {
|
||||||
@@ -279,7 +208,11 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
|
|||||||
data := constants.DNSProviderMapping()[provider]
|
data := constants.DNSProviderMapping()[provider]
|
||||||
for _, targetIP = range data.IPs {
|
for _, targetIP = range data.IPs {
|
||||||
if targetIP.To4() != nil {
|
if targetIP.To4() != nil {
|
||||||
l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
|
if fallback {
|
||||||
|
l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
|
||||||
|
} else {
|
||||||
|
l.logger.Info("using plaintext DNS at address %s", targetIP)
|
||||||
|
}
|
||||||
l.conf.UseDNSInternally(targetIP)
|
l.conf.UseDNSInternally(targetIP)
|
||||||
if err := l.conf.UseDNSSystemWide(targetIP, settings.KeepNameserver); err != nil {
|
if err := l.conf.UseDNSSystemWide(targetIP, settings.KeepNameserver); err != nil {
|
||||||
l.logger.Error(err)
|
l.logger.Error(err)
|
||||||
@@ -314,7 +247,20 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
return
|
return
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
lastTick = l.timeNow()
|
lastTick = l.timeNow()
|
||||||
l.restart <- struct{}{}
|
|
||||||
|
status := l.GetStatus()
|
||||||
|
if status == constants.Running {
|
||||||
|
if err := l.updateFiles(ctx); err != nil {
|
||||||
|
l.state.setStatusWithLock(constants.Crashed)
|
||||||
|
l.logger.Error(err)
|
||||||
|
l.logger.Warn("skipping Unbound restart due to failed files update")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = l.SetStatus(constants.Stopped)
|
||||||
|
_, _ = l.SetStatus(constants.Running)
|
||||||
|
|
||||||
settings := l.GetSettings()
|
settings := l.GetSettings()
|
||||||
timer.Reset(settings.UpdatePeriod)
|
timer.Reset(settings.UpdatePeriod)
|
||||||
case <-l.updateTicker:
|
case <-l.updateTicker:
|
||||||
@@ -337,3 +283,17 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *looper) updateFiles(ctx context.Context) (err error) {
|
||||||
|
if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
settings := l.GetSettings()
|
||||||
|
if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
96
internal/dns/state.go
Normal file
96
internal/dns/state.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type state struct {
|
||||||
|
status models.LoopStatus
|
||||||
|
settings settings.DNS
|
||||||
|
statusMu sync.RWMutex
|
||||||
|
settingsMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *state) setStatusWithLock(status models.LoopStatus) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.status = status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) GetStatus() (status models.LoopStatus) {
|
||||||
|
l.state.statusMu.RLock()
|
||||||
|
defer l.state.statusMu.RUnlock()
|
||||||
|
return l.state.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
defer l.state.statusMu.Unlock()
|
||||||
|
existingStatus := l.state.status
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case constants.Running:
|
||||||
|
switch existingStatus {
|
||||||
|
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
|
||||||
|
return fmt.Sprintf("already %s", existingStatus), nil
|
||||||
|
}
|
||||||
|
l.loopLock.Lock()
|
||||||
|
defer l.loopLock.Unlock()
|
||||||
|
l.state.status = constants.Starting
|
||||||
|
l.state.statusMu.Unlock()
|
||||||
|
l.start <- struct{}{}
|
||||||
|
newStatus := <-l.running
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
l.state.status = newStatus
|
||||||
|
return newStatus.String(), nil
|
||||||
|
case constants.Stopped:
|
||||||
|
switch existingStatus {
|
||||||
|
case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed:
|
||||||
|
return fmt.Sprintf("already %s", existingStatus), nil
|
||||||
|
}
|
||||||
|
l.loopLock.Lock()
|
||||||
|
defer l.loopLock.Unlock()
|
||||||
|
l.state.status = constants.Stopping
|
||||||
|
l.state.statusMu.Unlock()
|
||||||
|
l.stop <- struct{}{}
|
||||||
|
<-l.stopped
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
l.state.status = constants.Stopped
|
||||||
|
return status.String(), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("status %q can only be %q or %q",
|
||||||
|
status, constants.Running, constants.Stopped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) GetSettings() (settings settings.DNS) {
|
||||||
|
l.state.settingsMu.RLock()
|
||||||
|
defer l.state.settingsMu.RUnlock()
|
||||||
|
return l.state.settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) SetSettings(settings settings.DNS) (outcome string) {
|
||||||
|
l.state.settingsMu.Lock()
|
||||||
|
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
|
||||||
|
if settingsUnchanged {
|
||||||
|
l.state.settingsMu.Unlock()
|
||||||
|
return "settings left unchanged"
|
||||||
|
}
|
||||||
|
tempSettings := l.state.settings
|
||||||
|
tempSettings.UpdatePeriod = settings.UpdatePeriod
|
||||||
|
onlyUpdatePeriodChanged := reflect.DeepEqual(tempSettings, settings)
|
||||||
|
l.state.settings = settings
|
||||||
|
if onlyUpdatePeriodChanged {
|
||||||
|
l.updateTicker <- struct{}{}
|
||||||
|
return "update period changed"
|
||||||
|
}
|
||||||
|
_, _ = l.SetStatus(constants.Stopped)
|
||||||
|
outcome, _ = l.SetStatus(constants.Running)
|
||||||
|
return outcome
|
||||||
|
}
|
||||||
@@ -20,8 +20,14 @@ type (
|
|||||||
VPNProvider string
|
VPNProvider string
|
||||||
// NetworkProtocol contains the network protocol to be used to communicate with the VPN servers.
|
// NetworkProtocol contains the network protocol to be used to communicate with the VPN servers.
|
||||||
NetworkProtocol string
|
NetworkProtocol string
|
||||||
|
// Loop status such as stopped or running.
|
||||||
|
LoopStatus string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (ls LoopStatus) String() string {
|
||||||
|
return string(ls)
|
||||||
|
}
|
||||||
|
|
||||||
func marshalJSONString(s string) (data []byte, err error) {
|
func marshalJSONString(s string) (data []byte, err error) {
|
||||||
return []byte(fmt.Sprintf("%q", s)), nil
|
return []byte(fmt.Sprintf("%q", s)), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package models
|
|||||||
type BuildInformation struct {
|
type BuildInformation struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
Commit string `json:"commit"`
|
Commit string `json:"commit"`
|
||||||
BuildDate string `json:"buildDate"`
|
BuildDate string `json:"build_date"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import "net"
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
type OpenVPNConnection struct {
|
type OpenVPNConnection struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ import (
|
|||||||
// ProviderSettings contains settings specific to a VPN provider.
|
// ProviderSettings contains settings specific to a VPN provider.
|
||||||
type ProviderSettings struct {
|
type ProviderSettings struct {
|
||||||
Name VPNProvider `json:"name"`
|
Name VPNProvider `json:"name"`
|
||||||
ServerSelection ServerSelection `json:"serverSelection"`
|
ServerSelection ServerSelection `json:"server_selection"`
|
||||||
ExtraConfigOptions ExtraConfigOptions `json:"extraConfig"`
|
ExtraConfigOptions ExtraConfigOptions `json:"extra_config"`
|
||||||
PortForwarding PortForwarding `json:"portForwarding"`
|
PortForwarding PortForwarding `json:"port_forwarding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerSelection struct {
|
type ServerSelection struct {
|
||||||
// Common
|
// Common
|
||||||
Protocol NetworkProtocol `json:"networkProtocol"`
|
Protocol NetworkProtocol `json:"network_protocol"`
|
||||||
TargetIP net.IP `json:"targetIP,omitempty"`
|
TargetIP net.IP `json:"target_ip,omitempty"`
|
||||||
|
|
||||||
// Cyberghost, PIA, Surfshark, Windscribe, Vyprvpn, NordVPN
|
// Cyberghost, PIA, Surfshark, Windscribe, Vyprvpn, NordVPN
|
||||||
Regions []string `json:"regions"`
|
Regions []string `json:"regions"`
|
||||||
@@ -34,20 +34,20 @@ type ServerSelection struct {
|
|||||||
Owned bool `json:"owned"`
|
Owned bool `json:"owned"`
|
||||||
|
|
||||||
// Mullvad, Windscribe
|
// Mullvad, Windscribe
|
||||||
CustomPort uint16 `json:"customPort"`
|
CustomPort uint16 `json:"custom_port"`
|
||||||
|
|
||||||
// NordVPN
|
// NordVPN
|
||||||
Numbers []uint16 `json:"numbers"`
|
Numbers []uint16 `json:"numbers"`
|
||||||
|
|
||||||
// PIA
|
// PIA
|
||||||
EncryptionPreset string `json:"encryptionPreset"`
|
EncryptionPreset string `json:"encryption_preset"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ExtraConfigOptions struct {
|
type ExtraConfigOptions struct {
|
||||||
ClientCertificate string `json:"-"` // Cyberghost
|
ClientCertificate string `json:"-"` // Cyberghost
|
||||||
ClientKey string `json:"-"` // Cyberghost
|
ClientKey string `json:"-"` // Cyberghost
|
||||||
EncryptionPreset string `json:"encryptionPreset"` // PIA
|
EncryptionPreset string `json:"encryption_preset"` // PIA
|
||||||
OpenVPNIPv6 bool `json:"openvpnIPv6"` // Mullvad
|
OpenVPNIPv6 bool `json:"openvpn_ipv6"` // Mullvad
|
||||||
}
|
}
|
||||||
|
|
||||||
// PortForwarding contains settings for port forwarding.
|
// PortForwarding contains settings for port forwarding.
|
||||||
|
|||||||
@@ -20,23 +20,18 @@ import (
|
|||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
Restart()
|
GetStatus() (status models.LoopStatus)
|
||||||
PortForward(vpnGatewayIP net.IP)
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
GetSettings() (settings settings.OpenVPN)
|
GetSettings() (settings settings.OpenVPN)
|
||||||
SetSettings(settings settings.OpenVPN)
|
SetSettings(settings settings.OpenVPN) (outcome string)
|
||||||
GetPortForwarded() (portForwarded uint16)
|
GetServers() (servers models.AllServers)
|
||||||
SetAllServers(allServers models.AllServers)
|
SetServers(servers models.AllServers)
|
||||||
|
GetPortForwarded() (port uint16)
|
||||||
|
PortForward(vpnGatewayIP net.IP)
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
// Variable parameters
|
state state
|
||||||
provider models.VPNProvider
|
|
||||||
settings settings.OpenVPN
|
|
||||||
settingsMutex sync.RWMutex
|
|
||||||
portForwarded uint16
|
|
||||||
portForwardedMutex sync.RWMutex
|
|
||||||
allServers models.AllServers
|
|
||||||
allServersMutex sync.RWMutex
|
|
||||||
// Fixed parameters
|
// Fixed parameters
|
||||||
uid int
|
uid int
|
||||||
gid int
|
gid int
|
||||||
@@ -50,22 +45,27 @@ type looper struct {
|
|||||||
fileManager files.FileManager
|
fileManager files.FileManager
|
||||||
streamMerger command.StreamMerger
|
streamMerger command.StreamMerger
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
// Internal channels
|
// Internal channels and locks
|
||||||
restart chan struct{}
|
loopLock sync.Mutex
|
||||||
|
running chan models.LoopStatus
|
||||||
|
stop, stopped chan struct{}
|
||||||
|
start chan struct{}
|
||||||
portForwardSignals chan net.IP
|
portForwardSignals chan net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
func NewLooper(settings settings.OpenVPN,
|
||||||
uid, gid int, allServers models.AllServers,
|
uid, gid int, allServers models.AllServers,
|
||||||
conf Configurator, fw firewall.Configurator, routing routing.Routing,
|
conf Configurator, fw firewall.Configurator, routing routing.Routing,
|
||||||
logger logging.Logger, client *http.Client, fileManager files.FileManager,
|
logger logging.Logger, client *http.Client, fileManager files.FileManager,
|
||||||
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
|
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
|
||||||
return &looper{
|
return &looper{
|
||||||
provider: provider,
|
state: state{
|
||||||
settings: settings,
|
status: constants.Stopped,
|
||||||
|
settings: settings,
|
||||||
|
allServers: allServers,
|
||||||
|
},
|
||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
allServers: allServers,
|
|
||||||
conf: conf,
|
conf: conf,
|
||||||
fw: fw,
|
fw: fw,
|
||||||
routing: routing,
|
routing: routing,
|
||||||
@@ -75,46 +75,29 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
|||||||
fileManager: fileManager,
|
fileManager: fileManager,
|
||||||
streamMerger: streamMerger,
|
streamMerger: streamMerger,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
restart: make(chan struct{}),
|
start: make(chan struct{}),
|
||||||
|
running: make(chan models.LoopStatus),
|
||||||
|
stop: make(chan struct{}),
|
||||||
|
stopped: make(chan struct{}),
|
||||||
portForwardSignals: make(chan net.IP),
|
portForwardSignals: make(chan net.IP),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
|
||||||
func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
|
func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
|
||||||
|
|
||||||
func (l *looper) GetSettings() (settings settings.OpenVPN) {
|
|
||||||
l.settingsMutex.RLock()
|
|
||||||
defer l.settingsMutex.RUnlock()
|
|
||||||
return l.settings
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) SetSettings(settings settings.OpenVPN) {
|
|
||||||
l.settingsMutex.Lock()
|
|
||||||
defer l.settingsMutex.Unlock()
|
|
||||||
l.settings = settings
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) SetAllServers(allServers models.AllServers) {
|
|
||||||
l.allServersMutex.Lock()
|
|
||||||
defer l.allServersMutex.Unlock()
|
|
||||||
l.allServers = allServers
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
crashed := false
|
||||||
select {
|
select {
|
||||||
case <-l.restart:
|
case <-l.start:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer l.logger.Warn("loop exited")
|
defer l.logger.Warn("loop exited")
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
settings := l.GetSettings()
|
settings, allServers := l.state.getSettingsAndServers()
|
||||||
l.allServersMutex.RLock()
|
providerConf := provider.New(settings.Provider.Name, allServers, time.Now)
|
||||||
providerConf := provider.New(l.provider, l.allServers, time.Now)
|
|
||||||
l.allServersMutex.RUnlock()
|
|
||||||
connection, err := providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection)
|
connection, err := providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.logger.Error(err)
|
l.logger.Error(err)
|
||||||
@@ -155,6 +138,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openvpnCancel()
|
openvpnCancel()
|
||||||
|
if !crashed {
|
||||||
|
l.running <- constants.Crashed
|
||||||
|
crashed = true
|
||||||
|
}
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -179,23 +166,41 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
err := waitFn() // blocking
|
err := waitFn() // blocking
|
||||||
waitError <- err
|
waitError <- err
|
||||||
}()
|
}()
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
if !crashed {
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
l.running <- constants.Running
|
||||||
openvpnCancel()
|
crashed = false
|
||||||
<-waitError
|
} else {
|
||||||
close(waitError)
|
l.state.setStatusWithLock(constants.Running)
|
||||||
return
|
|
||||||
case <-l.restart: // triggered restart
|
|
||||||
l.logger.Info("restarting")
|
|
||||||
openvpnCancel()
|
|
||||||
<-waitError
|
|
||||||
close(waitError)
|
|
||||||
case err := <-waitError: // unexpected error
|
|
||||||
openvpnCancel()
|
|
||||||
close(waitError)
|
|
||||||
l.logAndWait(ctx, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stayHere := true
|
||||||
|
for stayHere {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.logger.Warn("context canceled: exiting loop")
|
||||||
|
openvpnCancel()
|
||||||
|
<-waitError
|
||||||
|
close(waitError)
|
||||||
|
return
|
||||||
|
case <-l.stop:
|
||||||
|
l.logger.Info("stopping")
|
||||||
|
openvpnCancel()
|
||||||
|
<-waitError
|
||||||
|
l.stopped <- struct{}{}
|
||||||
|
case <-l.start:
|
||||||
|
l.logger.Info("starting")
|
||||||
|
stayHere = false
|
||||||
|
case err := <-waitError: // unexpected error
|
||||||
|
openvpnCancel()
|
||||||
|
l.state.setStatusWithLock(constants.Crashed)
|
||||||
|
l.logAndWait(ctx, err)
|
||||||
|
crashed = true
|
||||||
|
stayHere = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(waitError)
|
||||||
|
openvpnCancel() // just for the linter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,24 +223,21 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
|
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
|
||||||
providerConf provider.Provider, client *http.Client, gateway net.IP) {
|
providerConf provider.Provider, client *http.Client, gateway net.IP) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
settings := l.GetSettings()
|
l.state.portForwardedMu.RLock()
|
||||||
|
settings := l.state.settings
|
||||||
|
l.state.portForwardedMu.RUnlock()
|
||||||
if !settings.Provider.PortForwarding.Enabled {
|
if !settings.Provider.PortForwarding.Enabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
syncState := func(port uint16) (pfFilepath models.Filepath) {
|
syncState := func(port uint16) (pfFilepath models.Filepath) {
|
||||||
l.portForwardedMutex.Lock()
|
l.state.portForwardedMu.Lock()
|
||||||
l.portForwarded = port
|
defer l.state.portForwardedMu.Unlock()
|
||||||
l.portForwardedMutex.Unlock()
|
l.state.portForwarded = port
|
||||||
settings := l.GetSettings()
|
l.state.settingsMu.RLock()
|
||||||
|
defer l.state.settingsMu.RUnlock()
|
||||||
return settings.Provider.PortForwarding.Filepath
|
return settings.Provider.PortForwarding.Filepath
|
||||||
}
|
}
|
||||||
providerConf.PortForward(ctx,
|
providerConf.PortForward(ctx,
|
||||||
client, l.fileManager, l.pfLogger,
|
client, l.fileManager, l.pfLogger,
|
||||||
gateway, l.fw, syncState)
|
gateway, l.fw, syncState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) GetPortForwarded() (portForwarded uint16) {
|
|
||||||
l.portForwardedMutex.RLock()
|
|
||||||
defer l.portForwardedMutex.RUnlock()
|
|
||||||
return l.portForwarded
|
|
||||||
}
|
|
||||||
|
|||||||
121
internal/openvpn/state.go
Normal file
121
internal/openvpn/state.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package openvpn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type state struct {
|
||||||
|
status models.LoopStatus
|
||||||
|
settings settings.OpenVPN
|
||||||
|
allServers models.AllServers
|
||||||
|
portForwarded uint16
|
||||||
|
statusMu sync.RWMutex
|
||||||
|
settingsMu sync.RWMutex
|
||||||
|
allServersMu sync.RWMutex
|
||||||
|
portForwardedMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *state) setStatusWithLock(status models.LoopStatus) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.status = status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *state) getSettingsAndServers() (settings settings.OpenVPN, allServers models.AllServers) {
|
||||||
|
s.settingsMu.RLock()
|
||||||
|
s.allServersMu.RLock()
|
||||||
|
settings = s.settings
|
||||||
|
allServers = s.allServers
|
||||||
|
s.settingsMu.RLock()
|
||||||
|
s.allServersMu.RLock()
|
||||||
|
return settings, allServers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) GetStatus() (status models.LoopStatus) {
|
||||||
|
l.state.statusMu.RLock()
|
||||||
|
defer l.state.statusMu.RUnlock()
|
||||||
|
return l.state.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
defer l.state.statusMu.Unlock()
|
||||||
|
existingStatus := l.state.status
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case constants.Running:
|
||||||
|
switch existingStatus {
|
||||||
|
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
|
||||||
|
return fmt.Sprintf("already %s", existingStatus), nil
|
||||||
|
}
|
||||||
|
l.loopLock.Lock()
|
||||||
|
defer l.loopLock.Unlock()
|
||||||
|
l.state.status = constants.Starting
|
||||||
|
l.state.statusMu.Unlock()
|
||||||
|
l.start <- struct{}{}
|
||||||
|
newStatus := <-l.running
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
l.state.status = newStatus
|
||||||
|
return newStatus.String(), nil
|
||||||
|
case constants.Stopped:
|
||||||
|
switch existingStatus {
|
||||||
|
case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed:
|
||||||
|
return fmt.Sprintf("already %s", existingStatus), nil
|
||||||
|
}
|
||||||
|
l.loopLock.Lock()
|
||||||
|
defer l.loopLock.Unlock()
|
||||||
|
l.state.status = constants.Stopping
|
||||||
|
l.state.statusMu.Unlock()
|
||||||
|
l.stop <- struct{}{}
|
||||||
|
<-l.stopped
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
l.state.status = constants.Stopped
|
||||||
|
return status.String(), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("status %q can only be %q or %q",
|
||||||
|
status, constants.Running, constants.Stopped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) GetSettings() (settings settings.OpenVPN) {
|
||||||
|
l.state.settingsMu.RLock()
|
||||||
|
defer l.state.settingsMu.RUnlock()
|
||||||
|
return l.state.settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) SetSettings(settings settings.OpenVPN) (outcome string) {
|
||||||
|
l.state.settingsMu.Lock()
|
||||||
|
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
|
||||||
|
if settingsUnchanged {
|
||||||
|
l.state.settingsMu.Unlock()
|
||||||
|
return "settings left unchanged"
|
||||||
|
}
|
||||||
|
l.state.settings = settings
|
||||||
|
_, _ = l.SetStatus(constants.Stopped)
|
||||||
|
outcome, _ = l.SetStatus(constants.Running)
|
||||||
|
return outcome
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) GetServers() (servers models.AllServers) {
|
||||||
|
l.state.allServersMu.RLock()
|
||||||
|
defer l.state.allServersMu.RUnlock()
|
||||||
|
return l.state.allServers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) SetServers(servers models.AllServers) {
|
||||||
|
l.state.allServersMu.Lock()
|
||||||
|
defer l.state.allServersMu.Unlock()
|
||||||
|
l.state.allServers = servers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) GetPortForwarded() (port uint16) {
|
||||||
|
l.state.portForwardedMu.RLock()
|
||||||
|
defer l.state.portForwardedMu.RUnlock()
|
||||||
|
return port
|
||||||
|
}
|
||||||
@@ -130,8 +130,8 @@ func (r *reader) GetDNSOverTLSPrivateAddresses() (privateAddresses []string, err
|
|||||||
return privateAddresses, nil
|
return privateAddresses, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using ipv6 DNS over TLS
|
// GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using
|
||||||
// servers from the environment variable DOT_IPV6.
|
// ipv6 DNS over TLS from the environment variable DOT_IPV6.
|
||||||
func (r *reader) GetDNSOverTLSIPv6() (ipv6 bool, err error) {
|
func (r *reader) GetDNSOverTLSIPv6() (ipv6 bool, err error) {
|
||||||
return r.envParams.GetOnOff("DOT_IPV6", libparams.Default("off"))
|
return r.envParams.GetOnOff("DOT_IPV6", libparams.Default("off"))
|
||||||
}
|
}
|
||||||
|
|||||||
76
internal/server/dns.go
Normal file
76
internal/server/dns.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
//nolint:dupl
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/dns"
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newDNSHandler(looper dns.Looper, logger logging.Logger) http.Handler {
|
||||||
|
return &dnsHandler{
|
||||||
|
looper: looper,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type dnsHandler struct {
|
||||||
|
looper dns.Looper
|
||||||
|
logger logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/dns")
|
||||||
|
switch r.RequestURI {
|
||||||
|
case "/status": //nolint:goconst
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.getStatus(w)
|
||||||
|
case http.MethodPut:
|
||||||
|
h.setStatus(w, r)
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) getStatus(w http.ResponseWriter) {
|
||||||
|
status := h.looper.GetStatus()
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
data := statusWrapper{Status: string(status)}
|
||||||
|
if err := encoder.Encode(data); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) setStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
decoder := json.NewDecoder(r.Body)
|
||||||
|
var data statusWrapper
|
||||||
|
if err := decoder.Decode(&data); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
status, err := data.getStatus()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
outcome, err := h.looper.SetStatus(status)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/dns"
|
"github.com/qdm12/gluetun/internal/dns"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
@@ -17,54 +17,33 @@ func newHandler(logger logging.Logger, logging bool,
|
|||||||
unboundLooper dns.Looper,
|
unboundLooper dns.Looper,
|
||||||
updaterLooper updater.Looper,
|
updaterLooper updater.Looper,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return &handler{
|
handler := &handler{}
|
||||||
logger: logger,
|
|
||||||
logging: logging,
|
openvpn := newOpenvpnHandler(openvpnLooper, logger)
|
||||||
buildInfo: buildInfo,
|
dns := newDNSHandler(unboundLooper, logger)
|
||||||
openvpnLooper: openvpnLooper,
|
updater := newUpdaterHandler(updaterLooper, logger)
|
||||||
unboundLooper: unboundLooper,
|
|
||||||
updaterLooper: updaterLooper,
|
handler.v0 = newHandlerV0(logger, openvpnLooper, unboundLooper, updaterLooper)
|
||||||
}
|
handler.v1 = newHandlerV1(logger, buildInfo, openvpn, dns, updater)
|
||||||
|
|
||||||
|
handlerWithLog := withLogMiddleware(handler, logger, logging)
|
||||||
|
handler.setLogEnabled = handlerWithLog.setEnabled
|
||||||
|
|
||||||
|
return handlerWithLog
|
||||||
}
|
}
|
||||||
|
|
||||||
type handler struct {
|
type handler struct {
|
||||||
logger logging.Logger
|
v0 http.Handler
|
||||||
logging bool
|
v1 http.Handler
|
||||||
buildInfo models.BuildInformation
|
setLogEnabled func(enabled bool)
|
||||||
openvpnLooper openvpn.Looper
|
|
||||||
unboundLooper dns.Looper
|
|
||||||
updaterLooper updater.Looper
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
|
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if h.logging {
|
r.RequestURI = strings.TrimSuffix(r.RequestURI, "/")
|
||||||
h.logger.Info("HTTP %s %s", request.Method, request.RequestURI)
|
if !strings.HasPrefix(r.RequestURI, "/v1/") && r.RequestURI != "/v1" {
|
||||||
}
|
h.v0.ServeHTTP(w, r)
|
||||||
switch request.Method {
|
return
|
||||||
case http.MethodGet:
|
|
||||||
switch request.RequestURI {
|
|
||||||
case "/version":
|
|
||||||
h.getVersion(responseWriter)
|
|
||||||
responseWriter.WriteHeader(http.StatusOK)
|
|
||||||
case "/openvpn/actions/restart":
|
|
||||||
h.openvpnLooper.Restart()
|
|
||||||
responseWriter.WriteHeader(http.StatusOK)
|
|
||||||
case "/unbound/actions/restart":
|
|
||||||
h.unboundLooper.Restart()
|
|
||||||
responseWriter.WriteHeader(http.StatusOK)
|
|
||||||
case "/openvpn/portforwarded":
|
|
||||||
h.getPortForwarded(responseWriter)
|
|
||||||
case "/openvpn/settings":
|
|
||||||
h.getOpenvpnSettings(responseWriter)
|
|
||||||
case "/updater/restart":
|
|
||||||
h.updaterLooper.Restart()
|
|
||||||
responseWriter.WriteHeader(http.StatusOK)
|
|
||||||
default:
|
|
||||||
errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI)
|
|
||||||
http.Error(responseWriter, errString, http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI)
|
|
||||||
http.Error(responseWriter, errString, http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
|
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/v1")
|
||||||
|
h.v1.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
69
internal/server/handlerv0.go
Normal file
69
internal/server/handlerv0.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/dns"
|
||||||
|
"github.com/qdm12/gluetun/internal/openvpn"
|
||||||
|
"github.com/qdm12/gluetun/internal/updater"
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newHandlerV0(logger logging.Logger,
|
||||||
|
openvpn openvpn.Looper, dns dns.Looper, updater updater.Looper) http.Handler {
|
||||||
|
return &handlerV0{
|
||||||
|
logger: logger,
|
||||||
|
openvpn: openvpn,
|
||||||
|
dns: dns,
|
||||||
|
updater: updater,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type handlerV0 struct {
|
||||||
|
logger logging.Logger
|
||||||
|
openvpn openvpn.Looper
|
||||||
|
dns dns.Looper
|
||||||
|
updater updater.Looper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "unversioned API: only supports GET method", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch r.RequestURI {
|
||||||
|
case "/version":
|
||||||
|
http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect)
|
||||||
|
case "/openvpn/actions/restart":
|
||||||
|
outcome, _ := h.openvpn.SetStatus(constants.Stopped)
|
||||||
|
h.logger.Info("openvpn: %s", outcome)
|
||||||
|
outcome, _ = h.openvpn.SetStatus(constants.Running)
|
||||||
|
h.logger.Info("openvpn: %s", outcome)
|
||||||
|
if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
}
|
||||||
|
case "/unbound/actions/restart":
|
||||||
|
outcome, _ := h.dns.SetStatus(constants.Stopped)
|
||||||
|
h.logger.Info("dns: %s", outcome)
|
||||||
|
outcome, _ = h.dns.SetStatus(constants.Running)
|
||||||
|
h.logger.Info("dns: %s", outcome)
|
||||||
|
if _, err := w.Write([]byte("dns restarted, please consider using the /v1/ API in the future.")); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
}
|
||||||
|
case "/openvpn/portforwarded":
|
||||||
|
http.Redirect(w, r, "/v1/openvpn/portforwarded", http.StatusPermanentRedirect)
|
||||||
|
case "/openvpn/settings":
|
||||||
|
http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect)
|
||||||
|
case "/updater/restart":
|
||||||
|
outcome, _ := h.updater.SetStatus(constants.Stopped)
|
||||||
|
h.logger.Info("updater: %s", outcome)
|
||||||
|
outcome, _ = h.updater.SetStatus(constants.Running)
|
||||||
|
h.logger.Info("updater: %s", outcome)
|
||||||
|
if _, err := w.Write([]byte("updater restarted, please consider using the /v1/ API in the future.")); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
http.Error(w, "unversioned API: requested URI not found", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
58
internal/server/handlerv1.go
Normal file
58
internal/server/handlerv1.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newHandlerV1(logger logging.Logger, buildInfo models.BuildInformation,
|
||||||
|
openvpn, dns, updater http.Handler) http.Handler {
|
||||||
|
return &handlerV1{
|
||||||
|
logger: logger,
|
||||||
|
buildInfo: buildInfo,
|
||||||
|
openvpn: openvpn,
|
||||||
|
dns: dns,
|
||||||
|
updater: updater,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type handlerV1 struct {
|
||||||
|
logger logging.Logger
|
||||||
|
buildInfo models.BuildInformation
|
||||||
|
openvpn http.Handler
|
||||||
|
dns http.Handler
|
||||||
|
updater http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handlerV1) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.RequestURI == "/version" && r.Method == http.MethodGet:
|
||||||
|
h.getVersion(w)
|
||||||
|
case strings.HasPrefix(r.RequestURI, "/openvpn"):
|
||||||
|
h.openvpn.ServeHTTP(w, r)
|
||||||
|
case strings.HasPrefix(r.RequestURI, "/dns"):
|
||||||
|
h.dns.ServeHTTP(w, r)
|
||||||
|
case strings.HasPrefix(r.RequestURI, "/updater"):
|
||||||
|
h.updater.ServeHTTP(w, r)
|
||||||
|
default:
|
||||||
|
errString := fmt.Sprintf("%s %s not found", r.Method, r.RequestURI)
|
||||||
|
http.Error(w, errString, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handlerV1) getVersion(w http.ResponseWriter) {
|
||||||
|
data, err := json.Marshal(h.buildInfo)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := w.Write(data); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
75
internal/server/log.go
Normal file
75
internal/server/log.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func withLogMiddleware(childHandler http.Handler, logger logging.Logger, enabled bool) *logMiddleware {
|
||||||
|
return &logMiddleware{
|
||||||
|
childHandler: childHandler,
|
||||||
|
logger: logger,
|
||||||
|
timeNow: time.Now,
|
||||||
|
enabled: enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type logMiddleware struct {
|
||||||
|
childHandler http.Handler
|
||||||
|
logger logging.Logger
|
||||||
|
timeNow func() time.Time
|
||||||
|
enabled bool
|
||||||
|
enabledMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !m.isEnabled() {
|
||||||
|
m.childHandler.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tStart := m.timeNow()
|
||||||
|
statefulWriter := &statefulResponseWriter{httpWriter: w}
|
||||||
|
m.childHandler.ServeHTTP(statefulWriter, r)
|
||||||
|
duration := m.timeNow().Sub(tStart)
|
||||||
|
m.logger.Info("%d %s %s wrote %dB to %s in %s",
|
||||||
|
statefulWriter.statusCode, r.Method, r.RequestURI, statefulWriter.length, r.RemoteAddr, duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *logMiddleware) setEnabled(enabled bool) {
|
||||||
|
m.enabledMu.Lock()
|
||||||
|
defer m.enabledMu.Unlock()
|
||||||
|
m.enabled = enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *logMiddleware) isEnabled() (enabled bool) {
|
||||||
|
m.enabledMu.RLock()
|
||||||
|
defer m.enabledMu.RUnlock()
|
||||||
|
return m.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
type statefulResponseWriter struct {
|
||||||
|
httpWriter http.ResponseWriter
|
||||||
|
statusCode int
|
||||||
|
length int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *statefulResponseWriter) Write(b []byte) (n int, err error) {
|
||||||
|
n, err = w.httpWriter.Write(b)
|
||||||
|
if w.statusCode == 0 {
|
||||||
|
w.statusCode = http.StatusOK
|
||||||
|
}
|
||||||
|
w.length += n
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *statefulResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
w.statusCode = statusCode
|
||||||
|
w.httpWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *statefulResponseWriter) Header() http.Header {
|
||||||
|
return w.httpWriter.Header()
|
||||||
|
}
|
||||||
@@ -3,34 +3,110 @@ package server
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/openvpn"
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *handler) getPortForwarded(w http.ResponseWriter) {
|
func newOpenvpnHandler(looper openvpn.Looper, logger logging.Logger) http.Handler {
|
||||||
port := h.openvpnLooper.GetPortForwarded()
|
return &openvpnHandler{
|
||||||
data, err := json.Marshal(struct {
|
looper: looper,
|
||||||
Port uint16 `json:"port"`
|
logger: logger,
|
||||||
}{port})
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Warn(err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := w.Write(data); err != nil {
|
|
||||||
h.logger.Warn(err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handler) getOpenvpnSettings(w http.ResponseWriter) {
|
type openvpnHandler struct {
|
||||||
settings := h.openvpnLooper.GetSettings()
|
looper openvpn.Looper
|
||||||
data, err := json.Marshal(settings)
|
logger logging.Logger
|
||||||
if err != nil {
|
}
|
||||||
|
|
||||||
|
func (h *openvpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/openvpn")
|
||||||
|
switch r.RequestURI {
|
||||||
|
case "/status":
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.getStatus(w)
|
||||||
|
case http.MethodPut:
|
||||||
|
h.setStatus(w, r)
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
case "/settings":
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.getSettings(w)
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
case "/portforwarded":
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.getPortForwarded(w)
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *openvpnHandler) getStatus(w http.ResponseWriter) {
|
||||||
|
status := h.looper.GetStatus()
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
data := statusWrapper{Status: string(status)}
|
||||||
|
if err := encoder.Encode(data); err != nil {
|
||||||
h.logger.Warn(err)
|
h.logger.Warn(err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err := w.Write(data); err != nil {
|
}
|
||||||
|
|
||||||
|
func (h *openvpnHandler) setStatus(w http.ResponseWriter, r *http.Request) { //nolint:dupl
|
||||||
|
decoder := json.NewDecoder(r.Body)
|
||||||
|
var data statusWrapper
|
||||||
|
if err := decoder.Decode(&data); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
status, err := data.getStatus()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
outcome, err := h.looper.SetStatus(status)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
|
||||||
h.logger.Warn(err)
|
h.logger.Warn(err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
|
||||||
|
settings := h.looper.GetSettings()
|
||||||
|
settings.User = "redacted"
|
||||||
|
settings.Password = "redacted"
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
if err := encoder.Encode(settings); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
|
||||||
|
port := h.looper.GetPortForwarded()
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
data := portWrapper{Port: port}
|
||||||
|
if err := encoder.Encode(data); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ type server struct {
|
|||||||
handler http.Handler
|
handler http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(address string, logging bool, logger logging.Logger, buildInfo models.BuildInformation,
|
func New(address string, logging bool, logger logging.Logger,
|
||||||
|
buildInfo models.BuildInformation,
|
||||||
openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper) Server {
|
openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper) Server {
|
||||||
serverLogger := logger.WithPrefix("http server: ")
|
serverLogger := logger.WithPrefix("http server: ")
|
||||||
handler := newHandler(serverLogger, logging, buildInfo, openvpnLooper, unboundLooper, updaterLooper)
|
handler := newHandler(serverLogger, logging, buildInfo, openvpnLooper, unboundLooper, updaterLooper)
|
||||||
|
|||||||
78
internal/server/updater.go
Normal file
78
internal/server/updater.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
//nolint:dupl
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/updater"
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newUpdaterHandler(
|
||||||
|
looper updater.Looper,
|
||||||
|
logger logging.Logger) http.Handler {
|
||||||
|
return &updaterHandler{
|
||||||
|
looper: looper,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type updaterHandler struct {
|
||||||
|
looper updater.Looper
|
||||||
|
logger logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *updaterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/updater")
|
||||||
|
switch r.RequestURI {
|
||||||
|
case "/status":
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.getStatus(w)
|
||||||
|
case http.MethodPut:
|
||||||
|
h.setStatus(w, r)
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *updaterHandler) getStatus(w http.ResponseWriter) {
|
||||||
|
status := h.looper.GetStatus()
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
data := statusWrapper{Status: string(status)}
|
||||||
|
if err := encoder.Encode(data); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *updaterHandler) setStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
decoder := json.NewDecoder(r.Body)
|
||||||
|
var data statusWrapper
|
||||||
|
if err := decoder.Decode(&data); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
status, err := data.getStatus()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
outcome, err := h.looper.SetStatus(status)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
encoder := json.NewEncoder(w)
|
||||||
|
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
|
||||||
|
h.logger.Warn(err)
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (h *handler) getVersion(w http.ResponseWriter) {
|
|
||||||
data, err := json.Marshal(h.buildInfo)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Warn(err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := w.Write(data); err != nil {
|
|
||||||
h.logger.Warn(err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
32
internal/server/wrappers.go
Normal file
32
internal/server/wrappers.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
type statusWrapper struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sw *statusWrapper) getStatus() (status models.LoopStatus, err error) {
|
||||||
|
status = models.LoopStatus(sw.Status)
|
||||||
|
switch status {
|
||||||
|
case constants.Stopped, constants.Running:
|
||||||
|
return status, nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"invalid status %q: possible values are: %s, %s",
|
||||||
|
sw.Status, constants.Stopped, constants.Running)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type portWrapper struct {
|
||||||
|
Port uint16 `json:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type outcomeWrapper struct {
|
||||||
|
Outcome string `json:"outcome"`
|
||||||
|
}
|
||||||
@@ -12,9 +12,9 @@ import (
|
|||||||
// OpenVPN contains settings to configure the OpenVPN client.
|
// OpenVPN contains settings to configure the OpenVPN client.
|
||||||
type OpenVPN struct {
|
type OpenVPN struct {
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
Password string `json:"-"`
|
Password string `json:"password"`
|
||||||
Verbosity int `json:"verbosity"`
|
Verbosity int `json:"verbosity"`
|
||||||
Root bool `json:"runAsRoot"`
|
Root bool `json:"run_as_root"`
|
||||||
Cipher string `json:"cipher"`
|
Cipher string `json:"cipher"`
|
||||||
Auth string `json:"auth"`
|
Auth string `json:"auth"`
|
||||||
Provider models.ProviderSettings `json:"provider"`
|
Provider models.ProviderSettings `json:"provider"`
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func Test_OpenVPN_JSON(t *testing.T) {
|
|||||||
data, err := json.Marshal(in)
|
data, err := json.Marshal(in)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
//nolint:lll
|
//nolint:lll
|
||||||
assert.Equal(t, `{"user":"","verbosity":0,"runAsRoot":true,"cipher":"","auth":"","provider":{"name":"name","serverSelection":{"networkProtocol":"","regions":null,"group":"","countries":null,"cities":null,"hostnames":null,"isps":null,"owned":false,"customPort":0,"numbers":null,"encryptionPreset":""},"extraConfig":{"encryptionPreset":"","openvpnIPv6":false},"portForwarding":{"enabled":false,"filepath":""}}}`, string(data))
|
assert.Equal(t, `{"user":"","password":"","verbosity":0,"run_as_root":true,"cipher":"","auth":"","provider":{"name":"name","server_selection":{"network_protocol":"","regions":null,"group":"","countries":null,"cities":null,"hostnames":null,"isps":null,"owned":false,"custom_port":0,"numbers":null,"encryption_preset":""},"extra_config":{"encryption_preset":"","openvpn_ipv6":false},"port_forwarding":{"enabled":false,"filepath":""}}}`, string(data))
|
||||||
var out OpenVPN
|
var out OpenVPN
|
||||||
err = json.Unmarshal(data, &out)
|
err = json.Unmarshal(data, &out)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,7 +23,7 @@ type Settings struct {
|
|||||||
HTTPProxy HTTPProxy
|
HTTPProxy HTTPProxy
|
||||||
ShadowSocks ShadowSocks
|
ShadowSocks ShadowSocks
|
||||||
PublicIPPeriod time.Duration
|
PublicIPPeriod time.Duration
|
||||||
UpdaterPeriod time.Duration
|
Updater Updater
|
||||||
VersionInformation bool
|
VersionInformation bool
|
||||||
ControlServer ControlServer
|
ControlServer ControlServer
|
||||||
}
|
}
|
||||||
@@ -34,10 +33,6 @@ func (s *Settings) String() string {
|
|||||||
if s.VersionInformation {
|
if s.VersionInformation {
|
||||||
versionInformation = enabled
|
versionInformation = enabled
|
||||||
}
|
}
|
||||||
updaterLine := "Updater: disabled"
|
|
||||||
if s.UpdaterPeriod > 0 {
|
|
||||||
updaterLine = fmt.Sprintf("Updater period: %s", s.UpdaterPeriod)
|
|
||||||
}
|
|
||||||
return strings.Join([]string{
|
return strings.Join([]string{
|
||||||
"Settings summary below:",
|
"Settings summary below:",
|
||||||
s.OpenVPN.String(),
|
s.OpenVPN.String(),
|
||||||
@@ -47,9 +42,9 @@ func (s *Settings) String() string {
|
|||||||
s.HTTPProxy.String(),
|
s.HTTPProxy.String(),
|
||||||
s.ShadowSocks.String(),
|
s.ShadowSocks.String(),
|
||||||
s.ControlServer.String(),
|
s.ControlServer.String(),
|
||||||
|
s.Updater.String(),
|
||||||
"Public IP check period: " + s.PublicIPPeriod.String(), // TODO print disabled if 0
|
"Public IP check period: " + s.PublicIPPeriod.String(), // TODO print disabled if 0
|
||||||
"Version information: " + versionInformation,
|
"Version information: " + versionInformation,
|
||||||
updaterLine,
|
|
||||||
"", // new line at the end
|
"", // new line at the end
|
||||||
}, "\n")
|
}, "\n")
|
||||||
}
|
}
|
||||||
@@ -93,7 +88,7 @@ func GetAllSettings(paramsReader params.Reader) (settings Settings, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return settings, err
|
return settings, err
|
||||||
}
|
}
|
||||||
settings.UpdaterPeriod, err = paramsReader.GetUpdaterPeriod()
|
settings.Updater, err = GetUpdaterSettings(paramsReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return settings, err
|
return settings, err
|
||||||
}
|
}
|
||||||
|
|||||||
59
internal/settings/updater.go
Normal file
59
internal/settings/updater.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/params"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Updater struct {
|
||||||
|
Period time.Duration `json:"period"`
|
||||||
|
DNSAddress string `json:"dns_address"`
|
||||||
|
Cyberghost bool `json:"cyberghost"`
|
||||||
|
Mullvad bool `json:"mullvad"`
|
||||||
|
Nordvpn bool `json:"nordvpn"`
|
||||||
|
PIA bool `json:"pia"`
|
||||||
|
Privado bool `json:"privado"`
|
||||||
|
Purevpn bool `json:"purevpn"`
|
||||||
|
Surfshark bool `json:"surfshark"`
|
||||||
|
Vyprvpn bool `json:"vyprvpn"`
|
||||||
|
Windscribe bool `json:"windscribe"`
|
||||||
|
// The two below should be used in CLI mode only
|
||||||
|
Stdout bool `json:"-"` // in order to update constants file (maintainer side)
|
||||||
|
CLI bool `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpdaterSettings obtains the server updater settings using the params functions.
|
||||||
|
func GetUpdaterSettings(paramsReader params.Reader) (settings Updater, err error) {
|
||||||
|
settings = Updater{
|
||||||
|
Cyberghost: true,
|
||||||
|
Mullvad: true,
|
||||||
|
Nordvpn: true,
|
||||||
|
PIA: true,
|
||||||
|
Purevpn: true,
|
||||||
|
Surfshark: true,
|
||||||
|
Vyprvpn: true,
|
||||||
|
Windscribe: true,
|
||||||
|
Stdout: false,
|
||||||
|
CLI: false,
|
||||||
|
DNSAddress: "127.0.0.1",
|
||||||
|
}
|
||||||
|
settings.Period, err = paramsReader.GetUpdaterPeriod()
|
||||||
|
if err != nil {
|
||||||
|
return settings, err
|
||||||
|
}
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Updater) String() string {
|
||||||
|
if s.Period == 0 {
|
||||||
|
return "Server updater settings: disabled"
|
||||||
|
}
|
||||||
|
settingsList := []string{
|
||||||
|
"Server updater settings:",
|
||||||
|
fmt.Sprintf("Period: %s", s.Period),
|
||||||
|
}
|
||||||
|
return strings.Join(settingsList, "\n|--")
|
||||||
|
}
|
||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
"github.com/qdm12/gluetun/internal/storage"
|
"github.com/qdm12/gluetun/internal/storage"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
)
|
)
|
||||||
@@ -14,60 +16,54 @@ import (
|
|||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
||||||
Restart()
|
GetStatus() (status models.LoopStatus)
|
||||||
Stop()
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
GetPeriod() (period time.Duration)
|
GetSettings() (settings settings.Updater)
|
||||||
SetPeriod(period time.Duration)
|
SetSettings(settings settings.Updater) (outcome string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
period time.Duration
|
state state
|
||||||
periodMutex sync.RWMutex
|
// Objects
|
||||||
updater Updater
|
updater Updater
|
||||||
storage storage.Storage
|
storage storage.Storage
|
||||||
setAllServers func(allServers models.AllServers)
|
setAllServers func(allServers models.AllServers)
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
restart chan struct{}
|
// Internal channels and locks
|
||||||
stop chan struct{}
|
loopLock sync.Mutex
|
||||||
updateTicker chan struct{}
|
start chan struct{}
|
||||||
timeNow func() time.Time
|
running chan models.LoopStatus
|
||||||
timeSince func(time.Time) time.Duration
|
stop chan struct{}
|
||||||
|
stopped chan struct{}
|
||||||
|
updateTicker chan struct{}
|
||||||
|
// Mock functions
|
||||||
|
timeNow func() time.Time
|
||||||
|
timeSince func(time.Time) time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLooper(options Options, period time.Duration, currentServers models.AllServers,
|
func NewLooper(settings settings.Updater, currentServers models.AllServers,
|
||||||
storage storage.Storage, setAllServers func(allServers models.AllServers),
|
storage storage.Storage, setAllServers func(allServers models.AllServers),
|
||||||
client *http.Client, logger logging.Logger) Looper {
|
client *http.Client, logger logging.Logger) Looper {
|
||||||
loggerWithPrefix := logger.WithPrefix("updater: ")
|
loggerWithPrefix := logger.WithPrefix("updater: ")
|
||||||
return &looper{
|
return &looper{
|
||||||
period: period,
|
state: state{
|
||||||
updater: New(options, client, currentServers, loggerWithPrefix),
|
status: constants.Stopped,
|
||||||
|
settings: settings,
|
||||||
|
},
|
||||||
|
updater: New(settings, client, currentServers, loggerWithPrefix),
|
||||||
storage: storage,
|
storage: storage,
|
||||||
setAllServers: setAllServers,
|
setAllServers: setAllServers,
|
||||||
logger: loggerWithPrefix,
|
logger: loggerWithPrefix,
|
||||||
restart: make(chan struct{}),
|
start: make(chan struct{}),
|
||||||
|
running: make(chan models.LoopStatus),
|
||||||
stop: make(chan struct{}),
|
stop: make(chan struct{}),
|
||||||
|
stopped: make(chan struct{}),
|
||||||
updateTicker: make(chan struct{}),
|
updateTicker: make(chan struct{}),
|
||||||
timeNow: time.Now,
|
timeNow: time.Now,
|
||||||
timeSince: time.Since,
|
timeSince: time.Since,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
|
||||||
func (l *looper) Stop() { l.stop <- struct{}{} }
|
|
||||||
|
|
||||||
func (l *looper) GetPeriod() (period time.Duration) {
|
|
||||||
l.periodMutex.RLock()
|
|
||||||
defer l.periodMutex.RUnlock()
|
|
||||||
return l.period
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) SetPeriod(period time.Duration) {
|
|
||||||
l.periodMutex.Lock()
|
|
||||||
l.period = period
|
|
||||||
l.periodMutex.Unlock()
|
|
||||||
l.updateTicker <- struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||||
l.logger.Error(err)
|
l.logger.Error(err)
|
||||||
const waitTime = 5 * time.Minute
|
const waitTime = 5 * time.Minute
|
||||||
@@ -84,52 +80,71 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
crashed := false
|
||||||
select {
|
select {
|
||||||
case <-l.restart:
|
case <-l.start:
|
||||||
l.logger.Info("starting...")
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer l.logger.Warn("loop exited")
|
defer l.logger.Warn("loop exited")
|
||||||
|
|
||||||
enabled := true
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
for !enabled {
|
updateCtx, updateCancel := context.WithCancel(ctx)
|
||||||
// wait for a signal to re-enable
|
defer updateCancel()
|
||||||
|
serversCh := make(chan models.AllServers)
|
||||||
|
errorCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
servers, err := l.updater.UpdateServers(updateCtx)
|
||||||
|
if err != nil {
|
||||||
|
errorCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serversCh <- servers
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !crashed {
|
||||||
|
l.running <- constants.Running
|
||||||
|
crashed = false
|
||||||
|
} else {
|
||||||
|
l.state.setStatusWithLock(constants.Running)
|
||||||
|
}
|
||||||
|
|
||||||
|
stayHere := true
|
||||||
|
for stayHere {
|
||||||
select {
|
select {
|
||||||
case <-l.stop:
|
|
||||||
l.logger.Info("already disabled")
|
|
||||||
case <-l.restart:
|
|
||||||
enabled = true
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
l.logger.Warn("context canceled: exiting loop")
|
||||||
|
updateCancel()
|
||||||
|
close(errorCh)
|
||||||
return
|
return
|
||||||
|
case <-l.start:
|
||||||
|
l.logger.Info("starting")
|
||||||
|
updateCancel()
|
||||||
|
stayHere = false
|
||||||
|
case <-l.stop:
|
||||||
|
l.logger.Info("stopping")
|
||||||
|
updateCancel()
|
||||||
|
<-errorCh
|
||||||
|
l.stopped <- struct{}{}
|
||||||
|
case servers := <-serversCh:
|
||||||
|
updateCancel()
|
||||||
|
close(serversCh)
|
||||||
|
l.setAllServers(servers)
|
||||||
|
if err := l.storage.FlushToFile(servers); err != nil {
|
||||||
|
l.logger.Error(err)
|
||||||
|
}
|
||||||
|
l.state.setStatusWithLock(constants.Completed)
|
||||||
|
l.logger.Info("Updated servers information")
|
||||||
|
case err := <-errorCh:
|
||||||
|
updateCancel()
|
||||||
|
close(serversCh)
|
||||||
|
l.state.setStatusWithLock(constants.Crashed)
|
||||||
|
l.logAndWait(ctx, err)
|
||||||
|
crashed = true
|
||||||
|
stayHere = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
close(errorCh)
|
||||||
// Enabled and has a period set
|
|
||||||
|
|
||||||
servers, err := l.updater.UpdateServers(ctx)
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
l.logAndWait(ctx, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
l.setAllServers(servers)
|
|
||||||
if err := l.storage.FlushToFile(servers); err != nil {
|
|
||||||
l.logger.Error(err)
|
|
||||||
}
|
|
||||||
l.logger.Info("Updated servers information")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-l.restart: // triggered restart
|
|
||||||
case <-l.stop:
|
|
||||||
enabled = false
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +153,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
timer := time.NewTimer(time.Hour)
|
timer := time.NewTimer(time.Hour)
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
timerIsStopped := true
|
timerIsStopped := true
|
||||||
if period := l.GetPeriod(); period > 0 {
|
if period := l.GetSettings().Period; period > 0 {
|
||||||
timerIsStopped = false
|
timerIsStopped = false
|
||||||
timer.Reset(period)
|
timer.Reset(period)
|
||||||
}
|
}
|
||||||
@@ -152,14 +167,14 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
return
|
return
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
lastTick = l.timeNow()
|
lastTick = l.timeNow()
|
||||||
l.restart <- struct{}{}
|
l.start <- struct{}{}
|
||||||
timer.Reset(l.GetPeriod())
|
timer.Reset(l.GetSettings().Period)
|
||||||
case <-l.updateTicker:
|
case <-l.updateTicker:
|
||||||
if !timerIsStopped && !timer.Stop() {
|
if !timerIsStopped && !timer.Stop() {
|
||||||
<-timer.C
|
<-timer.C
|
||||||
}
|
}
|
||||||
timerIsStopped = true
|
timerIsStopped = true
|
||||||
period := l.GetPeriod()
|
period := l.GetSettings().Period
|
||||||
if period == 0 {
|
if period == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
package updater
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
Cyberghost bool
|
|
||||||
Mullvad bool
|
|
||||||
Nordvpn bool
|
|
||||||
PIA bool
|
|
||||||
Privado bool
|
|
||||||
Purevpn bool
|
|
||||||
Surfshark bool
|
|
||||||
Vyprvpn bool
|
|
||||||
Windscribe bool
|
|
||||||
Stdout bool // in order to update constants file (maintainer side)
|
|
||||||
CLI bool
|
|
||||||
DNSAddress string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOptions(dnsAddress string) Options {
|
|
||||||
return Options{
|
|
||||||
Cyberghost: true,
|
|
||||||
Mullvad: true,
|
|
||||||
Nordvpn: true,
|
|
||||||
PIA: true,
|
|
||||||
Purevpn: true,
|
|
||||||
Surfshark: true,
|
|
||||||
Vyprvpn: true,
|
|
||||||
Windscribe: true,
|
|
||||||
Stdout: false,
|
|
||||||
CLI: false,
|
|
||||||
DNSAddress: dnsAddress,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
88
internal/updater/state.go
Normal file
88
internal/updater/state.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package updater
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type state struct {
|
||||||
|
status models.LoopStatus
|
||||||
|
settings settings.Updater
|
||||||
|
statusMu sync.RWMutex
|
||||||
|
periodMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *state) setStatusWithLock(status models.LoopStatus) {
|
||||||
|
s.statusMu.Lock()
|
||||||
|
defer s.statusMu.Unlock()
|
||||||
|
s.status = status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) GetStatus() (status models.LoopStatus) {
|
||||||
|
l.state.statusMu.RLock()
|
||||||
|
defer l.state.statusMu.RUnlock()
|
||||||
|
return l.state.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
defer l.state.statusMu.Unlock()
|
||||||
|
existingStatus := l.state.status
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case constants.Running:
|
||||||
|
switch existingStatus {
|
||||||
|
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
|
||||||
|
return fmt.Sprintf("already %s", existingStatus), nil
|
||||||
|
}
|
||||||
|
l.loopLock.Lock()
|
||||||
|
defer l.loopLock.Unlock()
|
||||||
|
l.state.status = constants.Starting
|
||||||
|
l.state.statusMu.Unlock()
|
||||||
|
l.start <- struct{}{}
|
||||||
|
newStatus := <-l.running
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
l.state.status = newStatus
|
||||||
|
return newStatus.String(), nil
|
||||||
|
case constants.Stopped:
|
||||||
|
switch existingStatus {
|
||||||
|
case constants.Stopped, constants.Stopping, constants.Starting, constants.Crashed:
|
||||||
|
return fmt.Sprintf("already %s", existingStatus), nil
|
||||||
|
}
|
||||||
|
l.loopLock.Lock()
|
||||||
|
defer l.loopLock.Unlock()
|
||||||
|
l.state.status = constants.Stopping
|
||||||
|
l.state.statusMu.Unlock()
|
||||||
|
l.stop <- struct{}{}
|
||||||
|
<-l.stopped
|
||||||
|
l.state.statusMu.Lock()
|
||||||
|
l.state.status = status
|
||||||
|
return status.String(), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("status %q can only be %q or %q",
|
||||||
|
status, constants.Running, constants.Stopped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) GetSettings() (settings settings.Updater) {
|
||||||
|
l.state.periodMu.RLock()
|
||||||
|
defer l.state.periodMu.RUnlock()
|
||||||
|
return l.state.settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) SetSettings(settings settings.Updater) (outcome string) {
|
||||||
|
l.state.periodMu.Lock()
|
||||||
|
defer l.state.periodMu.Unlock()
|
||||||
|
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
|
||||||
|
if settingsUnchanged {
|
||||||
|
return "settings left unchanged"
|
||||||
|
}
|
||||||
|
l.state.settings = settings
|
||||||
|
l.updateTicker <- struct{}{}
|
||||||
|
return "settings updated"
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
"github.com/qdm12/golibs/network"
|
"github.com/qdm12/golibs/network"
|
||||||
)
|
)
|
||||||
@@ -17,7 +18,7 @@ type Updater interface {
|
|||||||
|
|
||||||
type updater struct {
|
type updater struct {
|
||||||
// configuration
|
// configuration
|
||||||
options Options
|
options settings.Updater
|
||||||
|
|
||||||
// state
|
// state
|
||||||
servers models.AllServers
|
servers models.AllServers
|
||||||
@@ -30,11 +31,12 @@ type updater struct {
|
|||||||
client network.Client
|
client network.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(options Options, httpClient *http.Client, currentServers models.AllServers, logger logging.Logger) Updater {
|
func New(settings settings.Updater, httpClient *http.Client,
|
||||||
if len(options.DNSAddress) == 0 {
|
currentServers models.AllServers, logger logging.Logger) Updater {
|
||||||
options.DNSAddress = "1.1.1.1"
|
if len(settings.DNSAddress) == 0 {
|
||||||
|
settings.DNSAddress = "1.1.1.1"
|
||||||
}
|
}
|
||||||
resolver := newResolver(options.DNSAddress)
|
resolver := newResolver(settings.DNSAddress)
|
||||||
const clientTimeout = 10 * time.Second
|
const clientTimeout = 10 * time.Second
|
||||||
return &updater{
|
return &updater{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
@@ -42,7 +44,7 @@ func New(options Options, httpClient *http.Client, currentServers models.AllServ
|
|||||||
println: func(s string) { fmt.Println(s) },
|
println: func(s string) { fmt.Println(s) },
|
||||||
lookupIP: newLookupIP(resolver),
|
lookupIP: newLookupIP(resolver),
|
||||||
client: network.NewClient(clientTimeout),
|
client: network.NewClient(clientTimeout),
|
||||||
options: options,
|
options: settings,
|
||||||
servers: currentServers,
|
servers: currentServers,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user