From 4218dba177674f4a9e8ac98f98fc5ee0da4f4ccc Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 18 May 2024 18:06:01 +0000 Subject: [PATCH] fix(publicip): abort ip data fetch if vpn context is canceled - Prevents requesting the public IP address N times after N VPN failures - Fetching runs with a context local to the 'single run' - Single run writes single run result to a channel back to the caller, RunOnce is now blocking --- internal/publicip/loop.go | 87 +++++++++++++++++++------------------- internal/vpn/interfaces.go | 2 +- internal/vpn/tunnelup.go | 7 ++- 3 files changed, 50 insertions(+), 46 deletions(-) diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index 73919f72..9741502d 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -2,7 +2,6 @@ package publicip import ( "context" - "errors" "fmt" "net/netip" "sync" @@ -10,7 +9,6 @@ import ( "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/gluetun/internal/publicip/api" ) type Loop struct { @@ -30,7 +28,8 @@ type Loop struct { // when performing an update runCtx context.Context //nolint:containedctx runCancel context.CancelFunc - runTrigger chan<- struct{} + runTrigger chan<- context.Context + runResult <-chan error updateTrigger chan<- settings.PublicIP updatedResult <-chan error runDone <-chan struct{} @@ -58,21 +57,23 @@ func (l *Loop) Start(_ context.Context) (_ <-chan error, err error) { l.runCtx, l.runCancel = context.WithCancel(context.Background()) runDone := make(chan struct{}) l.runDone = runDone - runTrigger := make(chan struct{}) + runTrigger := make(chan context.Context) l.runTrigger = runTrigger + runResult := make(chan error) + l.runResult = runResult updateTrigger := make(chan settings.PublicIP) l.updateTrigger = updateTrigger updatedResult := make(chan error) l.updatedResult = updatedResult - go l.run(l.runCtx, runDone, runTrigger, updateTrigger, updatedResult) + go l.run(l.runCtx, runDone, runTrigger, runResult, updateTrigger, updatedResult) return nil, nil //nolint:nilnil } func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, - runTrigger <-chan struct{}, updateTrigger <-chan settings.PublicIP, - updatedResult chan<- error) { + runTrigger <-chan context.Context, runResult chan<- error, + updateTrigger <-chan settings.PublicIP, updatedResult chan<- error) { defer close(runDone) timer := time.NewTimer(time.Hour) @@ -82,10 +83,14 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, lastFetch := time.Unix(0, 0) for { + singleRunCtx := runCtx + var singleRunResult chan<- error select { case <-runCtx.Done(): return - case <-runTrigger: + case singleRunCtx = <-runTrigger: + // Note singleRunCtx is canceled if runCtx is canceled. + singleRunResult = runResult case <-timer.C: timerIsReadyToReset = true case partialUpdate := <-updateTrigger: @@ -95,15 +100,17 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, continue } - result, err := l.fetchIPData(runCtx) - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return - } - lastFetch = l.timeNow() timerIsReadyToReset = l.updateTimer(*l.settings.Period, lastFetch, timer, timerIsReadyToReset) - if errors.Is(err, api.ErrTooManyRequests) { + result, err := l.fetcher.FetchInfo(singleRunCtx, netip.Addr{}) + if err != nil { + err = fmt.Errorf("fetching information: %w", err) + if singleRunResult != nil { + singleRunResult <- err + } else { + l.logger.Error(err.Error()) + } continue } @@ -117,42 +124,36 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, filepath := *l.settings.IPFilepath err = persistPublicIP(filepath, result.IP.String(), l.puid, l.pgid) - if err != nil { // non critical error, which can be fixed with settings updates. + if err != nil { + err = fmt.Errorf("persisting public ip address: %w", err) + } + + if singleRunResult != nil { + singleRunResult <- err + } else if err != nil { l.logger.Error(err.Error()) } } } -func (l *Loop) fetchIPData(ctx context.Context) (result models.PublicIP, err error) { - // keep retrying since settings updates won't change the - // behavior of the following code. - const defaultBackoffTime = 5 * time.Second - backoffTime := defaultBackoffTime - for { - result, err = l.fetcher.FetchInfo(ctx, netip.Addr{}) - switch { - case err == nil: - return result, nil - case ctx.Err() != nil: - return result, err - case errors.Is(err, api.ErrTooManyRequests): - l.logger.Warn(err.Error() + "; not retrying.") - return result, err - } - - l.logger.Error(fmt.Sprintf("%s - retrying in %s", err, backoffTime)) - select { - case <-ctx.Done(): - return result, ctx.Err() - case <-time.After(backoffTime): - } - const backoffTimeMultipler = 2 - backoffTime *= backoffTimeMultipler +func (l *Loop) RunOnce(ctx context.Context) (err error) { + singleRunCtx, singleRunCancel := context.WithCancel(ctx) + select { + case l.runTrigger <- singleRunCtx: + case <-ctx.Done(): // in case writing to run trigger is blocking + singleRunCancel() + return ctx.Err() } -} -func (l *Loop) StartSingleRun() { - l.runTrigger <- struct{}{} + select { + case err = <-l.runResult: + singleRunCancel() + return err + case <-l.runCtx.Done(): + singleRunCancel() + <-l.runResult + return l.runCtx.Err() + } } func (l *Loop) UpdateWith(partialUpdate settings.PublicIP) (err error) { diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 6fe8b4f6..6a4a145b 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -89,6 +89,6 @@ type DNSLoop interface { } type PublicIPLoop interface { - StartSingleRun() + RunOnce(ctx context.Context) (err error) ClearData() (err error) } diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 58a88551..5126b9c3 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -29,7 +29,10 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { _, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running) } - l.publicip.StartSingleRun() + err := l.publicip.RunOnce(ctx) + if err != nil { + l.logger.Error("getting public IP address information: " + err.Error()) + } if l.versionInfo { l.versionInfo = false // only get the version information once @@ -41,7 +44,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { } } - err := l.startPortForwarding(data) + err = l.startPortForwarding(data) if err != nil { l.logger.Error(err.Error()) }