diff --git a/Dockerfile b/Dockerfile index 4f14bd35..49a0075f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,7 +47,7 @@ ENV VPNSP=pia \ TZ= \ UID=1000 \ GID=1000 \ - IP_STATUS_FILE="/tmp/gluetun/ip" \ + PUBLICIP_FILE="/tmp/gluetun/ip" \ # PIA, Windscribe, Surfshark, Cyberghost, Vyprvpn, NordVPN, PureVPN only USER= \ PASSWORD= \ diff --git a/README.md b/README.md index 9627dc5c..3a2e28f3 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ docker run --rm --network=container:gluetun alpine:3.12 wget -qO- https://ipinfo | Variable | Default | Choices | Description | | --- | --- | --- | --- | | 🏁 `VPNSP` | `private internet access` | `private internet access`, `mullvad`, `windscribe`, `surfshark`, `vyprvpn`, `nordvpn`, `purevpn`, `privado` | VPN Service Provider | -| `IP_STATUS_FILE` | `/tmp/gluetun/ip` | Any filepath | Filepath to store the public IP address assigned | +| `PUBLICIP_FILE` | `/tmp/gluetun/ip` | Any filepath | Filepath to store the public IP address assigned | | `PROTOCOL` | `udp` | `udp` or `tcp` | Network protocol to use | | `OPENVPN_VERBOSITY` | `1` | `0` to `6` | Openvpn verbosity level | | `OPENVPN_ROOT` | `no` | `yes` or `no` | Run OpenVPN as root | diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 86187838..dc5a4f5a 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -235,13 +235,12 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker go unboundLooper.Run(ctx, wg, signalDNSReady) - publicIPLooper := publicip.NewLooper(client, logger, fileManager, - allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid) + publicIPLooper := publicip.NewLooper( + client, logger, fileManager, allSettings.PublicIP, uid, gid) wg.Add(1) go publicIPLooper.Run(ctx, wg) wg.Add(1) go publicIPLooper.RunRestartTicker(ctx, wg) - publicIPLooper.SetPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy) wg.Add(1) @@ -294,11 +293,6 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go case <-ctx.Done(): logger.Warn("context canceled, shutting down") } - logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath) - if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil { - logger.Error(err) - shutdownErrorsCount++ - } if allSettings.OpenVPN.Provider.PortForwarding.Enabled { logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath) if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { @@ -425,7 +419,8 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn startPortForward(vpnGateway) } case <-dnsReadyCh: - publicIPLooper.Restart() // TODO do not restart if disabled + // Runs the Public IP getter job once + _, _ = publicIPLooper.SetStatus(constants.Running) if !versionInformation { break } diff --git a/internal/params/params.go b/internal/params/params.go index b6c31397..c1b40709 100644 --- a/internal/params/params.go +++ b/internal/params/params.go @@ -37,7 +37,7 @@ type Reader interface { GetUID() (uid int, err error) GetGID() (gid int, err error) GetTimezone() (timezone string, err error) - GetIPStatusFilepath() (filepath models.Filepath, err error) + GetPublicIPFilepath() (filepath models.Filepath, err error) // Firewall getters GetFirewall() (enabled bool, err error) diff --git a/internal/params/publicip.go b/internal/params/publicip.go index 04845fa4..369008ab 100644 --- a/internal/params/publicip.go +++ b/internal/params/publicip.go @@ -3,6 +3,7 @@ package params import ( "time" + "github.com/qdm12/gluetun/internal/models" libparams "github.com/qdm12/golibs/params" ) @@ -15,3 +16,13 @@ func (r *reader) GetPublicIPPeriod() (period time.Duration, err error) { } return time.ParseDuration(s) } + +// GetPublicIPFilepath obtains the public IP filepath +// from the environment variable PUBLICIP_FILE with retro-compatible +// environment variable IP_STATUS_FILE. +func (r *reader) GetPublicIPFilepath() (filepath models.Filepath, err error) { + filepathStr, err := r.envParams.GetPath("PUBLICIP_FILE", + libparams.RetroKeys([]string{"IP_STATUS_FILE"}, r.onRetroActive), + libparams.Default("/tmp/gluetun/ip"), libparams.CaseSensitiveValue()) + return models.Filepath(filepathStr), err +} diff --git a/internal/params/system.go b/internal/params/system.go index ac82b928..6f5b60aa 100644 --- a/internal/params/system.go +++ b/internal/params/system.go @@ -1,7 +1,6 @@ package params import ( - "github.com/qdm12/gluetun/internal/models" libparams "github.com/qdm12/golibs/params" ) @@ -19,11 +18,3 @@ func (r *reader) GetGID() (gid int, err error) { func (r *reader) GetTimezone() (timezone string, err error) { return r.envParams.GetEnv("TZ") } - -// GetIPStatusFilepath obtains the IP status file path -// from the environment variable IP_STATUS_FILE. -func (r *reader) GetIPStatusFilepath() (filepath models.Filepath, err error) { - filepathStr, err := r.envParams.GetPath("IP_STATUS_FILE", - libparams.Default("/tmp/gluetun/ip"), libparams.CaseSensitiveValue()) - return models.Filepath(filepathStr), err -} diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index f4ef4bb6..c9eb92b3 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -6,7 +6,9 @@ import ( "sync" "time" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/network" @@ -15,65 +17,57 @@ import ( type Looper interface { Run(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) - Restart() - Stop() - GetPeriod() (period time.Duration) - SetPeriod(period time.Duration) + GetStatus() (status models.LoopStatus) + SetStatus(status models.LoopStatus) (outcome string, err error) + GetSettings() (settings settings.PublicIP) + SetSettings(settings settings.PublicIP) (outcome string) GetPublicIP() (publicIP net.IP) } type looper struct { - period time.Duration - periodMutex sync.RWMutex - getter IPGetter - logger logging.Logger - fileManager files.FileManager - ipMutex sync.RWMutex - ip net.IP - ipStatusFilepath models.Filepath - uid int - gid int - restart chan struct{} - stop chan struct{} - updateTicker chan struct{} - timeNow func() time.Time - timeSince func(time.Time) time.Duration + state state + // Objects + getter IPGetter + logger logging.Logger + fileManager files.FileManager + // Fixed settings + uid int + gid int + // Internal channels and locks + loopLock sync.Mutex + start chan struct{} + running chan models.LoopStatus + stop chan struct{} + stopped chan struct{} + updateTicker chan struct{} + // Mock functions + timeNow func() time.Time + timeSince func(time.Time) time.Duration } func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager, - ipStatusFilepath models.Filepath, period time.Duration, uid, gid int) Looper { + settings settings.PublicIP, uid, gid int) Looper { return &looper{ - period: period, - getter: NewIPGetter(client), - logger: logger.WithPrefix("ip getter: "), - fileManager: fileManager, - ipStatusFilepath: ipStatusFilepath, - uid: uid, - gid: gid, - restart: make(chan struct{}), - stop: make(chan struct{}), - updateTicker: make(chan struct{}), - timeNow: time.Now, - timeSince: time.Since, + state: state{ + status: constants.Stopped, + settings: settings, + }, + // Objects + getter: NewIPGetter(client), + logger: logger.WithPrefix("ip getter: "), + fileManager: fileManager, + uid: uid, + gid: gid, + start: make(chan struct{}), + running: make(chan models.LoopStatus), + stop: make(chan struct{}), + stopped: make(chan struct{}), + updateTicker: make(chan struct{}), + timeNow: time.Now, + timeSince: time.Since, } } -func (l *looper) Restart() { l.restart <- struct{}{} } -func (l *looper) Stop() { l.stop <- struct{}{} } - -func (l *looper) GetPeriod() (period time.Duration) { - l.periodMutex.RLock() - defer l.periodMutex.RUnlock() - return l.period -} - -func (l *looper) SetPeriod(period time.Duration) { - l.periodMutex.Lock() - l.period = period - l.periodMutex.Unlock() - l.updateTicker <- struct{}{} -} - func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Error(err) const waitTime = 5 * time.Second @@ -90,54 +84,84 @@ func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() + + crashed := false + select { - case <-l.restart: + case <-l.start: case <-ctx.Done(): return } defer l.logger.Warn("loop exited") - enabled := true - for ctx.Err() == nil { - for !enabled { - // wait for a signal to re-enable - select { - case <-l.stop: - l.logger.Info("already disabled") - case <-l.restart: - enabled = true - case <-ctx.Done(): + getCtx, getCancel := context.WithCancel(ctx) + defer getCancel() + + ipCh := make(chan net.IP) + errorCh := make(chan error) + go func() { + ip, err := l.getter.Get(getCtx) + if err != nil { + errorCh <- err return } + ipCh <- ip + }() + + if !crashed { + l.running <- constants.Running + crashed = false + } else { + l.state.setStatusWithLock(constants.Running) } - // Enabled and has a period set - - ip, err := l.getter.Get(ctx) - if err != nil { - l.logAndWait(ctx, err) - continue - } - l.setPublicIP(ip) - l.logger.Info("Public IP address is %s", ip) - const userReadWritePermissions = 0600 - err = l.fileManager.WriteLinesToFile( - string(l.ipStatusFilepath), - []string{ip.String()}, - files.Ownership(l.uid, l.gid), - files.Permissions(userReadWritePermissions)) - if err != nil { - l.logAndWait(ctx, err) - continue - } - select { - case <-l.restart: // triggered restart - case <-l.stop: - enabled = false - case <-ctx.Done(): - return + stayHere := true + for stayHere { + select { + case <-ctx.Done(): + l.logger.Warn("context canceled: exiting loop") + getCancel() + close(errorCh) + filepath := l.GetSettings().IPFilepath + l.logger.Info("Removing ip file %s", filepath) + if err := l.fileManager.Remove(string(filepath)); err != nil { + l.logger.Error(err) + } + return + case <-l.start: + l.logger.Info("starting") + getCancel() + stayHere = false + case <-l.stop: + l.logger.Info("stopping") + getCancel() + <-errorCh + l.stopped <- struct{}{} + case ip := <-ipCh: + getCancel() + l.state.setPublicIP(ip) + l.logger.Info("Public IP address is %s", ip) + const userReadWritePermissions = 0600 + err := l.fileManager.WriteLinesToFile( + string(l.state.settings.IPFilepath), + []string{ip.String()}, + files.Ownership(l.uid, l.gid), + files.Permissions(userReadWritePermissions)) + if err != nil { + l.logger.Error(err) + } + l.state.setStatusWithLock(constants.Completed) + case err := <-errorCh: + getCancel() + close(ipCh) + l.state.setStatusWithLock(constants.Crashed) + l.logAndWait(ctx, err) + crashed = true + stayHere = false + } } + close(errorCh) } } @@ -146,10 +170,9 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { timer := time.NewTimer(time.Hour) timer.Stop() // 1 hour, cannot be a race condition timerIsStopped := true - period := l.GetPeriod() - if period > 0 { - timer.Reset(period) + if period := l.GetSettings().Period; period > 0 { timerIsStopped = false + timer.Reset(period) } lastTick := time.Unix(0, 0) for { @@ -161,14 +184,14 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { return case <-timer.C: lastTick = l.timeNow() - l.restart <- struct{}{} - timer.Reset(l.GetPeriod()) + l.start <- struct{}{} + timer.Reset(l.GetSettings().Period) case <-l.updateTicker: - if !timer.Stop() { + if !timerIsStopped && !timer.Stop() { <-timer.C } timerIsStopped = true - period := l.GetPeriod() + period := l.GetSettings().Period if period == 0 { continue } diff --git a/internal/publicip/state.go b/internal/publicip/state.go index 032e571a..6f6a3cfe 100644 --- a/internal/publicip/state.go +++ b/internal/publicip/state.go @@ -1,17 +1,110 @@ package publicip -import "net" +import ( + "fmt" + "net" + "reflect" + "sync" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/settings" +) + +type state struct { + status models.LoopStatus + settings settings.PublicIP + ip net.IP + statusMu sync.RWMutex + settingsMu sync.RWMutex + ipMu sync.RWMutex +} + +func (s *state) setStatusWithLock(status models.LoopStatus) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.status = status +} + +func (l *looper) GetStatus() (status models.LoopStatus) { + l.state.statusMu.RLock() + defer l.state.statusMu.RUnlock() + return l.state.status +} + +func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) { + l.state.statusMu.Lock() + defer l.state.statusMu.Unlock() + existingStatus := l.state.status + + switch status { + case constants.Running: + switch existingStatus { + case constants.Starting, constants.Running, constants.Stopping, constants.Crashed: + return fmt.Sprintf("already %s", existingStatus), nil + } + l.loopLock.Lock() + defer l.loopLock.Unlock() + l.state.status = constants.Starting + l.state.statusMu.Unlock() + l.start <- struct{}{} + newStatus := <-l.running + l.state.statusMu.Lock() + l.state.status = newStatus + return newStatus.String(), nil + case constants.Stopped: + switch existingStatus { + case constants.Stopped, constants.Stopping, constants.Starting, constants.Crashed: + return fmt.Sprintf("already %s", existingStatus), nil + } + l.loopLock.Lock() + defer l.loopLock.Unlock() + l.state.status = constants.Stopping + l.state.statusMu.Unlock() + l.stop <- struct{}{} + <-l.stopped + l.state.statusMu.Lock() + l.state.status = status + return status.String(), nil + default: + return "", fmt.Errorf("status %q can only be %q or %q", + status, constants.Running, constants.Stopped) + } +} + +func (l *looper) GetSettings() (settings settings.PublicIP) { + l.state.settingsMu.RLock() + defer l.state.settingsMu.RUnlock() + return l.state.settings +} + +func (l *looper) SetSettings(settings settings.PublicIP) (outcome string) { + l.state.settingsMu.Lock() + defer l.state.settingsMu.Unlock() + settingsUnchanged := reflect.DeepEqual(settings, l.state.settings) + if settingsUnchanged { + return "settings left unchanged" + } + periodChanged := l.state.settings.Period != settings.Period + l.state.settings = settings + if periodChanged { + l.updateTicker <- struct{}{} + // TODO blocking + } + return "settings updated" +} func (l *looper) GetPublicIP() (publicIP net.IP) { - l.ipMutex.RLock() - defer l.ipMutex.RUnlock() - publicIP = make(net.IP, len(l.ip)) - copy(publicIP, l.ip) + l.state.ipMu.RLock() + defer l.state.ipMu.RUnlock() + publicIP = make(net.IP, len(l.state.ip)) + copy(publicIP, l.state.ip) return publicIP } -func (l *looper) setPublicIP(publicIP net.IP) { - l.ipMutex.Lock() - defer l.ipMutex.Unlock() - l.ip = publicIP +func (s *state) setPublicIP(publicIP net.IP) { + s.ipMu.Lock() + defer s.ipMu.Unlock() + s.ip = make(net.IP, len(publicIP)) + copy(s.ip, publicIP) } diff --git a/internal/settings/publicip.go b/internal/settings/publicip.go new file mode 100644 index 00000000..93e8bbef --- /dev/null +++ b/internal/settings/publicip.go @@ -0,0 +1,39 @@ +package settings + +import ( + "fmt" + "strings" + "time" + + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/params" +) + +type PublicIP struct { + Period time.Duration `json:"period"` + IPFilepath models.Filepath `json:"ip_filepath"` +} + +func getPublicIPSettings(paramsReader params.Reader) (settings PublicIP, err error) { + settings.Period, err = paramsReader.GetPublicIPPeriod() + if err != nil { + return settings, err + } + settings.IPFilepath, err = paramsReader.GetPublicIPFilepath() + if err != nil { + return settings, err + } + return settings, nil +} + +func (s *PublicIP) String() string { + if s.Period == 0 { + return "Public IP getter settings: disabled" + } + settingsList := []string{ + "Public IP getter settings:", + fmt.Sprintf("Period: %s", s.Period), + fmt.Sprintf("IP file: %s", s.IPFilepath), + } + return strings.Join(settingsList, "\n|--") +} diff --git a/internal/settings/settings.go b/internal/settings/settings.go index 97122428..a299f9d0 100644 --- a/internal/settings/settings.go +++ b/internal/settings/settings.go @@ -2,7 +2,6 @@ package settings import ( "strings" - "time" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/params" @@ -22,8 +21,8 @@ type Settings struct { Firewall Firewall HTTPProxy HTTPProxy ShadowSocks ShadowSocks - PublicIPPeriod time.Duration Updater Updater + PublicIP PublicIP VersionInformation bool ControlServer ControlServer } @@ -43,7 +42,7 @@ func (s *Settings) String() string { s.ShadowSocks.String(), s.ControlServer.String(), s.Updater.String(), - "Public IP check period: " + s.PublicIPPeriod.String(), // TODO print disabled if 0 + s.PublicIP.String(), "Version information: " + versionInformation, "", // new line at the end }, "\n") @@ -80,7 +79,7 @@ func GetAllSettings(paramsReader params.Reader) (settings Settings, err error) { if err != nil { return settings, err } - settings.PublicIPPeriod, err = paramsReader.GetPublicIPPeriod() + settings.PublicIP, err = getPublicIPSettings(paramsReader) if err != nil { return settings, err } diff --git a/internal/settings/system.go b/internal/settings/system.go index 7bae2014..3d51b16c 100644 --- a/internal/settings/system.go +++ b/internal/settings/system.go @@ -4,16 +4,14 @@ import ( "fmt" "strings" - "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/params" ) // System contains settings to configure system related elements. type System struct { - UID int - GID int - Timezone string - IPStatusFilepath models.Filepath + UID int + GID int + Timezone string } // GetSystemSettings obtains the System settings using the params functions. @@ -30,10 +28,6 @@ func GetSystemSettings(paramsReader params.Reader) (settings System, err error) if err != nil { return settings, err } - settings.IPStatusFilepath, err = paramsReader.GetIPStatusFilepath() - if err != nil { - return settings, err - } return settings, nil } @@ -43,7 +37,6 @@ func (s *System) String() string { fmt.Sprintf("User ID: %d", s.UID), fmt.Sprintf("Group ID: %d", s.GID), fmt.Sprintf("Timezone: %s", s.Timezone), - fmt.Sprintf("IP Status filepath: %s", s.IPStatusFilepath), } return strings.Join(settingsList, "\n|--") }