Public IP getter loop refactored

This commit is contained in:
Quentin McGaw
2020-12-28 01:51:55 +00:00
parent 91f5338db0
commit db886163c2
11 changed files with 279 additions and 135 deletions

View File

@@ -47,7 +47,7 @@ ENV VPNSP=pia \
TZ= \ TZ= \
UID=1000 \ UID=1000 \
GID=1000 \ GID=1000 \
IP_STATUS_FILE="/tmp/gluetun/ip" \ PUBLICIP_FILE="/tmp/gluetun/ip" \
# PIA, Windscribe, Surfshark, Cyberghost, Vyprvpn, NordVPN, PureVPN only # PIA, Windscribe, Surfshark, Cyberghost, Vyprvpn, NordVPN, PureVPN only
USER= \ USER= \
PASSWORD= \ PASSWORD= \

View File

@@ -97,7 +97,7 @@ docker run --rm --network=container:gluetun alpine:3.12 wget -qO- https://ipinfo
| Variable | Default | Choices | Description | | Variable | Default | Choices | Description |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| 🏁 `VPNSP` | `private internet access` | `private internet access`, `mullvad`, `windscribe`, `surfshark`, `vyprvpn`, `nordvpn`, `purevpn`, `privado` | VPN Service Provider | | 🏁 `VPNSP` | `private internet access` | `private internet access`, `mullvad`, `windscribe`, `surfshark`, `vyprvpn`, `nordvpn`, `purevpn`, `privado` | VPN Service Provider |
| `IP_STATUS_FILE` | `/tmp/gluetun/ip` | Any filepath | Filepath to store the public IP address assigned | | `PUBLICIP_FILE` | `/tmp/gluetun/ip` | Any filepath | Filepath to store the public IP address assigned |
| `PROTOCOL` | `udp` | `udp` or `tcp` | Network protocol to use | | `PROTOCOL` | `udp` | `udp` or `tcp` | Network protocol to use |
| `OPENVPN_VERBOSITY` | `1` | `0` to `6` | Openvpn verbosity level | | `OPENVPN_VERBOSITY` | `1` | `0` to `6` | Openvpn verbosity level |
| `OPENVPN_ROOT` | `no` | `yes` or `no` | Run OpenVPN as root | | `OPENVPN_ROOT` | `no` | `yes` or `no` | Run OpenVPN as root |

View File

@@ -235,13 +235,12 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg, signalDNSReady) go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper(client, logger, fileManager, publicIPLooper := publicip.NewLooper(
allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid) client, logger, fileManager, allSettings.PublicIP, uid, gid)
wg.Add(1) wg.Add(1)
go publicIPLooper.Run(ctx, wg) go publicIPLooper.Run(ctx, wg)
wg.Add(1) wg.Add(1)
go publicIPLooper.RunRestartTicker(ctx, wg) go publicIPLooper.RunRestartTicker(ctx, wg)
publicIPLooper.SetPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker
httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy) httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy)
wg.Add(1) wg.Add(1)
@@ -294,11 +293,6 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
case <-ctx.Done(): case <-ctx.Done():
logger.Warn("context canceled, shutting down") logger.Warn("context canceled, shutting down")
} }
logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath)
if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil {
logger.Error(err)
shutdownErrorsCount++
}
if allSettings.OpenVPN.Provider.PortForwarding.Enabled { if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath) logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
@@ -425,7 +419,8 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
startPortForward(vpnGateway) startPortForward(vpnGateway)
} }
case <-dnsReadyCh: case <-dnsReadyCh:
publicIPLooper.Restart() // TODO do not restart if disabled // Runs the Public IP getter job once
_, _ = publicIPLooper.SetStatus(constants.Running)
if !versionInformation { if !versionInformation {
break break
} }

View File

