diff --git a/internal/dns/loop.go b/internal/dns/loop.go index eae0ac7c..99aeb346 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -38,10 +38,12 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, } } -func (l *looper) attemptingRestart(err error) { +func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Warn(err) l.logger.Info("attempting restart in 10 seconds") - time.Sleep(10 * time.Second) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + <-ctx.Done() } func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { @@ -53,8 +55,13 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait case <-ctx.Done(): return } - _, unboundCancel := context.WithCancel(ctx) - for { + defer l.logger.Warn("loop exited") + + var unboundCtx context.Context + var unboundCancel context.CancelFunc = func() {} + var waitError chan error + triggeredRestart := false + for ctx.Err() == nil { if !l.settings.Enabled { // wait for another restart signal to recheck if it is enabled select { @@ -64,33 +71,34 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait return } } - if ctx.Err() == context.Canceled { - unboundCancel() - return - } // Setup if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil { - l.attemptingRestart(err) + l.logAndWait(ctx, err) continue } if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil { - l.attemptingRestart(err) + l.logAndWait(ctx, err) continue } if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil { - l.attemptingRestart(err) + l.logAndWait(ctx, err) continue } - // Start command - unboundCancel() - unboundCtx, unboundCancel := context.WithCancel(ctx) + if triggeredRestart { + triggeredRestart = false + unboundCancel() + <-waitError + close(waitError) + } + unboundCtx, unboundCancel = context.WithCancel(context.Background()) stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel) if err != nil { unboundCancel() l.fallbackToUnencryptedDNS() - l.attemptingRestart(err) + l.logAndWait(ctx, err) + continue } // Started successfully @@ -98,16 +106,15 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound())) l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound - unboundCancel() - l.fallbackToUnencryptedDNS() - l.attemptingRestart(err) + l.logger.Error(err) } if err := l.conf.WaitForUnbound(); err != nil { unboundCancel() l.fallbackToUnencryptedDNS() - l.attemptingRestart(err) + l.logAndWait(ctx, err) + continue } - waitError := make(chan error) + waitError = make(chan error) go func() { err := waitFn() // blocking waitError <- err @@ -122,16 +129,17 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait close(waitError) return case <-restart: // triggered restart - unboundCancel() - close(waitError) l.logger.Info("restarting") + // unboundCancel occurs next loop run when the setup is complete + triggeredRestart = true case err := <-waitError: // unexpected error - unboundCancel() close(waitError) + unboundCancel() l.fallbackToUnencryptedDNS() - l.attemptingRestart(err) + l.logAndWait(ctx, err) } } + unboundCancel() } func (l *looper) fallbackToUnencryptedDNS() { diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index cecb1031..be374288 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -36,10 +36,12 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F } } -func (l *looper) logAndWait(err error) { +func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Error(err) l.logger.Info("retrying in 5 seconds") - time.Sleep(5 * time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() // just for the linter + <-ctx.Done() } func (l *looper) Run(ctx context.Context, restart <-chan struct{}) { @@ -48,10 +50,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) { case <-ctx.Done(): return } - for { + defer l.logger.Warn("loop exited") + + for ctx.Err() == nil { ip, err := l.getter.Get() if err != nil { - l.logAndWait(err) + l.logAndWait(ctx, err) continue } l.logger.Info("Public IP address is %s", ip) @@ -61,7 +65,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) { files.Ownership(l.uid, l.gid), files.Permissions(0600)) if err != nil { - l.logAndWait(err) + l.logAndWait(ctx, err) continue } select { diff --git a/internal/server/server.go b/internal/server/server.go index 77e34acd..397c9cc0 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -37,6 +37,7 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() <-ctx.Done() s.logger.Warn("context canceled: exiting loop") + defer s.logger.Warn("loop exited") shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := server.Shutdown(shutdownCtx); err != nil { diff --git a/internal/shadowsocks/loop.go b/internal/shadowsocks/loop.go index 05e8fcde..6c398162 100644 --- a/internal/shadowsocks/loop.go +++ b/internal/shadowsocks/loop.go @@ -27,10 +27,12 @@ type looper struct { gid int } -func (l *looper) logAndWait(err error) { +func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Error(err) l.logger.Info("retrying in 1 minute") - time.Sleep(time.Minute) + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() // just for the linter + <-ctx.Done() } func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.ShadowSocks, dnsSettings settings.DNS, @@ -55,7 +57,9 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait case <-ctx.Done(): return } - for { + defer l.logger.Warn("loop exited") + + for ctx.Err() == nil { nameserver := l.dnsSettings.PlaintextAddress.String() if l.dnsSettings.Enabled { nameserver = "127.0.0.1" @@ -68,7 +72,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait l.uid, l.gid) if err != nil { - l.logAndWait(err) + l.logAndWait(ctx, err) continue } err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port) @@ -76,11 +80,11 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait if err != nil { l.logger.Error(err) } - shadowsocksCtx, shadowsocksCancel := context.WithCancel(ctx) - stdout, stderr, waitFn, err := l.conf.Start(ctx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log) + shadowsocksCtx, shadowsocksCancel := context.WithCancel(context.Background()) + stdout, stderr, waitFn, err := l.conf.Start(shadowsocksCtx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log) if err != nil { shadowsocksCancel() - l.logAndWait(err) + l.logAndWait(ctx, err) continue } go l.streamMerger.Merge(shadowsocksCtx, stdout, @@ -102,13 +106,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait case <-restart: // triggered restart l.logger.Info("restarting") shadowsocksCancel() + <-waitError close(waitError) case err := <-waitError: // unexpected error - l.logger.Warn(err) - l.logger.Info("restarting") shadowsocksCancel() close(waitError) - time.Sleep(time.Second) + l.logAndWait(ctx, err) } } } diff --git a/internal/tinyproxy/loop.go b/internal/tinyproxy/loop.go index b44e37ec..fbd9580f 100644 --- a/internal/tinyproxy/loop.go +++ b/internal/tinyproxy/loop.go @@ -26,10 +26,12 @@ type looper struct { gid int } -func (l *looper) logAndWait(err error) { +func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Error(err) l.logger.Info("retrying in 1 minute") - time.Sleep(time.Minute) + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() // just for the linter + <-ctx.Done() } func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.TinyProxy, @@ -53,7 +55,9 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait case <-ctx.Done(): return } - for { + defer l.logger.Warn("loop exited") + + for ctx.Err() == nil { err := l.conf.MakeConf( l.settings.LogLevel, l.settings.Port, @@ -62,7 +66,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait l.uid, l.gid) if err != nil { - l.logAndWait(err) + l.logAndWait(ctx, err) continue } err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port) @@ -70,11 +74,11 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait if err != nil { l.logger.Error(err) } - tinyproxyCtx, tinyproxyCancel := context.WithCancel(ctx) + tinyproxyCtx, tinyproxyCancel := context.WithCancel(context.Background()) stream, waitFn, err := l.conf.Start(tinyproxyCtx) if err != nil { tinyproxyCancel() - l.logAndWait(err) + l.logAndWait(ctx, err) continue } go l.streamMerger.Merge(tinyproxyCtx, stream, @@ -94,13 +98,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait case <-restart: // triggered restart l.logger.Info("restarting") tinyproxyCancel() + <-waitError close(waitError) case err := <-waitError: // unexpected error - l.logger.Warn(err) - l.logger.Info("restarting") tinyproxyCancel() close(waitError) - time.Sleep(time.Second) + l.logAndWait(ctx, err) } } }