From 99e386abc8a6c0b96ab9bfcddc876e748e332efd Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 8 Jul 2020 23:36:02 +0000 Subject: [PATCH] Using a waitgroup to wait for all programs to exit --- cmd/gluetun/main.go | 23 ++++++++--------------- internal/dns/loop.go | 11 +++++------ internal/openvpn/loop.go | 9 +++++---- internal/server/server.go | 8 +++++--- internal/shadowsocks/loop.go | 9 +++++---- internal/tinyproxy/loop.go | 9 +++++---- 6 files changed, 33 insertions(+), 36 deletions(-) diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index bc9752ec..17af242a 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -7,6 +7,7 @@ import ( "os/signal" "regexp" "strings" + "sync" "syscall" "time" @@ -164,29 +165,25 @@ func _main(background context.Context, args []string) int { restartPublicIP := make(chan struct{}) restartTinyproxy := make(chan struct{}) restartShadowsocks := make(chan struct{}) - openvpnDone := make(chan struct{}) - unboundDone := make(chan struct{}) - serverDone := make(chan struct{}) - tinyproxyDone := make(chan struct{}) - shadowsocksDone := make(chan struct{}) + wg := &sync.WaitGroup{} openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError, uid, gid) // wait for restartOpenvpn - go openvpnLooper.Run(ctx, restartOpenvpn, openvpnDone) + go openvpnLooper.Run(ctx, restartOpenvpn, wg) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid) // wait for restartUnbound - go unboundLooper.Run(ctx, restartUnbound, unboundDone) + go unboundLooper.Run(ctx, restartUnbound, wg) publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, uid, gid) go publicIPLooper.Run(ctx, restartPublicIP) go publicIPLooper.RunRestartTicker(ctx, restartPublicIP) tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid) - go tinyproxyLooper.Run(ctx, restartTinyproxy, tinyproxyDone) + go tinyproxyLooper.Run(ctx, restartTinyproxy, wg) shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid) - go shadowsocksLooper.Run(ctx, restartShadowsocks, shadowsocksDone) + go shadowsocksLooper.Run(ctx, restartShadowsocks, wg) if allSettings.TinyProxy.Enabled { <-restartTinyproxy @@ -218,7 +215,7 @@ func _main(background context.Context, args []string) int { }() httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound) - go httpServer.Run(ctx, serverDone) + go httpServer.Run(ctx, wg) // Start openvpn for the first time restartOpenvpn <- struct{}{} @@ -256,11 +253,7 @@ func _main(background context.Context, args []string) int { logger.Error(err) exitStatus = 1 } - <-serverDone - <-unboundDone - <-openvpnDone - <-tinyproxyDone - <-shadowsocksDone + wg.Wait() return exitStatus } diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 45d16405..d8301cf1 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -3,6 +3,7 @@ package dns import ( "context" "net" + "sync" "time" "github.com/qdm12/golibs/command" @@ -12,7 +13,7 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) + Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, restart chan<- struct{}) } @@ -43,12 +44,13 @@ func (l *looper) attemptingRestart(err error) { time.Sleep(10 * time.Second) } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) { +func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { + wg.Add(1) + defer wg.Done() l.fallbackToUnencryptedDNS() select { case <-restart: case <-ctx.Done(): - close(done) return } _, unboundCancel := context.WithCancel(ctx) @@ -59,13 +61,11 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- s case <-restart: case <-ctx.Done(): unboundCancel() - close(done) return } } if ctx.Err() == context.Canceled { unboundCancel() - close(done) return } @@ -121,7 +121,6 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- s l.logger.Warn("context canceled: exiting loop") unboundCancel() close(waitError) - close(done) return case <-restart: // triggered restart unboundCancel() diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index f7f24c7d..f83e48be 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -2,6 +2,7 @@ package openvpn import ( "context" + "sync" "time" "github.com/qdm12/golibs/command" @@ -11,7 +12,7 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) + Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) } type looper struct { @@ -37,11 +38,12 @@ func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logg } } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) { +func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { + wg.Add(1) + defer wg.Done() select { case <-restart: case <-ctx.Done(): - close(done) return } for { @@ -69,7 +71,6 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- s l.logger.Warn("context canceled: exiting loop") openvpnCancel() close(waitError) - close(done) return case <-restart: // triggered restart l.logger.Info("restarting") diff --git a/internal/server/server.go b/internal/server/server.go index f05432b3..77e34acd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "net/http" + "sync" "time" "github.com/qdm12/golibs/logging" ) type Server interface { - Run(ctx context.Context, serverDone chan struct{}) + Run(ctx context.Context, wg *sync.WaitGroup) } type server struct { @@ -29,10 +30,11 @@ func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound c } } -func (s *server) Run(ctx context.Context, serverDone chan struct{}) { +func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { + wg.Add(1) server := http.Server{Addr: s.address, Handler: s.makeHandler()} go func() { - defer close(serverDone) + defer wg.Done() <-ctx.Done() s.logger.Warn("context canceled: exiting loop") shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) diff --git a/internal/shadowsocks/loop.go b/internal/shadowsocks/loop.go index 031878ba..d4981755 100644 --- a/internal/shadowsocks/loop.go +++ b/internal/shadowsocks/loop.go @@ -2,6 +2,7 @@ package shadowsocks import ( "context" + "sync" "time" "github.com/qdm12/golibs/command" @@ -12,7 +13,7 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) + Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) } type looper struct { @@ -46,11 +47,12 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s } } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) { +func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { + wg.Add(1) + defer wg.Done() select { case <-restart: case <-ctx.Done(): - close(done) return } for { @@ -97,7 +99,6 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- s l.logger.Warn("context canceled: exiting loop") shadowsocksCancel() close(waitError) - close(done) return case <-restart: // triggered restart l.logger.Info("restarting") diff --git a/internal/tinyproxy/loop.go b/internal/tinyproxy/loop.go index 67e75b12..46b875c8 100644 --- a/internal/tinyproxy/loop.go +++ b/internal/tinyproxy/loop.go @@ -2,6 +2,7 @@ package tinyproxy import ( "context" + "sync" "time" "github.com/qdm12/golibs/command" @@ -12,7 +13,7 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) + Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) } type looper struct { @@ -44,11 +45,12 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s } } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) { +func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { + wg.Add(1) + defer wg.Done() select { case <-restart: case <-ctx.Done(): - close(done) return } for { @@ -89,7 +91,6 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- s l.logger.Warn("context canceled: exiting loop") tinyproxyCancel() close(waitError) - close(done) return case <-restart: // triggered restart l.logger.Info("restarting")