@@ -37,7 +37,7 @@ type Reader interface {
GetUID() (uid int, err error) GetUID() (uid int, err error)
GetGID() (gid int, err error) GetGID() (gid int, err error)
GetTimezone() (timezone string, err error) GetTimezone() (timezone string, err error)
GetIPStatusFilepath() (filepath models.Filepath, err error) GetPublicIPFilepath() (filepath models.Filepath, err error)
// Firewall getters // Firewall getters
GetFirewall() (enabled bool, err error) GetFirewall() (enabled bool, err error)

View File

@@ -3,6 +3,7 @@ package params
import ( import (
"time" "time"
"github.com/qdm12/gluetun/internal/models"
libparams "github.com/qdm12/golibs/params" libparams "github.com/qdm12/golibs/params"
) )
@@ -15,3 +16,13 @@ func (r *reader) GetPublicIPPeriod() (period time.Duration, err error) {
} }
return time.ParseDuration(s) return time.ParseDuration(s)
} }
// GetPublicIPFilepath obtains the public IP filepath
// from the environment variable PUBLICIP_FILE with retro-compatible
// environment variable IP_STATUS_FILE.
func (r *reader) GetPublicIPFilepath() (filepath models.Filepath, err error) {
filepathStr, err := r.envParams.GetPath("PUBLICIP_FILE",
libparams.RetroKeys([]string{"IP_STATUS_FILE"}, r.onRetroActive),
libparams.Default("/tmp/gluetun/ip"), libparams.CaseSensitiveValue())
return models.Filepath(filepathStr), err
}

View File

