Fix several async issues

- race conditions between ctx.Done and waitError channel
- Sleep for retry cancels on cancelation of context
- Stops the any loop at the start if the context was canceled
- Mentions when loops exit
- Wait for errors on triggered loop restarts
This commit is contained in:
Quentin McGaw
2020-07-11 20:59:30 +00:00
parent 1ac06ee4a8
commit ccf11990f1
5 changed files with 67 additions and 48 deletions

View File

@@ -38,10 +38,12 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
}
}
func (l *looper) attemptingRestart(err error) {
func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Warn(err)
l.logger.Info("attempting restart in 10 seconds")
time.Sleep(10 * time.Second)
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
<-ctx.Done()
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
@@ -53,8 +55,13 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-ctx.Done():
return
}
_, unboundCancel := context.WithCancel(ctx)
for {
defer l.logger.Warn("loop exited")
var unboundCtx context.Context
var unboundCancel context.CancelFunc = func() {}
var waitError chan error
triggeredRestart := false
for ctx.Err() == nil {
if !l.settings.Enabled {
// wait for another restart signal to recheck if it is enabled
select {
@@ -64,33 +71,34 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
return
}
}
if ctx.Err() == context.Canceled {
unboundCancel()
return
}
// Setup
if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil {
l.attemptingRestart(err)
l.logAndWait(ctx, err)
continue
}
if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil {
l.attemptingRestart(err)
l.logAndWait(ctx, err)
continue
}
if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil {
l.attemptingRestart(err)
l.logAndWait(ctx, err)
continue
}
// Start command
unboundCancel()
unboundCtx, unboundCancel := context.WithCancel(ctx)
if triggeredRestart {
triggeredRestart = false
unboundCancel()
<-waitError
close(waitError)
}
unboundCtx, unboundCancel = context.WithCancel(context.Background())
stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel)
if err != nil {
unboundCancel()
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
l.logAndWait(ctx, err)
continue
}
// Started successfully
@@ -98,16 +106,15 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
unboundCancel()
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
l.logger.Error(err)
}
if err := l.conf.WaitForUnbound(); err != nil {
unboundCancel()
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
l.logAndWait(ctx, err)
continue
}
waitError := make(chan error)
waitError = make(chan error)
go func() {
err := waitFn() // blocking
waitError <- err
@@ -122,16 +129,17 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
close(waitError)
return
case <-restart: // triggered restart
unboundCancel()
close(waitError)
l.logger.Info("restarting")
// unboundCancel occurs next loop run when the setup is complete
triggeredRestart = true
case err := <-waitError: // unexpected error
unboundCancel()
close(waitError)
unboundCancel()
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
l.logAndWait(ctx, err)
}
}
unboundCancel()
}
func (l *looper) fallbackToUnencryptedDNS() {

View File

@@ -36,10 +36,12 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F
}
}
func (l *looper) logAndWait(err error) {
func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err)
l.logger.Info("retrying in 5 seconds")
time.Sleep(5 * time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() // just for the linter
<-ctx.Done()
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
@@ -48,10 +50,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
case <-ctx.Done():
return
}
for {
defer l.logger.Warn("loop exited")
for ctx.Err() == nil {
ip, err := l.getter.Get()
if err != nil {
l.logAndWait(err)
l.logAndWait(ctx, err)
continue
}
l.logger.Info("Public IP address is %s", ip)
@@ -61,7 +65,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
files.Ownership(l.uid, l.gid),
files.Permissions(0600))
if err != nil {
l.logAndWait(err)
l.logAndWait(ctx, err)
continue
}
select {

View File

@@ -37,6 +37,7 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
<-ctx.Done()
s.logger.Warn("context canceled: exiting loop")
defer s.logger.Warn("loop exited")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {

View File

@@ -27,10 +27,12 @@ type looper struct {
gid int
}
func (l *looper) logAndWait(err error) {
func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err)
l.logger.Info("retrying in 1 minute")
time.Sleep(time.Minute)
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() // just for the linter
<-ctx.Done()
}
func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.ShadowSocks, dnsSettings settings.DNS,
@@ -55,7 +57,9 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-ctx.Done():
return
}
for {
defer l.logger.Warn("loop exited")
for ctx.Err() == nil {
nameserver := l.dnsSettings.PlaintextAddress.String()
if l.dnsSettings.Enabled {
nameserver = "127.0.0.1"
@@ -68,7 +72,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
l.uid,
l.gid)
if err != nil {
l.logAndWait(err)
l.logAndWait(ctx, err)
continue
}
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
@@ -76,11 +80,11 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
if err != nil {
l.logger.Error(err)
}
shadowsocksCtx, shadowsocksCancel := context.WithCancel(ctx)
stdout, stderr, waitFn, err := l.conf.Start(ctx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log)
shadowsocksCtx, shadowsocksCancel := context.WithCancel(context.Background())
stdout, stderr, waitFn, err := l.conf.Start(shadowsocksCtx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log)
if err != nil {
shadowsocksCancel()
l.logAndWait(err)
l.logAndWait(ctx, err)
continue
}
go l.streamMerger.Merge(shadowsocksCtx, stdout,
@@ -102,13 +106,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-restart: // triggered restart
l.logger.Info("restarting")
shadowsocksCancel()
<-waitError
close(waitError)
case err := <-waitError: // unexpected error
l.logger.Warn(err)
l.logger.Info("restarting")
shadowsocksCancel()
close(waitError)
time.Sleep(time.Second)
l.logAndWait(ctx, err)
}
}
}

View File

@@ -26,10 +26,12 @@ type looper struct {
gid int
}
func (l *looper) logAndWait(err error) {
func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err)
l.logger.Info("retrying in 1 minute")
time.Sleep(time.Minute)
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() // just for the linter
<-ctx.Done()
}
func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.TinyProxy,
@@ -53,7 +55,9 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-ctx.Done():
return
}
for {
defer l.logger.Warn("loop exited")
for ctx.Err() == nil {
err := l.conf.MakeConf(
l.settings.LogLevel,
l.settings.Port,
@@ -62,7 +66,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
l.uid,
l.gid)
if err != nil {
l.logAndWait(err)
l.logAndWait(ctx, err)
continue
}
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
@@ -70,11 +74,11 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
if err != nil {
l.logger.Error(err)
}
tinyproxyCtx, tinyproxyCancel := context.WithCancel(ctx)
tinyproxyCtx, tinyproxyCancel := context.WithCancel(context.Background())
stream, waitFn, err := l.conf.Start(tinyproxyCtx)
if err != nil {
tinyproxyCancel()
l.logAndWait(err)
l.logAndWait(ctx, err)
continue
}
go l.streamMerger.Merge(tinyproxyCtx, stream,
@@ -94,13 +98,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-restart: // triggered restart
l.logger.Info("restarting")
tinyproxyCancel()
<-waitError
close(waitError)
case err := <-waitError: // unexpected error
l.logger.Warn(err)
l.logger.Info("restarting")
tinyproxyCancel()
close(waitError)
time.Sleep(time.Second)
l.logAndWait(ctx, err)
}
}
}