Maintenance: shutdown order

- Order of threads to shutdown (control then tickers then health etc.)
- Rely on closing channels instead of waitgroups
- Move exit logs from each package to the shutdown package
This commit is contained in:
Quentin McGaw
2021-05-11 22:24:32 +00:00
parent 5159c1dc83
commit cff5e693d2
17 changed files with 292 additions and 140 deletions

View File

@@ -9,7 +9,6 @@ import (
nativeos "os" nativeos "os"
"os/signal" "os/signal"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
@@ -29,6 +28,7 @@ import (
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/gluetun/internal/server" "github.com/qdm12/gluetun/internal/server"
"github.com/qdm12/gluetun/internal/shadowsocks" "github.com/qdm12/gluetun/internal/shadowsocks"
"github.com/qdm12/gluetun/internal/shutdown"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/unix" "github.com/qdm12/gluetun/internal/unix"
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
@@ -255,44 +255,56 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
} }
} // TODO move inside firewall? } // TODO move inside firewall?
wg := &sync.WaitGroup{} const (
shutdownMaxTimeout = 3 * time.Second
shutdownRoutineTimeout = 400 * time.Millisecond
shutdownOpenvpnTimeout = time.Second
)
healthy := make(chan bool) healthy := make(chan bool)
controlWave := shutdown.NewWave("control")
tickerWave := shutdown.NewWave("tickers")
healthWave := shutdown.NewWave("health")
dnsWave := shutdown.NewWave("DNS")
vpnWave := shutdown.NewWave("VPN")
serverWave := shutdown.NewWave("servers")
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh, healthy) ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh, healthy)
wg.Add(1) openvpnCtx, openvpnDone := vpnWave.Add("openvpn", shutdownOpenvpnTimeout)
// wait for restartOpenvpn // wait for restartOpenvpn
go openvpnLooper.Run(ctx, wg) go openvpnLooper.Run(openvpnCtx, openvpnDone)
updaterLooper := updater.NewLooper(allSettings.Updater, updaterLooper := updater.NewLooper(allSettings.Updater,
allServers, storage, openvpnLooper.SetServers, httpClient, logger) allServers, storage, openvpnLooper.SetServers, httpClient, logger)
wg.Add(1) updaterCtx, updaterDone := tickerWave.Add("updater", shutdownRoutineTimeout)
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker // wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker
go updaterLooper.Run(ctx, wg) go updaterLooper.Run(updaterCtx, updaterDone)
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient, unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient,
logger, nonRootUsername, puid, pgid) logger, nonRootUsername, puid, pgid)
wg.Add(1) dnsCtx, dnsDone := dnsWave.Add("unbound", shutdownRoutineTimeout)
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg) go unboundLooper.Run(dnsCtx, dnsDone)
publicIPLooper := publicip.NewLooper( publicIPLooper := publicip.NewLooper(
httpClient, logger, allSettings.PublicIP, puid, pgid, os) httpClient, logger, allSettings.PublicIP, puid, pgid, os)
wg.Add(1) pubIPCtx, pubIPDone := serverWave.Add("public IP", shutdownRoutineTimeout)
go publicIPLooper.Run(ctx, wg) go publicIPLooper.Run(pubIPCtx, pubIPDone)
wg.Add(1)
go publicIPLooper.RunRestartTicker(ctx, wg) pubIPTickerCtx, pubIPTickerDone := tickerWave.Add("public IP", shutdownRoutineTimeout)
go publicIPLooper.RunRestartTicker(pubIPTickerCtx, pubIPTickerDone)
httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy) httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy)
wg.Add(1) httpProxyCtx, httpProxyDone := serverWave.Add("http proxy", shutdownRoutineTimeout)
go httpProxyLooper.Run(ctx, wg) go httpProxyLooper.Run(httpProxyCtx, httpProxyDone)
shadowsocksLooper := shadowsocks.NewLooper(allSettings.ShadowSocks, logger) shadowsocksLooper := shadowsocks.NewLooper(allSettings.ShadowSocks, logger)
wg.Add(1) shadowsocksCtx, shadowsocksDone := serverWave.Add("shadowsocks proxy", shutdownRoutineTimeout)
go shadowsocksLooper.Run(ctx, wg) go shadowsocksLooper.Run(shadowsocksCtx, shadowsocksDone)
wg.Add(1) eventsRoutingCtx, eventsRoutingDone := controlWave.Add("events routing", shutdownRoutineTimeout)
go routeReadyEvents(ctx, wg, buildInfo, tunnelReadyCh, go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh,
unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient, unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient,
allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward, allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward,
) )
@@ -300,13 +312,17 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
controlServerLogging := allSettings.ControlServer.Log controlServerLogging := allSettings.ControlServer.Log
httpServer := server.New(controlServerAddress, controlServerLogging, httpServer := server.New(controlServerAddress, controlServerLogging,
logger, buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper) logger, buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
wg.Add(1) httpServerCtx, httpServerDone := controlWave.Add("http server", shutdownRoutineTimeout)
go httpServer.Run(ctx, wg) go httpServer.Run(httpServerCtx, httpServerDone)
healthcheckServer := healthcheck.NewServer( healthcheckServer := healthcheck.NewServer(constants.HealthcheckAddress, logger)
constants.HealthcheckAddress, logger) healthServerCtx, healthServerDone := healthWave.Add("HTTP health server", shutdownRoutineTimeout)
wg.Add(1) go healthcheckServer.Run(healthServerCtx, healthy, healthServerDone)
go healthcheckServer.Run(ctx, healthy, wg)
shutdownOrder := shutdown.NewOrder()
shutdownOrder.Append(controlWave, tickerWave, healthWave,
dnsWave, vpnWave, serverWave,
)
// Start openvpn for the first time in a blocking call // Start openvpn for the first time in a blocking call
// until openvpn is launched // until openvpn is launched
@@ -321,6 +337,11 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
} }
} }
if err := shutdownOrder.Shutdown(shutdownMaxTimeout, logger); err != nil {
return err
}
// Only disable firewall if everything has shutdown gracefully
if allSettings.Firewall.Enabled { if allSettings.Firewall.Enabled {
const enable = false const enable = false
err := firewallConf.SetEnabled(context.Background(), enable) err := firewallConf.SetEnabled(context.Background(), enable)
@@ -329,8 +350,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
} }
} }
wg.Wait()
return nil return nil
} }
@@ -349,22 +368,29 @@ func printVersions(ctx context.Context, logger logging.Logger,
} }
} }
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.BuildInformation, func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo models.BuildInformation,
tunnelReadyCh <-chan struct{}, tunnelReadyCh <-chan struct{},
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
routing routing.Routing, logger logging.Logger, httpClient *http.Client, routing routing.Routing, logger logging.Logger, httpClient *http.Client,
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) { versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
defer wg.Done() defer close(done)
tickerWg := &sync.WaitGroup{}
// for linters only // for linters only
var restartTickerContext context.Context var restartTickerContext context.Context
var restartTickerCancel context.CancelFunc = func() {} var restartTickerCancel context.CancelFunc = func() {}
unboundTickerDone := make(chan struct{})
close(unboundTickerDone)
updaterTickerDone := make(chan struct{})
close(updaterTickerDone)
first := true first := true
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
restartTickerCancel() // for linters only restartTickerCancel() // for linters only
tickerWg.Wait() <-unboundTickerDone
<-updaterTickerDone
return return
case <-tunnelReadyCh: // blocks until openvpn is connected case <-tunnelReadyCh: // blocks until openvpn is connected
vpnDestination, err := routing.VPNDestinationIP() vpnDestination, err := routing.VPNDestinationIP()
@@ -379,7 +405,8 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.
} }
restartTickerCancel() // stop previous restart tickers restartTickerCancel() // stop previous restart tickers
tickerWg.Wait() <-unboundTickerDone
<-updaterTickerDone
restartTickerContext, restartTickerCancel = context.WithCancel(ctx) restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
// Runs the Public IP getter job once // Runs the Public IP getter job once
@@ -398,10 +425,10 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.
} }
} }
//nolint:gomnd unboundTickerDone = make(chan struct{})
tickerWg.Add(2) updaterTickerDone = make(chan struct{})
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) go unboundLooper.RunRestartTicker(restartTickerContext, unboundTickerDone)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) go updaterLooper.RunRestartTicker(restartTickerContext, updaterTickerDone)
if portForwardingEnabled { if portForwardingEnabled {
// vpnGateway required only for PIA // vpnGateway required only for PIA
vpnGateway, err := routing.VPNLocalGatewayIP() vpnGateway, err := routing.VPNLocalGatewayIP()

View File

@@ -3,14 +3,13 @@ package dns
import ( import (
"regexp" "regexp"
"strings" "strings"
"sync"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
func (l *looper) collectLines(wg *sync.WaitGroup, stdout, stderr <-chan string) { func (l *looper) collectLines(stdout, stderr <-chan string, done chan<- struct{}) {
defer wg.Done() defer close(done)
var line string var line string
var ok bool var ok bool
for { for {

View File

@@ -17,8 +17,8 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, done chan<- struct{})
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, done chan<- struct{})
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(status models.LoopStatus) (outcome string, err error)
GetSettings() (settings configuration.DNS) GetSettings() (settings configuration.DNS)
@@ -86,8 +86,8 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
const fallback = false const fallback = false
l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution? l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution?
@@ -99,8 +99,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
return return
} }
defer l.logger.Warn("loop exited")
crashed := false crashed := false
l.backoffTime = defaultBackoffTime l.backoffTime = defaultBackoffTime
@@ -116,11 +114,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
if !crashed { if !crashed {
l.running <- constants.Stopped l.running <- constants.Stopped
} }
l.logger.Warn("context canceled: exiting loop")
return return
} }
var err error var err error
unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, wg, crashed) unboundCancel, waitError, closeStreams, err = l.setupUnbound(ctx, crashed)
if err != nil { if err != nil {
if !errors.Is(err, errUpdateFiles) { if !errors.Is(err, errUpdateFiles) {
const fallback = true const fallback = true
@@ -143,7 +140,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
for stayHere { for stayHere {
select { select {
case <-ctx.Done(): case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
unboundCancel() unboundCancel()
<-waitError <-waitError
close(waitError) close(waitError)
@@ -178,9 +174,8 @@ var errUpdateFiles = errors.New("cannot update files")
// Returning cancel == nil signals we want to re-run setupUnbound // Returning cancel == nil signals we want to re-run setupUnbound
// Returning err == errUpdateFiles signals we should not fall back // Returning err == errUpdateFiles signals we should not fall back
// on the plaintext DNS as DOT is still up and running. // on the plaintext DNS as DOT is still up and running.
func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup, func (l *looper) setupUnbound(ctx context.Context, previousCrashed bool) (
previousCrashed bool) (cancel context.CancelFunc, waitError chan error, cancel context.CancelFunc, waitError chan error, closeStreams func(), err error) {
closeStreams func(), err error) {
err = l.updateFiles(ctx) err = l.updateFiles(ctx)
if err != nil { if err != nil {
l.state.setStatusWithLock(constants.Crashed) l.state.setStatusWithLock(constants.Crashed)
@@ -199,8 +194,8 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup,
return nil, nil, nil, err return nil, nil, nil, err
} }
wg.Add(1) collectLinesDone := make(chan struct{})
go l.collectLines(wg, stdoutLines, stderrLines) go l.collectLines(stdoutLines, stderrLines, collectLinesDone)
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound
@@ -216,6 +211,7 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup,
close(waitError) close(waitError)
close(stdoutLines) close(stdoutLines)
close(stderrLines) close(stderrLines)
<-collectLinesDone
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -230,6 +226,7 @@ func (l *looper) setupUnbound(ctx context.Context, wg *sync.WaitGroup,
closeStreams = func() { closeStreams = func() {
close(stdoutLines) close(stdoutLines)
close(stderrLines) close(stderrLines)
<-collectLinesDone
} }
return cancel, waitError, closeStreams, nil return cancel, waitError, closeStreams, nil
@@ -276,8 +273,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Unbound.Providers) l.logger.Error("no ipv4 DNS address found for providers %s", settings.Unbound.Providers)
} }
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
// Timer that acts as a ticker // Timer that acts as a ticker
timer := time.NewTimer(time.Hour) timer := time.NewTimer(time.Hour)
timer.Stop() timer.Stop()

View File

@@ -6,12 +6,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
) )
func (s *server) runHealthcheckLoop(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup) { func (s *server) runHealthcheckLoop(ctx context.Context, healthy chan<- bool, done chan<- struct{}) {
defer wg.Done() defer close(done)
for { for {
previousErr := s.handler.getErr() previousErr := s.handler.getErr()

View File

@@ -5,14 +5,13 @@ import (
"errors" "errors"
"net" "net"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
type Server interface { type Server interface {
Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup) Run(ctx context.Context, healthy chan<- bool, done chan<- struct{})
} }
type server struct { type server struct {
@@ -32,23 +31,20 @@ func NewServer(address string, logger logging.Logger) Server {
} }
} }
func (s *server) Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGroup) { func (s *server) Run(ctx context.Context, healthy chan<- bool, done chan<- struct{}) {
defer wg.Done() defer close(done)
internalWg := &sync.WaitGroup{} loopDone := make(chan struct{})
internalWg.Add(1) go s.runHealthcheckLoop(ctx, healthy, loopDone)
go s.runHealthcheckLoop(ctx, healthy, internalWg)
server := http.Server{ server := http.Server{
Addr: s.address, Addr: s.address,
Handler: s.handler, Handler: s.handler,
} }
internalWg.Add(1) serverDone := make(chan struct{})
go func() { go func() {
defer internalWg.Done() defer close(serverDone)
<-ctx.Done() <-ctx.Done()
s.logger.Warn("context canceled: shutting down server")
defer s.logger.Warn("server shut down")
const shutdownGraceDuration = 2 * time.Second const shutdownGraceDuration = 2 * time.Second
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration) shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
defer cancel() defer cancel()
@@ -63,5 +59,6 @@ func (s *server) Run(ctx context.Context, healthy chan<- bool, wg *sync.WaitGrou
s.logger.Error(err) s.logger.Error(err)
} }
internalWg.Wait() <-loopDone
<-serverDone
} }

View File

@@ -14,7 +14,7 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, done chan<- struct{})
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(status models.LoopStatus) (outcome string, err error)
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
GetSettings() (settings configuration.HTTPProxy) GetSettings() (settings configuration.HTTPProxy)
@@ -50,8 +50,8 @@ func NewLooper(logger logging.Logger, settings configuration.HTTPProxy) Looper {
} }
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
crashed := false crashed := false
@@ -67,8 +67,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
return return
} }
defer l.logger.Warn("loop exited")
for ctx.Err() == nil { for ctx.Err() == nil {
runCtx, runCancel := context.WithCancel(ctx) runCtx, runCancel := context.WithCancel(ctx)
@@ -76,10 +74,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
address := fmt.Sprintf(":%d", settings.Port) address := fmt.Sprintf(":%d", settings.Port)
server := New(runCtx, address, l.logger, settings.Stealth, settings.Log, settings.User, settings.Password) server := New(runCtx, address, l.logger, settings.Stealth, settings.Log, settings.User, settings.Password)
runWg := &sync.WaitGroup{}
runWg.Add(1)
errorCh := make(chan error) errorCh := make(chan error)
go server.Run(runCtx, runWg, errorCh) go server.Run(runCtx, errorCh)
// TODO stable timer, check Shadowsocks // TODO stable timer, check Shadowsocks
if !crashed { if !crashed {
@@ -94,22 +90,20 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
for stayHere { for stayHere {
select { select {
case <-ctx.Done(): case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
runCancel() runCancel()
runWg.Wait() <-errorCh
return return
case <-l.start: case <-l.start:
l.logger.Info("starting") l.logger.Info("starting")
runCancel() runCancel()
runWg.Wait() <-errorCh
stayHere = false stayHere = false
case <-l.stop: case <-l.stop:
l.logger.Info("stopping") l.logger.Info("stopping")
runCancel() runCancel()
runWg.Wait() <-errorCh
l.stopped <- struct{}{} l.stopped <- struct{}{}
case err := <-errorCh: case err := <-errorCh:
runWg.Wait()
l.state.setStatusWithLock(constants.Crashed) l.state.setStatusWithLock(constants.Crashed)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
crashed = true crashed = true

View File

@@ -10,7 +10,7 @@ import (
) )
type Server interface { type Server interface {
Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- error) Run(ctx context.Context, errorCh chan<- error)
} }
type server struct { type server struct {
@@ -31,13 +31,10 @@ func New(ctx context.Context, address string, logger logging.Logger,
} }
} }
func (s *server) Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- error) { func (s *server) Run(ctx context.Context, errorCh chan<- error) {
defer wg.Done()
server := http.Server{Addr: s.address, Handler: s.handler} server := http.Server{Addr: s.address, Handler: s.handler}
go func() { go func() {
<-ctx.Done() <-ctx.Done()
s.logger.Warn("shutting down server")
defer s.logger.Warn("server shut down")
const shutdownGraceDuration = 2 * time.Second const shutdownGraceDuration = 2 * time.Second
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration) shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
defer cancel() defer cancel()
@@ -47,8 +44,10 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup, errorCh chan<- err
}() }()
s.logger.Info("listening on %s", s.address) s.logger.Info("listening on %s", s.address)
err := server.ListenAndServe() err := server.ListenAndServe()
s.internalWG.Wait()
if err != nil && ctx.Err() == nil { if err != nil && ctx.Err() == nil {
errorCh <- err errorCh <- err
} else {
errorCh <- nil
} }
s.internalWG.Wait()
} }

View File

@@ -2,15 +2,14 @@ package openvpn
import ( import (
"strings" "strings"
"sync"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
func (l *looper) collectLines(wg *sync.WaitGroup, stdout, stderr <-chan string) { func (l *looper) collectLines(stdout, stderr <-chan string, done chan<- struct{}) {
defer wg.Done() defer close(done)
var line string var line string
var ok, errLine bool var ok, errLine bool

View File

@@ -19,7 +19,7 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, done chan<- struct{})
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(status models.LoopStatus) (outcome string, err error)
GetSettings() (settings configuration.OpenVPN) GetSettings() (settings configuration.OpenVPN)
@@ -104,14 +104,13 @@ func (l *looper) signalCrashedStatus() {
} }
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocognit func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocognit
defer wg.Done() defer close(done)
select { select {
case <-l.start: case <-l.start:
case <-ctx.Done(): case <-ctx.Done():
return return
} }
defer l.logger.Warn("loop exited")
for ctx.Err() == nil { for ctx.Err() == nil {
settings, allServers := l.state.getSettingsAndServers() settings, allServers := l.state.getSettingsAndServers()
@@ -166,20 +165,19 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocogni
continue continue
} }
wg.Add(1) lineCollectionDone := make(chan struct{})
go l.collectLines(wg, stdoutLines, stderrLines) go l.collectLines(stdoutLines, stderrLines, lineCollectionDone)
// Needs the stream line from main.go to know when the tunnel is up // Needs the stream line from main.go to know when the tunnel is up
portForwardDone := make(chan struct{})
go func(ctx context.Context) { go func(ctx context.Context) {
for { defer close(portForwardDone)
select { select {
// TODO have a way to disable pf with a context // TODO have a way to disable pf with a context
case <-ctx.Done(): case <-ctx.Done():
return return
case gateway := <-l.portForwardSignals: case gateway := <-l.portForwardSignals:
wg.Add(1) l.portForward(ctx, providerConf, l.client, gateway)
go l.portForward(ctx, wg, providerConf, l.client, gateway)
}
} }
}(openvpnCtx) }(openvpnCtx)
@@ -195,12 +193,13 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { //nolint:gocogni
for stayHere { for stayHere {
select { select {
case <-ctx.Done(): case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
openvpnCancel() openvpnCancel()
<-waitError <-waitError
close(waitError) close(waitError)
close(stdoutLines) close(stdoutLines)
close(stderrLines) close(stderrLines)
<-lineCollectionDone
<-portForwardDone
return return
case <-l.stop: case <-l.stop:
l.logger.Info("stopping") l.logger.Info("stopping")
@@ -288,9 +287,8 @@ func (l *looper) waitForHealth(ctx context.Context) (healthy bool) {
// portForward is a blocking operation which may or may not be infinite. // portForward is a blocking operation which may or may not be infinite.
// You should therefore always call it in a goroutine. // You should therefore always call it in a goroutine.
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup, func (l *looper) portForward(ctx context.Context,
providerConf provider.Provider, client *http.Client, gateway net.IP) { providerConf provider.Provider, client *http.Client, gateway net.IP) {
defer wg.Done()
l.state.portForwardedMu.RLock() l.state.portForwardedMu.RLock()
settings := l.state.settings settings := l.state.settings
l.state.portForwardedMu.RUnlock() l.state.portForwardedMu.RUnlock()

View File

@@ -29,8 +29,6 @@ var (
func (p *PIA) PortForward(ctx context.Context, client *http.Client, func (p *PIA) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
defer logger.Warn("loop exited")
commonName := p.activeServer.ServerName commonName := p.activeServer.ServerName
if !p.activeServer.PortForward { if !p.activeServer.PortForward {
logger.Error("The server " + commonName + logger.Error("The server " + commonName +

View File

@@ -15,8 +15,8 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, done chan<- struct{})
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, done chan<- struct{})
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(status models.LoopStatus) (outcome string, err error)
GetSettings() (settings configuration.PublicIP) GetSettings() (settings configuration.PublicIP)
@@ -91,8 +91,8 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
crashed := false crashed := false
@@ -101,7 +101,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
case <-ctx.Done(): case <-ctx.Done():
return return
} }
defer l.logger.Warn("loop exited")
for ctx.Err() == nil { for ctx.Err() == nil {
getCtx, getCancel := context.WithCancel(ctx) getCtx, getCancel := context.WithCancel(ctx)
@@ -132,11 +131,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
for stayHere { for stayHere {
select { select {
case <-ctx.Done(): case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
getCancel() getCancel()
close(errorCh) close(errorCh)
filepath := l.GetSettings().IPFilepath filepath := l.GetSettings().IPFilepath
l.logger.Info("Removing ip file %s", filepath) l.logger.Info("Removing ip file " + filepath)
if err := l.os.Remove(filepath); err != nil { if err := l.os.Remove(filepath); err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
@@ -181,8 +179,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
} }
} }
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
timer := time.NewTimer(time.Hour) timer := time.NewTimer(time.Hour)
timer.Stop() // 1 hour, cannot be a race condition timer.Stop() // 1 hour, cannot be a race condition
timerIsStopped := true timerIsStopped := true

View File

@@ -4,7 +4,6 @@ package server
import ( import (
"context" "context"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/qdm12/gluetun/internal/dns" "github.com/qdm12/gluetun/internal/dns"
@@ -16,7 +15,7 @@ import (
) )
type Server interface { type Server interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, done chan<- struct{})
} }
type server struct { type server struct {
@@ -39,12 +38,11 @@ func New(address string, logEnabled bool, logger logging.Logger,
} }
} }
func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { func (s *server) Run(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
server := http.Server{Addr: s.address, Handler: s.handler} server := http.Server{Addr: s.address, Handler: s.handler}
go func() { go func() {
<-ctx.Done() <-ctx.Done()
s.logger.Warn("context canceled: shutting down")
const shutdownGraceDuration = 2 * time.Second const shutdownGraceDuration = 2 * time.Second
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration) shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration)
defer cancel() defer cancel()
@@ -57,5 +55,4 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) {
if err != nil && ctx.Err() != context.Canceled { if err != nil && ctx.Err() != context.Canceled {
s.logger.Error(err) s.logger.Error(err)
} }
s.logger.Warn("shut down")
} }

View File

@@ -15,7 +15,7 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, done chan<- struct{})
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(status models.LoopStatus) (outcome string, err error)
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
GetSettings() (settings configuration.ShadowSocks) GetSettings() (settings configuration.ShadowSocks)
@@ -67,8 +67,8 @@ func NewLooper(settings configuration.ShadowSocks, logger logging.Logger) Looper
} }
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
crashed := false crashed := false
@@ -84,8 +84,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
return return
} }
defer l.logger.Warn("loop exited")
for ctx.Err() == nil { for ctx.Err() == nil {
settings := l.GetSettings() settings := l.GetSettings()
server, err := shadowsockslib.NewServer(settings.Method, settings.Password, adaptLogger(l.logger, settings.Log)) server, err := shadowsockslib.NewServer(settings.Method, settings.Password, adaptLogger(l.logger, settings.Log))
@@ -114,7 +112,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
for stayHere { for stayHere {
select { select {
case <-ctx.Done(): case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
shadowsocksCancel() shadowsocksCancel()
<-waitError <-waitError
close(waitError) close(waitError)

View File

@@ -0,0 +1,50 @@
package shutdown
import (
"context"
"errors"
"fmt"
"time"
"github.com/qdm12/golibs/logging"
)
type Order interface {
Append(waves ...Wave)
Shutdown(timeout time.Duration, logger logging.Logger) (err error)
}
type order struct {
waves []Wave
total int // for logging only
}
func NewOrder() Order {
return &order{}
}
var ErrIncomplete = errors.New("one or more routines did not terminate gracefully")
func (o *order) Append(waves ...Wave) {
o.waves = append(o.waves, waves...)
}
func (o *order) Shutdown(timeout time.Duration, logger logging.Logger) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
total := 0
incomplete := 0
for _, wave := range o.waves {
total += wave.size()
incomplete += wave.shutdown(ctx, logger)
}
if incomplete == 0 {
return nil
}
return fmt.Errorf("%w: %d not terminated on %d routines",
ErrIncomplete, incomplete, total)
}

View File

@@ -0,0 +1,39 @@
package shutdown
import (
"context"
"fmt"
"time"
)
type routine struct {
name string
cancel context.CancelFunc
done <-chan struct{}
timeout time.Duration
}
func newRoutine(name string) (r routine,
ctx context.Context, done chan struct{}) {
ctx, cancel := context.WithCancel(context.Background())
done = make(chan struct{})
return routine{
name: name,
cancel: cancel,
done: done,
}, ctx, done
}
func (r *routine) shutdown(ctx context.Context) (err error) {
ctx, cancel := context.WithTimeout(ctx, r.timeout)
defer cancel()
r.cancel()
select {
case <-r.done:
return nil
case <-ctx.Done():
return fmt.Errorf("for routine %q: %w", r.name, ctx.Err())
}
}

66
internal/shutdown/wave.go Normal file
View File

@@ -0,0 +1,66 @@
package shutdown
import (
"context"
"time"
"github.com/qdm12/golibs/logging"
)
type Wave interface {
Add(name string, timeout time.Duration) (
ctx context.Context, done chan struct{})
size() int
shutdown(ctx context.Context, logger logging.Logger) (incomplete int)
}
type wave struct {
name string
routines []routine
}
func NewWave(name string) Wave {
return &wave{
name: name,
}
}
func (w *wave) Add(name string, timeout time.Duration) (ctx context.Context, done chan struct{}) {
ctx, cancel := context.WithCancel(context.Background())
done = make(chan struct{})
routine := routine{
name: name,
cancel: cancel,
done: done,
timeout: timeout,
}
w.routines = append(w.routines, routine)
return ctx, done
}
func (w *wave) size() int { return len(w.routines) }
func (w *wave) shutdown(ctx context.Context, logger logging.Logger) (incomplete int) {
completed := make(chan bool)
for _, r := range w.routines {
go func(r routine) {
if err := r.shutdown(ctx); err != nil {
logger.Warn(w.name + " routines: " + err.Error() + " ⚠️")
completed <- false
} else {
logger.Info(w.name + " routines: " + r.name + " terminated ✔️")
completed <- err == nil
}
}(r)
}
for range w.routines {
c := <-completed
if !c {
incomplete++
}
}
return incomplete
}

View File

@@ -14,8 +14,8 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, done chan<- struct{})
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, done chan<- struct{})
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(status models.LoopStatus) (outcome string, err error)
GetSettings() (settings configuration.Updater) GetSettings() (settings configuration.Updater)
@@ -84,15 +84,14 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
crashed := false crashed := false
select { select {
case <-l.start: case <-l.start:
case <-ctx.Done(): case <-ctx.Done():
return return
} }
defer l.logger.Warn("loop exited")
for ctx.Err() == nil { for ctx.Err() == nil {
updateCtx, updateCancel := context.WithCancel(ctx) updateCtx, updateCancel := context.WithCancel(ctx)
@@ -125,7 +124,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
for stayHere { for stayHere {
select { select {
case <-ctx.Done(): case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
updateCancel() updateCancel()
runWg.Wait() runWg.Wait()
close(errorCh) close(errorCh)
@@ -162,8 +160,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
} }
} }
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) RunRestartTicker(ctx context.Context, done chan<- struct{}) {
defer wg.Done() defer close(done)
timer := time.NewTimer(time.Hour) timer := time.NewTimer(time.Hour)
timer.Stop() timer.Stop()
timerIsStopped := true timerIsStopped := true