@@ -1,7 +1,6 @@
package params package params
import ( import (
"github.com/qdm12/gluetun/internal/models"
libparams "github.com/qdm12/golibs/params" libparams "github.com/qdm12/golibs/params"
) )
@@ -19,11 +18,3 @@ func (r *reader) GetGID() (gid int, err error) {
func (r *reader) GetTimezone() (timezone string, err error) { func (r *reader) GetTimezone() (timezone string, err error) {
return r.envParams.GetEnv("TZ") return r.envParams.GetEnv("TZ")
} }
// GetIPStatusFilepath obtains the IP status file path
// from the environment variable IP_STATUS_FILE.
func (r *reader) GetIPStatusFilepath() (filepath models.Filepath, err error) {
filepathStr, err := r.envParams.GetPath("IP_STATUS_FILE",
libparams.Default("/tmp/gluetun/ip"), libparams.CaseSensitiveValue())
return models.Filepath(filepathStr), err
}

View File

@@ -6,7 +6,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/files" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
@@ -15,65 +17,57 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart() GetStatus() (status models.LoopStatus)
Stop() SetStatus(status models.LoopStatus) (outcome string, err error)
GetPeriod() (period time.Duration) GetSettings() (settings settings.PublicIP)
SetPeriod(period time.Duration) SetSettings(settings settings.PublicIP) (outcome string)
GetPublicIP() (publicIP net.IP) GetPublicIP() (publicIP net.IP)
} }
type looper struct { type looper struct {
period time.Duration state state
periodMutex sync.RWMutex // Objects
getter IPGetter getter IPGetter
logger logging.Logger logger logging.Logger
fileManager files.FileManager fileManager files.FileManager
ipMutex sync.RWMutex // Fixed settings
ip net.IP uid int
ipStatusFilepath models.Filepath gid int
uid int // Internal channels and locks
gid int loopLock sync.Mutex
restart chan struct{} start chan struct{}
stop chan struct{} running chan models.LoopStatus
updateTicker chan struct{} stop chan struct{}
timeNow func() time.Time stopped chan struct{}
timeSince func(time.Time) time.Duration updateTicker chan struct{}
// Mock functions
timeNow func() time.Time
timeSince func(time.Time) time.Duration
} }
func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager, func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager,
ipStatusFilepath models.Filepath, period time.Duration, uid, gid int) Looper { settings settings.PublicIP, uid, gid int) Looper {
return &looper{ return &looper{
period: period, state: state{
getter: NewIPGetter(client), status: constants.Stopped,
logger: logger.WithPrefix("ip getter: "), settings: settings,
fileManager: fileManager, },
ipStatusFilepath: ipStatusFilepath, // Objects
uid: uid, getter: NewIPGetter(client),
gid: gid, logger: logger.WithPrefix("ip getter: "),
restart: make(chan struct{}), fileManager: fileManager,
stop: make(chan struct{}), uid: uid,
updateTicker: make(chan struct{}), gid: gid,
timeNow: time.Now, start: make(chan struct{}),
timeSince: time.Since, running: make(chan models.LoopStatus),
stop: make(chan struct{}),
stopped: make(chan struct{}),
updateTicker: make(chan struct{}),
timeNow: time.Now,
timeSince: time.Since,
} }
} }
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) Stop() { l.stop <- struct{}{} }
func (l *looper) GetPeriod() (period time.Duration) {
l.periodMutex.RLock()
defer l.periodMutex.RUnlock()
return l.period
}
func (l *looper) SetPeriod(period time.Duration) {
l.periodMutex.Lock()
l.period = period
l.periodMutex.Unlock()
l.updateTicker <- struct{}{}
}
func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err) l.logger.Error(err)
const waitTime = 5 * time.Second const waitTime = 5 * time.Second
@@ -90,54 +84,84 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
crashed := false
select { select {
case <-l.restart: case <-l.start:
case <-ctx.Done(): case <-ctx.Done():
return return
} }
defer l.logger.Warn("loop exited") defer l.logger.Warn("loop exited")
enabled := true
for ctx.Err() == nil { for ctx.Err() == nil {
for !enabled { getCtx, getCancel := context.WithCancel(ctx)
// wait for a signal to re-enable defer getCancel()
select {
case <-l.stop: ipCh := make(chan net.IP)
l.logger.Info("already disabled") errorCh := make(chan error)
case <-l.restart: go func() {
enabled = true ip, err := l.getter.Get(getCtx)
case <-ctx.Done(): if err != nil {
errorCh <- err
return return
} }
ipCh <- ip
}()
if !crashed {
l.running <- constants.Running
crashed = false
} else {
l.state.setStatusWithLock(constants.Running)
} }
// Enabled and has a period set stayHere := true
for stayHere {
ip, err := l.getter.Get(ctx) select {
if err != nil { case <-ctx.Done():
l.logAndWait(ctx, err) l.logger.Warn("context canceled: exiting loop")
continue getCancel()
} close(errorCh)
l.setPublicIP(ip) filepath := l.GetSettings().IPFilepath
l.logger.Info("Public IP address is %s", ip) l.logger.Info("Removing ip file %s", filepath)
const userReadWritePermissions = 0600 if err := l.fileManager.Remove(string(filepath)); err != nil {
err = l.fileManager.WriteLinesToFile( l.logger.Error(err)
string(l.ipStatusFilepath), }
[]string{ip.String()}, return
files.Ownership(l.uid, l.gid), case <-l.start:
files.Permissions(userReadWritePermissions)) l.logger.Info("starting")
if err != nil { getCancel()
l.logAndWait(ctx, err) stayHere = false
continue case <-l.stop:
} l.logger.Info("stopping")
select { getCancel()
case <-l.restart: // triggered restart <-errorCh
case <-l.stop: l.stopped <- struct{}{}
enabled = false case ip := <-ipCh:
case <-ctx.Done(): getCancel()
return l.state.setPublicIP(ip)
l.logger.Info("Public IP address is %s", ip)
const userReadWritePermissions = 0600
err := l.fileManager.WriteLinesToFile(
string(l.state.settings.IPFilepath),
[]string{ip.String()},
files.Ownership(l.uid, l.gid),
files.Permissions(userReadWritePermissions))
if err != nil {
l.logger.Error(err)
}
l.state.setStatusWithLock(constants.Completed)
case err := <-errorCh:
getCancel()
close(ipCh)
l.state.setStatusWithLock(constants.Crashed)
l.logAndWait(ctx, err)
crashed = true
stayHere = false
}
} }
close(errorCh)
} }
} }
@@ -146,10 +170,9 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
timer := time.NewTimer(time.Hour) timer := time.NewTimer(time.Hour)
timer.Stop() // 1 hour, cannot be a race condition timer.Stop() // 1 hour, cannot be a race condition
timerIsStopped := true timerIsStopped := true
period := l.GetPeriod() if period := l.GetSettings().Period; period > 0 {
if period > 0 {
timer.Reset(period)
timerIsStopped = false timerIsStopped = false
timer.Reset(period)
} }
lastTick := time.Unix(0, 0) lastTick := time.Unix(0, 0)
for { for {
@@ -161,14 +184,14 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
return return
case <-timer.C: case <-timer.C:
lastTick = l.timeNow() lastTick = l.timeNow()
l.restart <- struct{}{} l.start <- struct{}{}
timer.Reset(l.GetPeriod()) timer.Reset(l.GetSettings().Period)
case <-l.updateTicker: case <-l.updateTicker:
if !timer.Stop() { if !timerIsStopped && !timer.Stop() {
<-timer.C <-timer.C
} }
timerIsStopped = true timerIsStopped = true
period := l.GetPeriod() period := l.GetSettings().Period
if period == 0 { if period == 0 {
continue continue
} }

View File

@@ -1,17 +1,110 @@
package publicip package publicip
import "net" import (
"fmt"
"net"
"reflect"
"sync"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
)
type state struct {
status models.LoopStatus
settings settings.PublicIP
ip net.IP
statusMu sync.RWMutex
settingsMu sync.RWMutex
ipMu sync.RWMutex
}
func (s *state) setStatusWithLock(status models.LoopStatus) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.status = status
}
func (l *looper) GetStatus() (status models.LoopStatus) {
l.state.statusMu.RLock()
defer l.state.statusMu.RUnlock()
return l.state.status
}
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
l.state.statusMu.Lock()
defer l.state.statusMu.Unlock()
existingStatus := l.state.status
switch status {
case constants.Running:
switch existingStatus {
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Starting
l.state.statusMu.Unlock()
l.start <- struct{}{}
newStatus := <-l.running
l.state.statusMu.Lock()
l.state.status = newStatus
return newStatus.String(), nil
case constants.Stopped:
switch existingStatus {
case constants.Stopped, constants.Stopping, constants.Starting, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Stopping
l.state.statusMu.Unlock()
l.stop <- struct{}{}
<-l.stopped
l.state.statusMu.Lock()
l.state.status = status
return status.String(), nil
default:
return "", fmt.Errorf("status %q can only be %q or %q",
status, constants.Running, constants.Stopped)
}
}
func (l *looper) GetSettings() (settings settings.PublicIP) {
l.state.settingsMu.RLock()
defer l.state.settingsMu.RUnlock()
return l.state.settings
}
func (l *looper) SetSettings(settings settings.PublicIP) (outcome string) {
l.state.settingsMu.Lock()
defer l.state.settingsMu.Unlock()
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
if settingsUnchanged {
return "settings left unchanged"
}
periodChanged := l.state.settings.Period != settings.Period
l.state.settings = settings
if periodChanged {
l.updateTicker <- struct{}{}
// TODO blocking
}
return "settings updated"
}
func (l *looper) GetPublicIP() (publicIP net.IP) { func (l *looper) GetPublicIP() (publicIP net.IP) {
l.ipMutex.RLock() l.state.ipMu.RLock()
defer l.ipMutex.RUnlock() defer l.state.ipMu.RUnlock()
publicIP = make(net.IP, len(l.ip)) publicIP = make(net.IP, len(l.state.ip))
copy(publicIP, l.ip) copy(publicIP, l.state.ip)
return publicIP return publicIP
} }
func (l *looper) setPublicIP(publicIP net.IP) { func (s *state) setPublicIP(publicIP net.IP) {
l.ipMutex.Lock() s.ipMu.Lock()
defer l.ipMutex.Unlock() defer s.ipMu.Unlock()
l.ip = publicIP s.ip = make(net.IP, len(publicIP))
copy(s.ip, publicIP)
} }

View File

@@ -0,0 +1,39 @@
package settings
import (
"fmt"
"strings"
"time"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/params"
)
type PublicIP struct {
Period time.Duration `json:"period"`
IPFilepath models.Filepath `json:"ip_filepath"`
}
func getPublicIPSettings(paramsReader params.Reader) (settings PublicIP, err error) {
settings.Period, err = paramsReader.GetPublicIPPeriod()
if err != nil {
return settings, err
}
settings.IPFilepath, err = paramsReader.GetPublicIPFilepath()
if err != nil {
return settings, err
}
return settings, nil
}
func (s *PublicIP) String() string {
if s.Period == 0 {
return "Public IP getter settings: disabled"
}
settingsList := []string{
"Public IP getter settings:",
fmt.Sprintf("Period: %s", s.Period),
fmt.Sprintf("IP file: %s", s.IPFilepath),
}
return strings.Join(settingsList, "\n|--")
}

View File

@@ -2,7 +2,6 @@ package settings
import ( import (
"strings" "strings"
"time"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/params"
@@ -22,8 +21,8 @@ type Settings struct {
Firewall Firewall Firewall Firewall
HTTPProxy HTTPProxy HTTPProxy HTTPProxy
ShadowSocks ShadowSocks ShadowSocks ShadowSocks
PublicIPPeriod time.Duration
Updater Updater Updater Updater
PublicIP PublicIP
VersionInformation bool VersionInformation bool
ControlServer ControlServer ControlServer ControlServer
} }
@@ -43,7 +42,7 @@ func (s *Settings) String() string {
s.ShadowSocks.String(), s.ShadowSocks.String(),
s.ControlServer.String(), s.ControlServer.String(),
s.Updater.String(), s.Updater.String(),
"Public IP check period: " + s.PublicIPPeriod.String(), // TODO print disabled if 0 s.PublicIP.String(),
"Version information: " + versionInformation, "Version information: " + versionInformation,
"", // new line at the end "", // new line at the end
}, "\n") }, "\n")
@@ -80,7 +79,7 @@ func GetAllSettings(paramsReader params.Reader) (settings Settings, err error) {
if err != nil { if err != nil {
return settings, err return settings, err
} }
settings.PublicIPPeriod, err = paramsReader.GetPublicIPPeriod() settings.PublicIP, err = getPublicIPSettings(paramsReader)
if err != nil { if err != nil {
return settings, err return settings, err
} }

View File

@@ -4,16 +4,14 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/params"
) )
// System contains settings to configure system related elements. // System contains settings to configure system related elements.
type System struct { type System struct {
UID int UID int
GID int GID int
Timezone string Timezone string
IPStatusFilepath models.Filepath
} }
// GetSystemSettings obtains the System settings using the params functions. // GetSystemSettings obtains the System settings using the params functions.
@@ -30,10 +28,6 @@ func GetSystemSettings(paramsReader params.Reader) (settings System, err error)
if err != nil { if err != nil {
return settings, err return settings, err
} }
settings.IPStatusFilepath, err = paramsReader.GetIPStatusFilepath()
if err != nil {
return settings, err
}
return settings, nil return settings, nil
} }
@@ -43,7 +37,6 @@ func (s *System) String() string {
fmt.Sprintf("User ID: %d", s.UID), fmt.Sprintf("User ID: %d", s.UID),
fmt.Sprintf("Group ID: %d", s.GID), fmt.Sprintf("Group ID: %d", s.GID),
fmt.Sprintf("Timezone: %s", s.Timezone), fmt.Sprintf("Timezone: %s", s.Timezone),
fmt.Sprintf("IP Status filepath: %s", s.IPStatusFilepath),
} }
return strings.Join(settingsList, "\n|--") return strings.Join(settingsList, "\n|--")
} }