diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index c052c14e..249d18a1 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -9,7 +9,6 @@ import ( nativeos "os" "os/signal" "strings" - "sync" "syscall" "time" @@ -29,6 +28,7 @@ import ( "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/server" "github.com/qdm12/gluetun/internal/shadowsocks" + "github.com/qdm12/gluetun/internal/shutdown" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/unix" "github.com/qdm12/gluetun/internal/updater" @@ -255,44 +255,56 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } } // TODO move inside firewall? - wg := &sync.WaitGroup{} + const ( + shutdownMaxTimeout = 3 * time.Second + shutdownRoutineTimeout = 400 * time.Millisecond + shutdownOpenvpnTimeout = time.Second + ) + healthy := make(chan bool) + controlWave := shutdown.NewWave("control") + tickerWave := shutdown.NewWave("tickers") + healthWave := shutdown.NewWave("health") + dnsWave := shutdown.NewWave("DNS") + vpnWave := shutdown.NewWave("VPN") + serverWave := shutdown.NewWave("servers") openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh, healthy) - wg.Add(1) + openvpnCtx, openvpnDone := vpnWave.Add("openvpn", shutdownOpenvpnTimeout) // wait for restartOpenvpn - go openvpnLooper.Run(ctx, wg) + go openvpnLooper.Run(openvpnCtx, openvpnDone) updaterLooper := updater.NewLooper(allSettings.Updater, allServers, storage, openvpnLooper.SetServers, httpClient, logger) - wg.Add(1) + updaterCtx, updaterDone := tickerWave.Add("updater", shutdownRoutineTimeout) // wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker - go updaterLooper.Run(ctx, wg) + go updaterLooper.Run(updaterCtx, updaterDone) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient, logger, nonRootUsername, puid, pgid) - wg.Add(1) + dnsCtx, dnsDone := dnsWave.Add("unbound", shutdownRoutineTimeout) // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker - go unboundLooper.Run(ctx, wg) + go unboundLooper.Run(dnsCtx, dnsDone) publicIPLooper := publicip.NewLooper( httpClient, logger, allSettings.PublicIP, puid, pgid, os) - wg.Add(1) - go publicIPLooper.Run(ctx, wg) - wg.Add(1) - go publicIPLooper.RunRestartTicker(ctx, wg) + pubIPCtx, pubIPDone := serverWave.Add("public IP", shutdownRoutineTimeout) + go publicIPLooper.Run(pubIPCtx, pubIPDone) + + pubIPTickerCtx, pubIPTickerDone := tickerWave.Add("public IP", shutdownRoutineTimeout) + go publicIPLooper.RunRestartTicker(pubIPTickerCtx, pubIPTickerDone) httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy) - wg.Add(1) - go httpProxyLooper.Run(ctx, wg) + httpProxyCtx, httpProxyDone := serverWave.Add("http proxy", shutdownRoutineTimeout) + go httpProxyLooper.Run(httpProxyCtx, httpProxyDone) shadowsocksLooper := shadowsocks.NewLooper(allSettings.ShadowSocks, logger) - wg.Add(1) - go shadowsocksLooper.Run(ctx, wg) + shadowsocksCtx, shadowsocksDone := serverWave.Add("shadowsocks proxy", shutdownRoutineTimeout) + go shadowsocksLooper.Run(shadowsocksCtx, shadowsocksDone) - wg.Add(1) - go routeReadyEvents(ctx, wg, buildInfo, tunnelReadyCh, + eventsRoutingCtx, eventsRoutingDone := controlWave.Add("events routing", shutdownRoutineTimeout) + go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh, unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient, allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward, ) @@ -300,13 +312,17 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, controlServerLogging := allSettings.ControlServer.Log httpServer := server.New(controlServerAddress, controlServerLogging, logger, buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper) - wg.Add(1) - go httpServer.Run(ctx, wg) + httpServerCtx, httpServerDone := controlWave.Add("http server", shutdownRoutineTimeout) + go httpServer.Run(httpServerCtx, httpServerDone) - healthcheckServer := healthcheck.NewServer( - constants.HealthcheckAddress, logger) - wg.Add(1) - go healthcheckServer.Run(ctx, healthy, wg) + healthcheckServer := healthcheck.NewServer(constants.HealthcheckAddress, logger) + healthServerCtx, healthServerDone := healthWave.Add("HTTP health server", shutdownRoutineTimeout) + go healthcheckServer.Run(healthServerCtx, healthy, healthServerDone) + + shutdownOrder := shutdown.NewOrder() + shutdownOrder.Append(controlWave, tickerWave, healthWave, + dnsWave, vpnWave, serverWave, + ) // Start openvpn for the first time in a blocking call // until openvpn is launched @@ -321,6 +337,11 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } } + if err := shutdownOrder.Shutdown(shutdownMaxTimeout, logger); err != nil { + return err + } + + // Only disable firewall if everything has shutdown gracefully if allSettings.Firewall.Enabled { const enable = false err := firewallConf.SetEnabled(context.Background(), enable) @@ -329,8 +350,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } } - wg.Wait() - return nil } @@ -349,22 +368,29 @@ func printVersions(ctx context.Context, logger logging.Logger, } } -func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.BuildInformation, +func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo models.BuildInformation, tunnelReadyCh <-chan struct{}, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, routing routing.Routing, logger logging.Logger, httpClient *http.Client, versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) { - defer wg.Done() - tickerWg := &sync.WaitGroup{} + defer close(done) + // for linters only var restartTickerContext context.Context var restartTickerCancel context.CancelFunc = func() {} + + unboundTickerDone := make(chan struct{}) + close(unboundTickerDone) + updaterTickerDone := make(chan struct{}) + close(updaterTickerDone) + first := true for { select { case <-ctx.Done(): restartTickerCancel() // for linters only - tickerWg.Wait() + <-unboundTickerDone + <-updaterTickerDone return case <-tunnelReadyCh: // blocks until openvpn is connected vpnDestination, err := routing.VPNDestinationIP() @@ -379,7 +405,8 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models. } restartTickerCancel() // stop previous restart tickers - tickerWg.Wait() + <-unboundTickerDone + <-updaterTickerDone restartTickerContext, restartTickerCancel = context.WithCancel(ctx) // Runs the Public IP getter job once @@ -398,10 +425,10 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models. } } - //nolint:gomnd - tickerWg.Add(2) - go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) - go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) + unboundTickerDone = make(chan struct{}) + updaterTickerDone = make(chan struct{}) + go unboundLooper.RunRestartTicker(restartTickerContext, unboundTickerDone) + go updaterLooper.RunRestartTicker(restartTickerContext, updaterTickerDone) if portForwardingEnabled { // vpnGateway required only for PIA vpnGateway, err := routing.VPNLocalGatewayIP() diff --git a/internal/dns/logs.go b/internal/dns/logs.go index 30b6c157..d136ef06 100644 --- a/internal/dns/logs.go +++ b/internal/dns/logs.go @@ -3,14 +3,13 @@ package dns import ( "regexp" "strings" - "sync" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/golibs/logging" ) -func (l *looper) collectLines(wg *sync.WaitGroup, stdout, stderr <-chan string) { - defer wg.Done() +func (l *looper) collectLines(stdout, stderr <-chan string, done chan<- struct{}) { + defer close(done) var line string var ok bool for { diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 4f13c262..63183056 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -17,8 +17,8 @@ import ( ) type Looper interface { - Run(ctx context.Context, wg *sync.WaitGroup) - RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) + Run(ctx context.Context, done chan<- struct{}) + RunRestartTicker(ctx context.Context, done chan<- struct{}) GetStatus() (status models.LoopStatus) SetStatus(status models.LoopStatus) (outcome string, err error) GetSettings() (settings configuration.DNS) @@ -86,8 +86,8 @@ func (l *looper) logAndWait(ctx context.Context, err error) { } } -func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { + defer close(done) const fallback = false l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution? @@ -99,8 +99,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { return } - defer l.logger.Warn("loop exited") - crashed := false l.backoffTime = defaultBackoffTime @@ -116,11 +114,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { if !crashed { l.running <- constants.Stopped } - l.logger.Warn("context canceled: exiting loop") return } var err error - unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, wg, crashed) + unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, crashed) if err != nil { if !errors.Is(err, errUpdateFiles) { const fallback = true @@ -143,7 +140,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { for stayHere { select { case <-ctx.Done(): - l.logger.Warn("context canceled: exiting loop") unboundCancel() <-waitError close(waitError) @@ -178,9 +174,8 @@ var errUpdateFiles = errors.New("cannot update files") // Returning cancel == nil signals we want to re-run setupUnbound // Returning err == errUpdateFiles signals we should not fall back // on the plaintext DNS as DOT is still up and running. -func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup, - previousCrashed bool) (cancel context.CancelFunc, waitError chan error, - closeStreams func(), err error) { +func (l *looper) setupUnbound(ctx context.Context, previousCrashed bool) ( + cancel context.CancelFunc, waitError chan error, closeStreams func(), err error) { err = l.updateFiles(ctx) if err != nil { l.state.setStatusWithLock(constants.Crashed) @@ -199,8 +194,8 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup, return nil, nil, nil, err } - wg.Add(1) - go l.collectLines(wg, stdoutLines, stderrLines) + collectLinesDone := make(chan struct{}) + go l.collectLines(stdoutLines, stderrLines, collectLinesDone) l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound @@ -216,6 +211,7 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup, close(waitError) close(stdoutLines) close(stderrLines) + <-collectLinesDone return nil, nil, nil, err } @@ -230,6 +226,7 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup, closeStreams = func() { close(stdoutLines) close(stderrLines) + <-collectLinesDone } return cancel, waitError, closeStreams, nil @@ -276,8 +273,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) { l.logger.Error("no ipv4 DNS address found for providers %s", settings.Unbound.Providers) } -func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) { + defer close(done) // Timer that acts as a ticker timer := time.NewTimer(time.Hour) timer.Stop() diff --git a/internal/healthcheck/health.go b/internal/healthcheck/health.go index c51d3020..48f3741b 100644 --- a/internal/healthcheck/health.go +++ b/internal/healthcheck/health.go @@ -6,12 +6,11 @@ import ( "errors" "fmt" "net" - "sync" "time" ) -func (s *server) runHealthcheckLoop(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup) { - defer wg.Done() +func (s *server) runHealthcheckLoop(ctx context.Context, healthy chan<- bool, done chan<- struct{}) { + defer close(done) for { previousErr := s.handler.getErr() diff --git a/internal/healthcheck/server.go b/internal/healthcheck/server.go index 69b7c988..ff512bdd 100644 --- a/internal/healthcheck/server.go +++ b/internal/healthcheck/server.go @@ -5,14 +5,13 @@ import ( "errors" "net" "net/http" - "sync" "time" "github.com/qdm12/golibs/logging" ) type Server interface { - Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup) + Run(ctx context.Context, healthy chan<- bool, done chan<- struct{}) } type server struct { @@ -32,23 +31,20 @@ func NewServer(address string, logger logging.Logger) Server { } } -func (s *server) Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup) { - defer wg.Done() +func (s *server) Run(ctx context.Context, healthy chan<- bool, done chan<- struct{}) { + defer close(done) - internalWg := &sync.WaitGroup{} - internalWg.Add(1) - go s.runHealthcheckLoop(ctx, healthy, internalWg) + loopDone := make(chan struct{}) + go s.runHealthcheckLoop(ctx, healthy, loopDone) server := http.Server{ Addr: s.address, Handler: s.handler, } - internalWg.Add(1) + serverDone := make(chan struct{}) go func() { - defer internalWg.Done() + defer close(serverDone) <-ctx.Done() - s.logger.Warn("context canceled: shutting down server") - defer s.logger.Warn("server shut down") const shutdownGraceDuration = 2 * time.Second shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration) defer cancel() @@ -63,5 +59,6 @@ func (s *server) Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGrou s.logger.Error(err) } - internalWg.Wait() + <-loopDone + <-serverDone } diff --git a/internal/httpproxy/loop.go b/internal/httpproxy/loop.go index ac6c4873..5e7d7cf0 100644 --- a/internal/httpproxy/loop.go +++ b/internal/httpproxy/loop.go @@ -14,7 +14,7 @@ import ( ) type Looper interface { - Run(ctx context.Context, wg *sync.WaitGroup) + Run(ctx context.Context, done chan<- struct{}) SetStatus(status models.LoopStatus) (outcome string, err error) GetStatus() (status models.LoopStatus) GetSettings() (settings configuration.HTTPProxy) @@ -50,8 +50,8 @@ func NewLooper(logger logging.Logger, settings configuration.HTTPProxy) Looper { } } -func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { + defer close(done) crashed := false @@ -67,8 +67,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { return } - defer l.logger.Warn("loop exited") - for ctx.Err() == nil { runCtx, runCancel := context.WithCancel(ctx) @@ -76,10 +74,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { address := fmt.Sprintf(":%d", settings.Port) server := New(runCtx, address, l.logger, settings.Stealth, settings.Log, settings.User, settings.Password) - runWg := &sync.WaitGroup{} - runWg.Add(1) errorCh := make(chan error) - go server.Run(runCtx, runWg, errorCh) + go server.Run(runCtx, errorCh) // TODO stable timer, check Shadowsocks if !crashed { @@ -94,22 +90,20 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { for stayHere { select { case <-ctx.Done(): - l.logger.Warn("context canceled: exiting loop") runCancel() - runWg.Wait() + <-errorCh return case <-l.start: l.logger.Info("starting") runCancel() - runWg.Wait() + <-errorCh stayHere = false case <-l.stop: l.logger.Info("stopping") runCancel() - runWg.Wait() + <-errorCh l.stopped <- struct{}{} case err := <-errorCh: - runWg.Wait() l.state.setStatusWithLock(constants.Crashed) l.logAndWait(ctx, err) crashed = true diff --git a/internal/httpproxy/server.go b/internal/httpproxy/server.go index 0942d007..af5ebae6 100644 --- a/internal/httpproxy/server.go +++ b/internal/httpproxy/server.go @@ -10,7 +10,7 @@ import ( ) type Server interface { - Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- error) + Run(ctx context.Context, errorCh chan<- error) } type server struct { @@ -31,13 +31,10 @@ func New(ctx context.Context, address string, logger logging.Logger, } } -func (s *server) Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- error) { - defer wg.Done() +func (s *server) Run(ctx context.Context, errorCh chan<- error) { server := http.Server{Addr: s.address, Handler: s.handler} go func() { <-ctx.Done() - s.logger.Warn("shutting down server") - defer s.logger.Warn("server shut down") const shutdownGraceDuration = 2 * time.Second shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration) defer cancel() @@ -47,8 +44,10 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- err }() s.logger.Info("listening on %s", s.address) err := server.ListenAndServe() + s.internalWG.Wait() if err != nil && ctx.Err() == nil { errorCh <- err + } else { + errorCh <- nil } - s.internalWG.Wait() } diff --git a/internal/openvpn/logs.go b/internal/openvpn/logs.go index db6f6c45..76fb0f27 100644 --- a/internal/openvpn/logs.go +++ b/internal/openvpn/logs.go @@ -2,15 +2,14 @@ package openvpn import ( "strings" - "sync" "github.com/fatih/color" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/golibs/logging" ) -func (l *looper) collectLines(wg *sync.WaitGroup, stdout, stderr <-chan string) { - defer wg.Done() +func (l *looper) collectLines(stdout, stderr <-chan string, done chan<- struct{}) { + defer close(done) var line string var ok, errLine bool diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 38e54b50..998adbf6 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -19,7 +19,7 @@ import ( ) type Looper interface { - Run(ctx context.Context, wg *sync.WaitGroup) + Run(ctx context.Context, done chan<- struct{}) GetStatus() (status models.LoopStatus) SetStatus(status models.LoopStatus) (outcome string, err error) GetSettings() (settings configuration.OpenVPN) @@ -104,14 +104,13 @@ func (l *looper) signalCrashedStatus() { } } -func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocognit - defer wg.Done() +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocognit + defer close(done) select { case <-l.start: case <-ctx.Done(): return } - defer l.logger.Warn("loop exited") for ctx.Err() == nil { settings, allServers := l.state.getSettingsAndServers() @@ -166,20 +165,19 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocogni continue } - wg.Add(1) - go l.collectLines(wg, stdoutLines, stderrLines) + lineCollectionDone := make(chan struct{}) + go l.collectLines(stdoutLines, stderrLines, lineCollectionDone) // Needs the stream line from main.go to know when the tunnel is up + portForwardDone := make(chan struct{}) go func(ctx context.Context) { - for { - select { - // TODO have a way to disable pf with a context - case <-ctx.Done(): - return - case gateway := <-l.portForwardSignals: - wg.Add(1) - go l.portForward(ctx, wg, providerConf, l.client, gateway) - } + defer close(portForwardDone) + select { + // TODO have a way to disable pf with a context + case <-ctx.Done(): + return + case gateway := <-l.portForwardSignals: + l.portForward(ctx, providerConf, l.client, gateway) } }(openvpnCtx) @@ -195,12 +193,13 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocogni for stayHere { select { case <-ctx.Done(): - l.logger.Warn("context canceled: exiting loop") openvpnCancel() <-waitError close(waitError) close(stdoutLines) close(stderrLines) + <-lineCollectionDone + <-portForwardDone return case <-l.stop: l.logger.Info("stopping") @@ -288,9 +287,8 @@ func (l *looper) waitForHealth(ctx context.Context) (healthy bool) { // portForward is a blocking operation which may or may not be infinite. // You should therefore always call it in a goroutine. -func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup, +func (l *looper) portForward(ctx context.Context, providerConf provider.Provider, client *http.Client, gateway net.IP) { - defer wg.Done() l.state.portForwardedMu.RLock() settings := l.state.settings l.state.portForwardedMu.RUnlock() diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 3eef1e86..40943348 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -29,8 +29,6 @@ var ( func (p *PIA) PortForward(ctx context.Context, client *http.Client, openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { - defer logger.Warn("loop exited") - commonName := p.activeServer.ServerName if !p.activeServer.PortForward { logger.Error("The server " + commonName + diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index 5f7e4efc..074a4b3c 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -15,8 +15,8 @@ import ( ) type Looper interface { - Run(ctx context.Context, wg *sync.WaitGroup) - RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) + Run(ctx context.Context, done chan<- struct{}) + RunRestartTicker(ctx context.Context, done chan<- struct{}) GetStatus() (status models.LoopStatus) SetStatus(status models.LoopStatus) (outcome string, err error) GetSettings() (settings configuration.PublicIP) @@ -91,8 +91,8 @@ func (l *looper) logAndWait(ctx context.Context, err error) { } } -func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { + defer close(done) crashed := false @@ -101,7 +101,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { case <-ctx.Done(): return } - defer l.logger.Warn("loop exited") for ctx.Err() == nil { getCtx, getCancel := context.WithCancel(ctx) @@ -132,11 +131,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { for stayHere { select { case <-ctx.Done(): - l.logger.Warn("context canceled: exiting loop") getCancel() close(errorCh) filepath := l.GetSettings().IPFilepath - l.logger.Info("Removing ip file %s", filepath) + l.logger.Info("Removing ip file " + filepath) if err := l.os.Remove(filepath); err != nil { l.logger.Error(err) } @@ -181,8 +179,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { } } -func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) { + defer close(done) timer := time.NewTimer(time.Hour) timer.Stop() // 1 hour, cannot be a race condition timerIsStopped := true diff --git a/internal/server/server.go b/internal/server/server.go index 207ebf93..7779a298 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,7 +4,6 @@ package server import ( "context" "net/http" - "sync" "time" "github.com/qdm12/gluetun/internal/dns" @@ -16,7 +15,7 @@ import ( ) type Server interface { - Run(ctx context.Context, wg *sync.WaitGroup) + Run(ctx context.Context, done chan<- struct{}) } type server struct { @@ -39,12 +38,11 @@ func New(address string, logEnabled bool, logger logging.Logger, } } -func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (s *server) Run(ctx context.Context, done chan<- struct{}) { + defer close(done) server := http.Server{Addr: s.address, Handler: s.handler} go func() { <-ctx.Done() - s.logger.Warn("context canceled: shutting down") const shutdownGraceDuration = 2 * time.Second shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration) defer cancel() @@ -57,5 +55,4 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { if err != nil && ctx.Err() != context.Canceled { s.logger.Error(err) } - s.logger.Warn("shut down") } diff --git a/internal/shadowsocks/loop.go b/internal/shadowsocks/loop.go index a74a6c04..54d09fdb 100644 --- a/internal/shadowsocks/loop.go +++ b/internal/shadowsocks/loop.go @@ -15,7 +15,7 @@ import ( ) type Looper interface { - Run(ctx context.Context, wg *sync.WaitGroup) + Run(ctx context.Context, done chan<- struct{}) SetStatus(status models.LoopStatus) (outcome string, err error) GetStatus() (status models.LoopStatus) GetSettings() (settings configuration.ShadowSocks) @@ -67,8 +67,8 @@ func NewLooper(settings configuration.ShadowSocks, logger logging.Logger) Looper } } -func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { + defer close(done) crashed := false @@ -84,8 +84,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { return } - defer l.logger.Warn("loop exited") - for ctx.Err() == nil { settings := l.GetSettings() server, err := shadowsockslib.NewServer(settings.Method, settings.Password, adaptLogger(l.logger, settings.Log)) @@ -114,7 +112,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { for stayHere { select { case <-ctx.Done(): - l.logger.Warn("context canceled: exiting loop") shadowsocksCancel() <-waitError close(waitError) diff --git a/internal/shutdown/order.go b/internal/shutdown/order.go new file mode 100644 index 00000000..4347b8ea --- /dev/null +++ b/internal/shutdown/order.go @@ -0,0 +1,50 @@ +package shutdown + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/qdm12/golibs/logging" +) + +type Order interface { + Append(waves ...Wave) + Shutdown(timeout time.Duration, logger logging.Logger) (err error) +} + +type order struct { + waves []Wave + total int // for logging only +} + +func NewOrder() Order { + return &order{} +} + +var ErrIncomplete = errors.New("one or more routines did not terminate gracefully") + +func (o *order) Append(waves ...Wave) { + o.waves = append(o.waves, waves...) +} + +func (o *order) Shutdown(timeout time.Duration, logger logging.Logger) (err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + total := 0 + incomplete := 0 + + for _, wave := range o.waves { + total += wave.size() + incomplete += wave.shutdown(ctx, logger) + } + + if incomplete == 0 { + return nil + } + + return fmt.Errorf("%w: %d not terminated on %d routines", + ErrIncomplete, incomplete, total) +} diff --git a/internal/shutdown/routine.go b/internal/shutdown/routine.go new file mode 100644 index 00000000..492d37f3 --- /dev/null +++ b/internal/shutdown/routine.go @@ -0,0 +1,39 @@ +package shutdown + +import ( + "context" + "fmt" + "time" +) + +type routine struct { + name string + cancel context.CancelFunc + done <-chan struct{} + timeout time.Duration +} + +func newRoutine(name string) (r routine, + ctx context.Context, done chan struct{}) { + ctx, cancel := context.WithCancel(context.Background()) + done = make(chan struct{}) + return routine{ + name: name, + cancel: cancel, + done: done, + }, ctx, done +} + +func (r *routine) shutdown(ctx context.Context) (err error) { + ctx, cancel := context.WithTimeout(ctx, r.timeout) + defer cancel() + + r.cancel() + + select { + case <-r.done: + return nil + case <-ctx.Done(): + return fmt.Errorf("for routine %q: %w", r.name, ctx.Err()) + } +} diff --git a/internal/shutdown/wave.go b/internal/shutdown/wave.go new file mode 100644 index 00000000..2be3a8b9 --- /dev/null +++ b/internal/shutdown/wave.go @@ -0,0 +1,66 @@ +package shutdown + +import ( + "context" + "time" + + "github.com/qdm12/golibs/logging" +) + +type Wave interface { + Add(name string, timeout time.Duration) ( + ctx context.Context, done chan struct{}) + size() int + shutdown(ctx context.Context, logger logging.Logger) (incomplete int) +} + +type wave struct { + name string + routines []routine +} + +func NewWave(name string) Wave { + return &wave{ + name: name, + } +} + +func (w *wave) Add(name string, timeout time.Duration) (ctx context.Context, done chan struct{}) { + ctx, cancel := context.WithCancel(context.Background()) + done = make(chan struct{}) + routine := routine{ + name: name, + cancel: cancel, + done: done, + timeout: timeout, + } + w.routines = append(w.routines, routine) + return ctx, done +} + +func (w *wave) size() int { return len(w.routines) } + +func (w *wave) shutdown(ctx context.Context, logger logging.Logger) (incomplete int) { + completed := make(chan bool) + + for _, r := range w.routines { + go func(r routine) { + if err := r.shutdown(ctx); err != nil { + logger.Warn(w.name + " routines: " + err.Error() + " ⚠️") + completed <- false + } else { + logger.Info(w.name + " routines: " + r.name + " terminated ✔️") + completed <- err == nil + } + }(r) + } + + for range w.routines { + c := <-completed + if !c { + incomplete++ + } + } + + return incomplete +} diff --git a/internal/updater/loop.go b/internal/updater/loop.go index 54bb574f..ab3350cc 100644 --- a/internal/updater/loop.go +++ b/internal/updater/loop.go @@ -14,8 +14,8 @@ import ( ) type Looper interface { - Run(ctx context.Context, wg *sync.WaitGroup) - RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) + Run(ctx context.Context, done chan<- struct{}) + RunRestartTicker(ctx context.Context, done chan<- struct{}) GetStatus() (status models.LoopStatus) SetStatus(status models.LoopStatus) (outcome string, err error) GetSettings() (settings configuration.Updater) @@ -84,15 +84,14 @@ func (l *looper) logAndWait(ctx context.Context, err error) { } } -func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { + defer close(done) crashed := false select { case <-l.start: case <-ctx.Done(): return } - defer l.logger.Warn("loop exited") for ctx.Err() == nil { updateCtx, updateCancel := context.WithCancel(ctx) @@ -125,7 +124,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { for stayHere { select { case <-ctx.Done(): - l.logger.Warn("context canceled: exiting loop") updateCancel() runWg.Wait() close(errorCh) @@ -162,8 +160,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { } } -func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() +func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) { + defer close(done) timer := time.NewTimer(time.Hour) timer.Stop() timerIsStopped := true