diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index f8bf4fab..3dac1101 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -29,7 +29,6 @@ import ( "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/server" "github.com/qdm12/gluetun/internal/shadowsocks" - "github.com/qdm12/gluetun/internal/shutdown" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/unix" "github.com/qdm12/gluetun/internal/updater" @@ -38,6 +37,7 @@ import ( "github.com/qdm12/golibs/os" "github.com/qdm12/golibs/os/user" "github.com/qdm12/golibs/params" + "github.com/qdm12/goshutdown" "github.com/qdm12/updated/pkg/dnscrypto" ) @@ -275,82 +275,113 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } } // TODO move inside firewall? - const ( - shutdownMaxTimeout = 3 * time.Second - shutdownRoutineTimeout = 400 * time.Millisecond - shutdownOpenvpnTimeout = time.Second - ) - healthy := make(chan bool) - controlWave := shutdown.NewWave("control") - tickerWave := shutdown.NewWave("tickers") - healthWave := shutdown.NewWave("health") - dnsWave := shutdown.NewWave("DNS") - vpnWave := shutdown.NewWave("VPN") - serverWave := shutdown.NewWave("servers") + + // Shutdown settings + const defaultShutdownTimeout = 400 * time.Millisecond + defaultShutdownOnSuccess := func(goRoutineName string) { + logger.Info(goRoutineName + ": terminated ✔️") + } + defaultShutdownOnFailure := func(goRoutineName string, err error) { + logger.Warn(goRoutineName + ": " + err.Error() + " ⚠️") + } + defaultGoRoutineSettings := goshutdown.GoRoutineSettings{Timeout: defaultShutdownTimeout} + defaultGroupSettings := goshutdown.GroupSettings{ + Timeout: defaultShutdownTimeout, + OnFailure: defaultShutdownOnFailure, + OnSuccess: defaultShutdownOnSuccess, + } + + controlGroupHandler := goshutdown.NewGroupHandler("control", defaultGroupSettings) + tickersGroupHandler := goshutdown.NewGroupHandler("tickers", defaultGroupSettings) + otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings) openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh, healthy) - openvpnCtx, openvpnDone := vpnWave.Add("openvpn", shutdownOpenvpnTimeout) + openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler( + "openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second}) // wait for restartOpenvpn go openvpnLooper.Run(openvpnCtx, openvpnDone) updaterLooper := updater.NewLooper(allSettings.Updater, allServers, storage, openvpnLooper.SetServers, httpClient, logger.NewChild(logging.Settings{Prefix: "updater: "})) - updaterCtx, updaterDone := tickerWave.Add("updater", shutdownRoutineTimeout) + updaterHandler, updaterCtx, updaterDone := goshutdown.NewGoRoutineHandler( + "updater", defaultGoRoutineSettings) // wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker go updaterLooper.Run(updaterCtx, updaterDone) + tickersGroupHandler.Add(updaterHandler) unboundLogger := logger.NewChild(logging.Settings{Prefix: "dns over tls: "}) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient, unboundLogger, os.OpenFile) - dnsCtx, dnsDone := dnsWave.Add("unbound", shutdownRoutineTimeout) + dnsHandler, dnsCtx, dnsDone := goshutdown.NewGoRoutineHandler( + "unbound", defaultGoRoutineSettings) // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker go unboundLooper.Run(dnsCtx, dnsDone) + otherGroupHandler.Add(dnsHandler) publicIPLooper := publicip.NewLooper(httpClient, logger.NewChild(logging.Settings{Prefix: "ip getter: "}), allSettings.PublicIP, puid, pgid, os) - pubIPCtx, pubIPDone := serverWave.Add("public IP", shutdownRoutineTimeout) + pubIPHandler, pubIPCtx, pubIPDone := goshutdown.NewGoRoutineHandler( + "public IP", defaultGoRoutineSettings) go publicIPLooper.Run(pubIPCtx, pubIPDone) + otherGroupHandler.Add(pubIPHandler) - pubIPTickerCtx, pubIPTickerDone := tickerWave.Add("public IP", shutdownRoutineTimeout) + pubIPTickerHandler, pubIPTickerCtx, pubIPTickerDone := goshutdown.NewGoRoutineHandler( + "public IP", defaultGoRoutineSettings) go publicIPLooper.RunRestartTicker(pubIPTickerCtx, pubIPTickerDone) + tickersGroupHandler.Add(pubIPTickerHandler) httpProxyLooper := httpproxy.NewLooper( logger.NewChild(logging.Settings{Prefix: "http proxy: "}), allSettings.HTTPProxy) - httpProxyCtx, httpProxyDone := serverWave.Add("http proxy", shutdownRoutineTimeout) + httpProxyHandler, httpProxyCtx, httpProxyDone := goshutdown.NewGoRoutineHandler( + "http proxy", defaultGoRoutineSettings) go httpProxyLooper.Run(httpProxyCtx, httpProxyDone) + otherGroupHandler.Add(httpProxyHandler) shadowsocksLooper := shadowsocks.NewLooper(allSettings.ShadowSocks, logger.NewChild(logging.Settings{Prefix: "shadowsocks: "})) - shadowsocksCtx, shadowsocksDone := serverWave.Add("shadowsocks proxy", shutdownRoutineTimeout) + shadowsocksHandler, shadowsocksCtx, shadowsocksDone := goshutdown.NewGoRoutineHandler( + "shadowsocks proxy", defaultGoRoutineSettings) go shadowsocksLooper.Run(shadowsocksCtx, shadowsocksDone) + otherGroupHandler.Add(shadowsocksHandler) - eventsRoutingCtx, eventsRoutingDone := controlWave.Add("events routing", shutdownRoutineTimeout) + eventsRoutingHandler, eventsRoutingCtx, eventsRoutingDone := goshutdown.NewGoRoutineHandler( + "events routing", defaultGoRoutineSettings) go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh, unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient, allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward, ) + controlGroupHandler.Add(eventsRoutingHandler) + controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port)) controlServerLogging := allSettings.ControlServer.Log httpServer := server.New(controlServerAddress, controlServerLogging, logger.NewChild(logging.Settings{Prefix: "http server: "}), buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper) - httpServerCtx, httpServerDone := controlWave.Add("http server", shutdownRoutineTimeout) + httpServerHandler, httpServerCtx, httpServerDone := goshutdown.NewGoRoutineHandler( + "http server", defaultGoRoutineSettings) go httpServer.Run(httpServerCtx, httpServerDone) + controlGroupHandler.Add(httpServerHandler) healthcheckServer := healthcheck.NewServer(constants.HealthcheckAddress, logger.NewChild(logging.Settings{Prefix: "healthcheck: "})) - healthServerCtx, healthServerDone := healthWave.Add("HTTP health server", shutdownRoutineTimeout) + healthServerHandler, healthServerCtx, healthServerDone := goshutdown.NewGoRoutineHandler( + "HTTP health server", defaultGoRoutineSettings) go healthcheckServer.Run(healthServerCtx, healthy, healthServerDone) - shutdownOrder := shutdown.NewOrder() - shutdownOrder.Append(controlWave, tickerWave, healthWave, - dnsWave, vpnWave, serverWave, - ) + const orderShutdownTimeout = 3 * time.Second + orderSettings := goshutdown.OrderSettings{ + Timeout: orderShutdownTimeout, + OnFailure: defaultShutdownOnFailure, + OnSuccess: defaultShutdownOnSuccess, + } + orderHandler := goshutdown.NewOrder("gluetun", orderSettings) + orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler, + openvpnHandler, otherGroupHandler) // Start openvpn for the first time in a blocking call // until openvpn is launched @@ -365,7 +396,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } } - return shutdownOrder.Shutdown(shutdownMaxTimeout, logger) + return orderHandler.Shutdown(context.Background()) } type printVersionElement struct { diff --git a/go.mod b/go.mod index 14ced5ef..77b1c921 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/golang/mock v1.5.0 github.com/qdm12/dns v1.8.0 github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb + github.com/qdm12/goshutdown v0.1.0 github.com/qdm12/ss-server v0.2.0 github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 58fdf25b..cbf6c364 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,8 @@ github.com/qdm12/dns v1.8.0 h1:GZ40kptmfDHOMNxBKWSA4zrbNyGm41BA57zv2MaDtCI= github.com/qdm12/dns v1.8.0/go.mod h1:P2mm63NDYZdx2NAd5CVLM0FBnNdi1ZgVjsRSnX+96vg= github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb h1:5WkOssTWl6Tv2H7VFb2jwB08A7BxxNCebkkpvz1PzrY= github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg= +github.com/qdm12/goshutdown v0.1.0 h1:lmwnygdXtnr2pa6VqfR/bm8077/BnBef1+7CP96B7Sw= +github.com/qdm12/goshutdown v0.1.0/go.mod h1:/LP3MWLqI+wGH/ijfaUG+RHzBbKXIiVKnrg5vXOCf6Q= github.com/qdm12/ss-server v0.2.0 h1:+togLzeeLAJ68MD1JqOWvYi9rl9t/fx1Qh7wKzZhY1g= github.com/qdm12/ss-server v0.2.0/go.mod h1:+1bWO1EfWNvsGM5Cuep6vneChK2OHniqtAsED9Fh1y0= github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e h1:4q+uFLawkaQRq3yARYLsjJPZd2wYwxn4g6G/5v0xW1g= diff --git a/internal/shutdown/order.go b/internal/shutdown/order.go deleted file mode 100644 index 755ee7b2..00000000 --- a/internal/shutdown/order.go +++ /dev/null @@ -1,49 +0,0 @@ -package shutdown - -import ( - "context" - "errors" - "fmt" - "time" - - "github.com/qdm12/golibs/logging" -) - -type Order interface { - Append(waves ...Wave) - Shutdown(timeout time.Duration, logger logging.Logger) (err error) -} - -type order struct { - waves []Wave -} - -func NewOrder() Order { - return &order{} -} - -var ErrIncomplete = errors.New("one or more routines did not terminate gracefully") - -func (o *order) Append(waves ...Wave) { - o.waves = append(o.waves, waves...) -} - -func (o *order) Shutdown(timeout time.Duration, logger logging.Logger) (err error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - total := 0 - incomplete := 0 - - for _, wave := range o.waves { - total += wave.size() - incomplete += wave.shutdown(ctx, logger) - } - - if incomplete == 0 { - return nil - } - - return fmt.Errorf("%w: %d not terminated on %d routines", - ErrIncomplete, incomplete, total) -} diff --git a/internal/shutdown/routine.go b/internal/shutdown/routine.go deleted file mode 100644 index 110dfa53..00000000 --- a/internal/shutdown/routine.go +++ /dev/null @@ -1,28 +0,0 @@ -package shutdown - -import ( - "context" - "fmt" - "time" -) - -type routine struct { - name string - cancel context.CancelFunc - done <-chan struct{} - timeout time.Duration -} - -func (r *routine) shutdown(ctx context.Context) (err error) { - ctx, cancel := context.WithTimeout(ctx, r.timeout) - defer cancel() - - r.cancel() - - select { - case <-r.done: - return nil - case <-ctx.Done(): - return fmt.Errorf("for routine %q: %w", r.name, ctx.Err()) - } -} diff --git a/internal/shutdown/wave.go b/internal/shutdown/wave.go deleted file mode 100644 index 2be3a8b9..00000000 --- a/internal/shutdown/wave.go +++ /dev/null @@ -1,66 +0,0 @@ -package shutdown - -import ( - "context" - "time" - - "github.com/qdm12/golibs/logging" -) - -type Wave interface { - Add(name string, timeout time.Duration) ( - ctx context.Context, done chan struct{}) - size() int - shutdown(ctx context.Context, logger logging.Logger) (incomplete int) -} - -type wave struct { - name string - routines []routine -} - -func NewWave(name string) Wave { - return &wave{ - name: name, - } -} - -func (w *wave) Add(name string, timeout time.Duration) (ctx context.Context, done chan struct{}) { - ctx, cancel := context.WithCancel(context.Background()) - done = make(chan struct{}) - routine := routine{ - name: name, - cancel: cancel, - done: done, - timeout: timeout, - } - w.routines = append(w.routines, routine) - return ctx, done -} - -func (w *wave) size() int { return len(w.routines) } - -func (w *wave) shutdown(ctx context.Context, logger logging.Logger) (incomplete int) { - completed := make(chan bool) - - for _, r := range w.routines { - go func(r routine) { - if err := r.shutdown(ctx); err != nil { - logger.Warn(w.name + " routines: " + err.Error() + " ⚠️") - completed <- false - } else { - logger.Info(w.name + " routines: " + r.name + " terminated ✔️") - completed <- err == nil - } - }(r) - } - - for range w.routines { - c := <-completed - if !c { - incomplete++ - } - } - - return incomplete -}