Maintenance: shutdown order
- Order of threads to shutdown (control then tickers then health etc.) - Rely on closing channels instead of waitgroups - Move exit logs from each package to the shutdown package
This commit is contained in:
@@ -9,7 +9,6 @@ import (
|
|||||||
nativeos "os"
|
nativeos "os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -29,6 +28,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/routing"
|
"github.com/qdm12/gluetun/internal/routing"
|
||||||
"github.com/qdm12/gluetun/internal/server"
|
"github.com/qdm12/gluetun/internal/server"
|
||||||
"github.com/qdm12/gluetun/internal/shadowsocks"
|
"github.com/qdm12/gluetun/internal/shadowsocks"
|
||||||
|
"github.com/qdm12/gluetun/internal/shutdown"
|
||||||
"github.com/qdm12/gluetun/internal/storage"
|
"github.com/qdm12/gluetun/internal/storage"
|
||||||
"github.com/qdm12/gluetun/internal/unix"
|
"github.com/qdm12/gluetun/internal/unix"
|
||||||
"github.com/qdm12/gluetun/internal/updater"
|
"github.com/qdm12/gluetun/internal/updater"
|
||||||
@@ -255,44 +255,56 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
} // TODO move inside firewall?
|
} // TODO move inside firewall?
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
const (
|
||||||
|
shutdownMaxTimeout = 3 * time.Second
|
||||||
|
shutdownRoutineTimeout = 400 * time.Millisecond
|
||||||
|
shutdownOpenvpnTimeout = time.Second
|
||||||
|
)
|
||||||
|
|
||||||
healthy := make(chan bool)
|
healthy := make(chan bool)
|
||||||
|
controlWave := shutdown.NewWave("control")
|
||||||
|
tickerWave := shutdown.NewWave("tickers")
|
||||||
|
healthWave := shutdown.NewWave("health")
|
||||||
|
dnsWave := shutdown.NewWave("DNS")
|
||||||
|
vpnWave := shutdown.NewWave("VPN")
|
||||||
|
serverWave := shutdown.NewWave("servers")
|
||||||
|
|
||||||
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
|
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
|
||||||
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh, healthy)
|
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh, healthy)
|
||||||
wg.Add(1)
|
openvpnCtx, openvpnDone := vpnWave.Add("openvpn", shutdownOpenvpnTimeout)
|
||||||
// wait for restartOpenvpn
|
// wait for restartOpenvpn
|
||||||
go openvpnLooper.Run(ctx, wg)
|
go openvpnLooper.Run(openvpnCtx, openvpnDone)
|
||||||
|
|
||||||
updaterLooper := updater.NewLooper(allSettings.Updater,
|
updaterLooper := updater.NewLooper(allSettings.Updater,
|
||||||
allServers, storage, openvpnLooper.SetServers, httpClient, logger)
|
allServers, storage, openvpnLooper.SetServers, httpClient, logger)
|
||||||
wg.Add(1)
|
updaterCtx, updaterDone := tickerWave.Add("updater", shutdownRoutineTimeout)
|
||||||
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker
|
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker
|
||||||
go updaterLooper.Run(ctx, wg)
|
go updaterLooper.Run(updaterCtx, updaterDone)
|
||||||
|
|
||||||
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient,
|
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient,
|
||||||
logger, nonRootUsername, puid, pgid)
|
logger, nonRootUsername, puid, pgid)
|
||||||
wg.Add(1)
|
dnsCtx, dnsDone := dnsWave.Add("unbound", shutdownRoutineTimeout)
|
||||||
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
|
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
|
||||||
go unboundLooper.Run(ctx, wg)
|
go unboundLooper.Run(dnsCtx, dnsDone)
|
||||||
|
|
||||||
publicIPLooper := publicip.NewLooper(
|
publicIPLooper := publicip.NewLooper(
|
||||||
httpClient, logger, allSettings.PublicIP, puid, pgid, os)
|
httpClient, logger, allSettings.PublicIP, puid, pgid, os)
|
||||||
wg.Add(1)
|
pubIPCtx, pubIPDone := serverWave.Add("public IP", shutdownRoutineTimeout)
|
||||||
go publicIPLooper.Run(ctx, wg)
|
go publicIPLooper.Run(pubIPCtx, pubIPDone)
|
||||||
wg.Add(1)
|
|
||||||
go publicIPLooper.RunRestartTicker(ctx, wg)
|
pubIPTickerCtx, pubIPTickerDone := tickerWave.Add("public IP", shutdownRoutineTimeout)
|
||||||
|
go publicIPLooper.RunRestartTicker(pubIPTickerCtx, pubIPTickerDone)
|
||||||
|
|
||||||
httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy)
|
httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy)
|
||||||
wg.Add(1)
|
httpProxyCtx, httpProxyDone := serverWave.Add("http proxy", shutdownRoutineTimeout)
|
||||||
go httpProxyLooper.Run(ctx, wg)
|
go httpProxyLooper.Run(httpProxyCtx, httpProxyDone)
|
||||||
|
|
||||||
shadowsocksLooper := shadowsocks.NewLooper(allSettings.ShadowSocks, logger)
|
shadowsocksLooper := shadowsocks.NewLooper(allSettings.ShadowSocks, logger)
|
||||||
wg.Add(1)
|
shadowsocksCtx, shadowsocksDone := serverWave.Add("shadowsocks proxy", shutdownRoutineTimeout)
|
||||||
go shadowsocksLooper.Run(ctx, wg)
|
go shadowsocksLooper.Run(shadowsocksCtx, shadowsocksDone)
|
||||||
|
|
||||||
wg.Add(1)
|
eventsRoutingCtx, eventsRoutingDone := controlWave.Add("events routing", shutdownRoutineTimeout)
|
||||||
go routeReadyEvents(ctx, wg, buildInfo, tunnelReadyCh,
|
go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh,
|
||||||
unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient,
|
unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient,
|
||||||
allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward,
|
allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward,
|
||||||
)
|
)
|
||||||
@@ -300,13 +312,17 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
controlServerLogging := allSettings.ControlServer.Log
|
controlServerLogging := allSettings.ControlServer.Log
|
||||||
httpServer := server.New(controlServerAddress, controlServerLogging,
|
httpServer := server.New(controlServerAddress, controlServerLogging,
|
||||||
logger, buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
|
logger, buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
|
||||||
wg.Add(1)
|
httpServerCtx, httpServerDone := controlWave.Add("http server", shutdownRoutineTimeout)
|
||||||
go httpServer.Run(ctx, wg)
|
go httpServer.Run(httpServerCtx, httpServerDone)
|
||||||
|
|
||||||
healthcheckServer := healthcheck.NewServer(
|
healthcheckServer := healthcheck.NewServer(constants.HealthcheckAddress, logger)
|
||||||
constants.HealthcheckAddress, logger)
|
healthServerCtx, healthServerDone := healthWave.Add("HTTP health server", shutdownRoutineTimeout)
|
||||||
wg.Add(1)
|
go healthcheckServer.Run(healthServerCtx, healthy, healthServerDone)
|
||||||
go healthcheckServer.Run(ctx, healthy, wg)
|
|
||||||
|
shutdownOrder := shutdown.NewOrder()
|
||||||
|
shutdownOrder.Append(controlWave, tickerWave, healthWave,
|
||||||
|
dnsWave, vpnWave, serverWave,
|
||||||
|
)
|
||||||
|
|
||||||
// Start openvpn for the first time in a blocking call
|
// Start openvpn for the first time in a blocking call
|
||||||
// until openvpn is launched
|
// until openvpn is launched
|
||||||
@@ -321,6 +337,11 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := shutdownOrder.Shutdown(shutdownMaxTimeout, logger); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only disable firewall if everything has shutdown gracefully
|
||||||
if allSettings.Firewall.Enabled {
|
if allSettings.Firewall.Enabled {
|
||||||
const enable = false
|
const enable = false
|
||||||
err := firewallConf.SetEnabled(context.Background(), enable)
|
err := firewallConf.SetEnabled(context.Background(), enable)
|
||||||
@@ -329,8 +350,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -349,22 +368,29 @@ func printVersions(ctx context.Context, logger logging.Logger,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.BuildInformation,
|
func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo models.BuildInformation,
|
||||||
tunnelReadyCh <-chan struct{},
|
tunnelReadyCh <-chan struct{},
|
||||||
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
|
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
|
||||||
routing routing.Routing, logger logging.Logger, httpClient *http.Client,
|
routing routing.Routing, logger logging.Logger, httpClient *http.Client,
|
||||||
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
|
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
tickerWg := &sync.WaitGroup{}
|
|
||||||
// for linters only
|
// for linters only
|
||||||
var restartTickerContext context.Context
|
var restartTickerContext context.Context
|
||||||
var restartTickerCancel context.CancelFunc = func() {}
|
var restartTickerCancel context.CancelFunc = func() {}
|
||||||
|
|
||||||
|
unboundTickerDone := make(chan struct{})
|
||||||
|
close(unboundTickerDone)
|
||||||
|
updaterTickerDone := make(chan struct{})
|
||||||
|
close(updaterTickerDone)
|
||||||
|
|
||||||
first := true
|
first := true
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
restartTickerCancel() // for linters only
|
restartTickerCancel() // for linters only
|
||||||
tickerWg.Wait()
|
<-unboundTickerDone
|
||||||
|
<-updaterTickerDone
|
||||||
return
|
return
|
||||||
case <-tunnelReadyCh: // blocks until openvpn is connected
|
case <-tunnelReadyCh: // blocks until openvpn is connected
|
||||||
vpnDestination, err := routing.VPNDestinationIP()
|
vpnDestination, err := routing.VPNDestinationIP()
|
||||||
@@ -379,7 +405,8 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.
|
|||||||
}
|
}
|
||||||
|
|
||||||
restartTickerCancel() // stop previous restart tickers
|
restartTickerCancel() // stop previous restart tickers
|
||||||
tickerWg.Wait()
|
<-unboundTickerDone
|
||||||
|
<-updaterTickerDone
|
||||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
// Runs the Public IP getter job once
|
// Runs the Public IP getter job once
|
||||||
@@ -398,10 +425,10 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gomnd
|
unboundTickerDone = make(chan struct{})
|
||||||
tickerWg.Add(2)
|
updaterTickerDone = make(chan struct{})
|
||||||
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
go unboundLooper.RunRestartTicker(restartTickerContext, unboundTickerDone)
|
||||||
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
go updaterLooper.RunRestartTicker(restartTickerContext, updaterTickerDone)
|
||||||
if portForwardingEnabled {
|
if portForwardingEnabled {
|
||||||
// vpnGateway required only for PIA
|
// vpnGateway required only for PIA
|
||||||
vpnGateway, err := routing.VPNLocalGatewayIP()
|
vpnGateway, err := routing.VPNLocalGatewayIP()
|
||||||
|
|||||||
@@ -3,14 +3,13 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (l *looper) collectLines(wg *sync.WaitGroup, stdout, stderr <-chan string) {
|
func (l *looper) collectLines(stdout, stderr <-chan string, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
var line string
|
var line string
|
||||||
var ok bool
|
var ok bool
|
||||||
for {
|
for {
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
RunRestartTicker(ctx context.Context, done chan<- struct{})
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
GetSettings() (settings configuration.DNS)
|
GetSettings() (settings configuration.DNS)
|
||||||
@@ -86,8 +86,8 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
|
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution?
|
l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution?
|
||||||
@@ -99,8 +99,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer l.logger.Warn("loop exited")
|
|
||||||
|
|
||||||
crashed := false
|
crashed := false
|
||||||
l.backoffTime = defaultBackoffTime
|
l.backoffTime = defaultBackoffTime
|
||||||
|
|
||||||
@@ -116,11 +114,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
if !crashed {
|
if !crashed {
|
||||||
l.running <- constants.Stopped
|
l.running <- constants.Stopped
|
||||||
}
|
}
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, wg, crashed)
|
unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, crashed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, errUpdateFiles) {
|
if !errors.Is(err, errUpdateFiles) {
|
||||||
const fallback = true
|
const fallback = true
|
||||||
@@ -143,7 +140,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
for stayHere {
|
for stayHere {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
|
||||||
unboundCancel()
|
unboundCancel()
|
||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
@@ -178,9 +174,8 @@ var errUpdateFiles = errors.New("cannot update files")
|
|||||||
// Returning cancel == nil signals we want to re-run setupUnbound
|
// Returning cancel == nil signals we want to re-run setupUnbound
|
||||||
// Returning err == errUpdateFiles signals we should not fall back
|
// Returning err == errUpdateFiles signals we should not fall back
|
||||||
// on the plaintext DNS as DOT is still up and running.
|
// on the plaintext DNS as DOT is still up and running.
|
||||||
func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup,
|
func (l *looper) setupUnbound(ctx context.Context, previousCrashed bool) (
|
||||||
previousCrashed bool) (cancel context.CancelFunc, waitError chan error,
|
cancel context.CancelFunc, waitError chan error, closeStreams func(), err error) {
|
||||||
closeStreams func(), err error) {
|
|
||||||
err = l.updateFiles(ctx)
|
err = l.updateFiles(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.state.setStatusWithLock(constants.Crashed)
|
l.state.setStatusWithLock(constants.Crashed)
|
||||||
@@ -199,8 +194,8 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup,
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
collectLinesDone := make(chan struct{})
|
||||||
go l.collectLines(wg, stdoutLines, stderrLines)
|
go l.collectLines(stdoutLines, stderrLines, collectLinesDone)
|
||||||
|
|
||||||
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
||||||
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound
|
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound
|
||||||
@@ -216,6 +211,7 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup,
|
|||||||
close(waitError)
|
close(waitError)
|
||||||
close(stdoutLines)
|
close(stdoutLines)
|
||||||
close(stderrLines)
|
close(stderrLines)
|
||||||
|
<-collectLinesDone
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,6 +226,7 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup,
|
|||||||
closeStreams = func() {
|
closeStreams = func() {
|
||||||
close(stdoutLines)
|
close(stdoutLines)
|
||||||
close(stderrLines)
|
close(stderrLines)
|
||||||
|
<-collectLinesDone
|
||||||
}
|
}
|
||||||
|
|
||||||
return cancel, waitError, closeStreams, nil
|
return cancel, waitError, closeStreams, nil
|
||||||
@@ -276,8 +273,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
|
|||||||
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Unbound.Providers)
|
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Unbound.Providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
// Timer that acts as a ticker
|
// Timer that acts as a ticker
|
||||||
timer := time.NewTimer(time.Hour)
|
timer := time.NewTimer(time.Hour)
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
|
|||||||
@@ -6,12 +6,11 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *server) runHealthcheckLoop(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup) {
|
func (s *server) runHealthcheckLoop(ctx context.Context, healthy chan<- bool, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
for {
|
for {
|
||||||
previousErr := s.handler.getErr()
|
previousErr := s.handler.getErr()
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server interface {
|
type Server interface {
|
||||||
Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup)
|
Run(ctx context.Context, healthy chan<- bool, done chan<- struct{})
|
||||||
}
|
}
|
||||||
|
|
||||||
type server struct {
|
type server struct {
|
||||||
@@ -32,23 +31,20 @@ func NewServer(address string, logger logging.Logger) Server {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup) {
|
func (s *server) Run(ctx context.Context, healthy chan<- bool, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
|
|
||||||
internalWg := &sync.WaitGroup{}
|
loopDone := make(chan struct{})
|
||||||
internalWg.Add(1)
|
go s.runHealthcheckLoop(ctx, healthy, loopDone)
|
||||||
go s.runHealthcheckLoop(ctx, healthy, internalWg)
|
|
||||||
|
|
||||||
server := http.Server{
|
server := http.Server{
|
||||||
Addr: s.address,
|
Addr: s.address,
|
||||||
Handler: s.handler,
|
Handler: s.handler,
|
||||||
}
|
}
|
||||||
internalWg.Add(1)
|
serverDone := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer internalWg.Done()
|
defer close(serverDone)
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
s.logger.Warn("context canceled: shutting down server")
|
|
||||||
defer s.logger.Warn("server shut down")
|
|
||||||
const shutdownGraceDuration = 2 * time.Second
|
const shutdownGraceDuration = 2 * time.Second
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -63,5 +59,6 @@ func (s *server) Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGrou
|
|||||||
s.logger.Error(err)
|
s.logger.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
internalWg.Wait()
|
<-loopDone
|
||||||
|
<-serverDone
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
GetSettings() (settings configuration.HTTPProxy)
|
GetSettings() (settings configuration.HTTPProxy)
|
||||||
@@ -50,8 +50,8 @@ func NewLooper(logger logging.Logger, settings configuration.HTTPProxy) Looper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
|
|
||||||
crashed := false
|
crashed := false
|
||||||
|
|
||||||
@@ -67,8 +67,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer l.logger.Warn("loop exited")
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
runCtx, runCancel := context.WithCancel(ctx)
|
runCtx, runCancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
@@ -76,10 +74,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
address := fmt.Sprintf(":%d", settings.Port)
|
address := fmt.Sprintf(":%d", settings.Port)
|
||||||
server := New(runCtx, address, l.logger, settings.Stealth, settings.Log, settings.User, settings.Password)
|
server := New(runCtx, address, l.logger, settings.Stealth, settings.Log, settings.User, settings.Password)
|
||||||
|
|
||||||
runWg := &sync.WaitGroup{}
|
|
||||||
runWg.Add(1)
|
|
||||||
errorCh := make(chan error)
|
errorCh := make(chan error)
|
||||||
go server.Run(runCtx, runWg, errorCh)
|
go server.Run(runCtx, errorCh)
|
||||||
|
|
||||||
// TODO stable timer, check Shadowsocks
|
// TODO stable timer, check Shadowsocks
|
||||||
if !crashed {
|
if !crashed {
|
||||||
@@ -94,22 +90,20 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
for stayHere {
|
for stayHere {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
|
||||||
runCancel()
|
runCancel()
|
||||||
runWg.Wait()
|
<-errorCh
|
||||||
return
|
return
|
||||||
case <-l.start:
|
case <-l.start:
|
||||||
l.logger.Info("starting")
|
l.logger.Info("starting")
|
||||||
runCancel()
|
runCancel()
|
||||||
runWg.Wait()
|
<-errorCh
|
||||||
stayHere = false
|
stayHere = false
|
||||||
case <-l.stop:
|
case <-l.stop:
|
||||||
l.logger.Info("stopping")
|
l.logger.Info("stopping")
|
||||||
runCancel()
|
runCancel()
|
||||||
runWg.Wait()
|
<-errorCh
|
||||||
l.stopped <- struct{}{}
|
l.stopped <- struct{}{}
|
||||||
case err := <-errorCh:
|
case err := <-errorCh:
|
||||||
runWg.Wait()
|
|
||||||
l.state.setStatusWithLock(constants.Crashed)
|
l.state.setStatusWithLock(constants.Crashed)
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
crashed = true
|
crashed = true
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Server interface {
|
type Server interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- error)
|
Run(ctx context.Context, errorCh chan<- error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type server struct {
|
type server struct {
|
||||||
@@ -31,13 +31,10 @@ func New(ctx context.Context, address string, logger logging.Logger,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- error) {
|
func (s *server) Run(ctx context.Context, errorCh chan<- error) {
|
||||||
defer wg.Done()
|
|
||||||
server := http.Server{Addr: s.address, Handler: s.handler}
|
server := http.Server{Addr: s.address, Handler: s.handler}
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
s.logger.Warn("shutting down server")
|
|
||||||
defer s.logger.Warn("server shut down")
|
|
||||||
const shutdownGraceDuration = 2 * time.Second
|
const shutdownGraceDuration = 2 * time.Second
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -47,8 +44,10 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- err
|
|||||||
}()
|
}()
|
||||||
s.logger.Info("listening on %s", s.address)
|
s.logger.Info("listening on %s", s.address)
|
||||||
err := server.ListenAndServe()
|
err := server.ListenAndServe()
|
||||||
|
s.internalWG.Wait()
|
||||||
if err != nil && ctx.Err() == nil {
|
if err != nil && ctx.Err() == nil {
|
||||||
errorCh <- err
|
errorCh <- err
|
||||||
|
} else {
|
||||||
|
errorCh <- nil
|
||||||
}
|
}
|
||||||
s.internalWG.Wait()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,14 @@ package openvpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (l *looper) collectLines(wg *sync.WaitGroup, stdout, stderr <-chan string) {
|
func (l *looper) collectLines(stdout, stderr <-chan string, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
var line string
|
var line string
|
||||||
var ok, errLine bool
|
var ok, errLine bool
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
GetSettings() (settings configuration.OpenVPN)
|
GetSettings() (settings configuration.OpenVPN)
|
||||||
@@ -104,14 +104,13 @@ func (l *looper) signalCrashedStatus() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocognit
|
func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocognit
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
select {
|
select {
|
||||||
case <-l.start:
|
case <-l.start:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer l.logger.Warn("loop exited")
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
settings, allServers := l.state.getSettingsAndServers()
|
settings, allServers := l.state.getSettingsAndServers()
|
||||||
@@ -166,20 +165,19 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocogni
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
lineCollectionDone := make(chan struct{})
|
||||||
go l.collectLines(wg, stdoutLines, stderrLines)
|
go l.collectLines(stdoutLines, stderrLines, lineCollectionDone)
|
||||||
|
|
||||||
// Needs the stream line from main.go to know when the tunnel is up
|
// Needs the stream line from main.go to know when the tunnel is up
|
||||||
|
portForwardDone := make(chan struct{})
|
||||||
go func(ctx context.Context) {
|
go func(ctx context.Context) {
|
||||||
for {
|
defer close(portForwardDone)
|
||||||
select {
|
select {
|
||||||
// TODO have a way to disable pf with a context
|
// TODO have a way to disable pf with a context
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case gateway := <-l.portForwardSignals:
|
case gateway := <-l.portForwardSignals:
|
||||||
wg.Add(1)
|
l.portForward(ctx, providerConf, l.client, gateway)
|
||||||
go l.portForward(ctx, wg, providerConf, l.client, gateway)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}(openvpnCtx)
|
}(openvpnCtx)
|
||||||
|
|
||||||
@@ -195,12 +193,13 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocogni
|
|||||||
for stayHere {
|
for stayHere {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
|
||||||
openvpnCancel()
|
openvpnCancel()
|
||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
close(stdoutLines)
|
close(stdoutLines)
|
||||||
close(stderrLines)
|
close(stderrLines)
|
||||||
|
<-lineCollectionDone
|
||||||
|
<-portForwardDone
|
||||||
return
|
return
|
||||||
case <-l.stop:
|
case <-l.stop:
|
||||||
l.logger.Info("stopping")
|
l.logger.Info("stopping")
|
||||||
@@ -288,9 +287,8 @@ func (l *looper) waitForHealth(ctx context.Context) (healthy bool) {
|
|||||||
|
|
||||||
// portForward is a blocking operation which may or may not be infinite.
|
// portForward is a blocking operation which may or may not be infinite.
|
||||||
// You should therefore always call it in a goroutine.
|
// You should therefore always call it in a goroutine.
|
||||||
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
|
func (l *looper) portForward(ctx context.Context,
|
||||||
providerConf provider.Provider, client *http.Client, gateway net.IP) {
|
providerConf provider.Provider, client *http.Client, gateway net.IP) {
|
||||||
defer wg.Done()
|
|
||||||
l.state.portForwardedMu.RLock()
|
l.state.portForwardedMu.RLock()
|
||||||
settings := l.state.settings
|
settings := l.state.settings
|
||||||
l.state.portForwardedMu.RUnlock()
|
l.state.portForwardedMu.RUnlock()
|
||||||
|
|||||||
@@ -29,8 +29,6 @@ var (
|
|||||||
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||||
openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||||
syncState func(port uint16) (pfFilepath string)) {
|
syncState func(port uint16) (pfFilepath string)) {
|
||||||
defer logger.Warn("loop exited")
|
|
||||||
|
|
||||||
commonName := p.activeServer.ServerName
|
commonName := p.activeServer.ServerName
|
||||||
if !p.activeServer.PortForward {
|
if !p.activeServer.PortForward {
|
||||||
logger.Error("The server " + commonName +
|
logger.Error("The server " + commonName +
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
RunRestartTicker(ctx context.Context, done chan<- struct{})
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
GetSettings() (settings configuration.PublicIP)
|
GetSettings() (settings configuration.PublicIP)
|
||||||
@@ -91,8 +91,8 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
|
|
||||||
crashed := false
|
crashed := false
|
||||||
|
|
||||||
@@ -101,7 +101,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer l.logger.Warn("loop exited")
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
getCtx, getCancel := context.WithCancel(ctx)
|
getCtx, getCancel := context.WithCancel(ctx)
|
||||||
@@ -132,11 +131,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
for stayHere {
|
for stayHere {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
|
||||||
getCancel()
|
getCancel()
|
||||||
close(errorCh)
|
close(errorCh)
|
||||||
filepath := l.GetSettings().IPFilepath
|
filepath := l.GetSettings().IPFilepath
|
||||||
l.logger.Info("Removing ip file %s", filepath)
|
l.logger.Info("Removing ip file " + filepath)
|
||||||
if err := l.os.Remove(filepath); err != nil {
|
if err := l.os.Remove(filepath); err != nil {
|
||||||
l.logger.Error(err)
|
l.logger.Error(err)
|
||||||
}
|
}
|
||||||
@@ -181,8 +179,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
timer := time.NewTimer(time.Hour)
|
timer := time.NewTimer(time.Hour)
|
||||||
timer.Stop() // 1 hour, cannot be a race condition
|
timer.Stop() // 1 hour, cannot be a race condition
|
||||||
timerIsStopped := true
|
timerIsStopped := true
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/dns"
|
"github.com/qdm12/gluetun/internal/dns"
|
||||||
@@ -16,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Server interface {
|
type Server interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
}
|
}
|
||||||
|
|
||||||
type server struct {
|
type server struct {
|
||||||
@@ -39,12 +38,11 @@ func New(address string, logEnabled bool, logger logging.Logger,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (s *server) Run(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
server := http.Server{Addr: s.address, Handler: s.handler}
|
server := http.Server{Addr: s.address, Handler: s.handler}
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
s.logger.Warn("context canceled: shutting down")
|
|
||||||
const shutdownGraceDuration = 2 * time.Second
|
const shutdownGraceDuration = 2 * time.Second
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -57,5 +55,4 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
if err != nil && ctx.Err() != context.Canceled {
|
if err != nil && ctx.Err() != context.Canceled {
|
||||||
s.logger.Error(err)
|
s.logger.Error(err)
|
||||||
}
|
}
|
||||||
s.logger.Warn("shut down")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
GetSettings() (settings configuration.ShadowSocks)
|
GetSettings() (settings configuration.ShadowSocks)
|
||||||
@@ -67,8 +67,8 @@ func NewLooper(settings configuration.ShadowSocks, logger logging.Logger) Looper
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
|
|
||||||
crashed := false
|
crashed := false
|
||||||
|
|
||||||
@@ -84,8 +84,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer l.logger.Warn("loop exited")
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
settings := l.GetSettings()
|
settings := l.GetSettings()
|
||||||
server, err := shadowsockslib.NewServer(settings.Method, settings.Password, adaptLogger(l.logger, settings.Log))
|
server, err := shadowsockslib.NewServer(settings.Method, settings.Password, adaptLogger(l.logger, settings.Log))
|
||||||
@@ -114,7 +112,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
for stayHere {
|
for stayHere {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
|
||||||
shadowsocksCancel()
|
shadowsocksCancel()
|
||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
|
|||||||
50
internal/shutdown/order.go
Normal file
50
internal/shutdown/order.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package shutdown
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Order interface {
|
||||||
|
Append(waves ...Wave)
|
||||||
|
Shutdown(timeout time.Duration, logger logging.Logger) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type order struct {
|
||||||
|
waves []Wave
|
||||||
|
total int // for logging only
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOrder() Order {
|
||||||
|
return &order{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrIncomplete = errors.New("one or more routines did not terminate gracefully")
|
||||||
|
|
||||||
|
func (o *order) Append(waves ...Wave) {
|
||||||
|
o.waves = append(o.waves, waves...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *order) Shutdown(timeout time.Duration, logger logging.Logger) (err error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
incomplete := 0
|
||||||
|
|
||||||
|
for _, wave := range o.waves {
|
||||||
|
total += wave.size()
|
||||||
|
incomplete += wave.shutdown(ctx, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if incomplete == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("%w: %d not terminated on %d routines",
|
||||||
|
ErrIncomplete, incomplete, total)
|
||||||
|
}
|
||||||
39
internal/shutdown/routine.go
Normal file
39
internal/shutdown/routine.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package shutdown
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type routine struct {
|
||||||
|
name string
|
||||||
|
cancel context.CancelFunc
|
||||||
|
done <-chan struct{}
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRoutine(name string) (r routine,
|
||||||
|
ctx context.Context, done chan struct{}) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done = make(chan struct{})
|
||||||
|
return routine{
|
||||||
|
name: name,
|
||||||
|
cancel: cancel,
|
||||||
|
done: done,
|
||||||
|
}, ctx, done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routine) shutdown(ctx context.Context) (err error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, r.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
r.cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-r.done:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("for routine %q: %w", r.name, ctx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
66
internal/shutdown/wave.go
Normal file
66
internal/shutdown/wave.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package shutdown
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Wave interface {
|
||||||
|
Add(name string, timeout time.Duration) (
|
||||||
|
ctx context.Context, done chan struct{})
|
||||||
|
size() int
|
||||||
|
shutdown(ctx context.Context, logger logging.Logger) (incomplete int)
|
||||||
|
}
|
||||||
|
|
||||||
|
type wave struct {
|
||||||
|
name string
|
||||||
|
routines []routine
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWave(name string) Wave {
|
||||||
|
return &wave{
|
||||||
|
name: name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wave) Add(name string, timeout time.Duration) (ctx context.Context, done chan struct{}) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done = make(chan struct{})
|
||||||
|
routine := routine{
|
||||||
|
name: name,
|
||||||
|
cancel: cancel,
|
||||||
|
done: done,
|
||||||
|
timeout: timeout,
|
||||||
|
}
|
||||||
|
w.routines = append(w.routines, routine)
|
||||||
|
return ctx, done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wave) size() int { return len(w.routines) }
|
||||||
|
|
||||||
|
func (w *wave) shutdown(ctx context.Context, logger logging.Logger) (incomplete int) {
|
||||||
|
completed := make(chan bool)
|
||||||
|
|
||||||
|
for _, r := range w.routines {
|
||||||
|
go func(r routine) {
|
||||||
|
if err := r.shutdown(ctx); err != nil {
|
||||||
|
logger.Warn(w.name + " routines: " + err.Error() + " ⚠️")
|
||||||
|
completed <- false
|
||||||
|
} else {
|
||||||
|
logger.Info(w.name + " routines: " + r.name + " terminated ✔️")
|
||||||
|
completed <- err == nil
|
||||||
|
}
|
||||||
|
}(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range w.routines {
|
||||||
|
c := <-completed
|
||||||
|
if !c {
|
||||||
|
incomplete++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return incomplete
|
||||||
|
}
|
||||||
@@ -14,8 +14,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, done chan<- struct{})
|
||||||
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
RunRestartTicker(ctx context.Context, done chan<- struct{})
|
||||||
GetStatus() (status models.LoopStatus)
|
GetStatus() (status models.LoopStatus)
|
||||||
SetStatus(status models.LoopStatus) (outcome string, err error)
|
SetStatus(status models.LoopStatus) (outcome string, err error)
|
||||||
GetSettings() (settings configuration.Updater)
|
GetSettings() (settings configuration.Updater)
|
||||||
@@ -84,15 +84,14 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
crashed := false
|
crashed := false
|
||||||
select {
|
select {
|
||||||
case <-l.start:
|
case <-l.start:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer l.logger.Warn("loop exited")
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
updateCtx, updateCancel := context.WithCancel(ctx)
|
updateCtx, updateCancel := context.WithCancel(ctx)
|
||||||
@@ -125,7 +124,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
for stayHere {
|
for stayHere {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
|
||||||
updateCancel()
|
updateCancel()
|
||||||
runWg.Wait()
|
runWg.Wait()
|
||||||
close(errorCh)
|
close(errorCh)
|
||||||
@@ -162,8 +160,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
|
||||||
defer wg.Done()
|
defer close(done)
|
||||||
timer := time.NewTimer(time.Hour)
|
timer := time.NewTimer(time.Hour)
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
timerIsStopped := true
|
timerIsStopped := true
|
||||||
|
|||||||
Reference in New Issue
Block a user