Revisit waitgroup (#241)

* Fix Add to waitgroup out of goroutines calling wg.Done()
* Pass waitgroup to other loop functions
This commit is contained in:
Quentin McGaw
2020-09-12 14:34:15 -04:00
committed by GitHub
parent 1c012e4c92
commit e0e450ca1c
8 changed files with 33 additions and 20 deletions

View File

@@ -194,6 +194,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
portForward := openvpnLooper.PortForward portForward := openvpnLooper.PortForward
getOpenvpnSettings := openvpnLooper.GetSettings getOpenvpnSettings := openvpnLooper.GetSettings
getPortForwarded := openvpnLooper.GetPortForwarded getPortForwarded := openvpnLooper.GetPortForwarded
wg.Add(1)
// wait for restartOpenvpn // wait for restartOpenvpn
go openvpnLooper.Run(ctx, wg) go openvpnLooper.Run(ctx, wg)
@@ -205,22 +206,27 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
restartUnbound := unboundLooper.Restart restartUnbound := unboundLooper.Restart
// wait for restartUnbound wg.Add(1)
// wait for restartUnbound or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg) go unboundLooper.Run(ctx, wg)
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid) publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid)
restartPublicIP := publicIPLooper.Restart restartPublicIP := publicIPLooper.Restart
setPublicIPPeriod := publicIPLooper.SetPeriod setPublicIPPeriod := publicIPLooper.SetPeriod
go publicIPLooper.Run(ctx) wg.Add(1)
go publicIPLooper.RunRestartTicker(ctx) go publicIPLooper.Run(ctx, wg)
wg.Add(1)
go publicIPLooper.RunRestartTicker(ctx, wg)
setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface) tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface)
restartTinyproxy := tinyproxyLooper.Restart restartTinyproxy := tinyproxyLooper.Restart
wg.Add(1)
go tinyproxyLooper.Run(ctx, wg) go tinyproxyLooper.Run(ctx, wg)
shadowsocksLooper := shadowsocks.NewLooper(firewallConf, allSettings.ShadowSocks, logger, defaultInterface) shadowsocksLooper := shadowsocks.NewLooper(firewallConf, allSettings.ShadowSocks, logger, defaultInterface)
restartShadowsocks := shadowsocksLooper.Restart restartShadowsocks := shadowsocksLooper.Restart
wg.Add(1)
go shadowsocksLooper.Run(ctx, wg) go shadowsocksLooper.Run(ctx, wg)
if allSettings.TinyProxy.Enabled { if allSettings.TinyProxy.Enabled {
@@ -241,19 +247,26 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
} }
logger.Info(message) logger.Info(message)
} }
wg.Add(1)
go func() { go func() {
defer wg.Done()
tickerWg := &sync.WaitGroup{}
// for linters only
var restartTickerContext context.Context var restartTickerContext context.Context
var restartTickerCancel context.CancelFunc = func() {} var restartTickerCancel context.CancelFunc = func() {}
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
restartTickerCancel() restartTickerCancel() // for linters only
tickerWg.Wait()
return return
case <-connectedCh: // blocks until openvpn is connected case <-connectedCh: // blocks until openvpn is connected
restartTickerCancel() restartTickerCancel() // stop previous restart tickers
tickerWg.Wait()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx) restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
go unboundLooper.RunRestartTicker(restartTickerContext) tickerWg.Add(2)
go updaterLooper.RunRestartTicker(ctx) go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation) onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation)
} }
} }
@@ -261,6 +274,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound, updaterLooper.Restart, httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound, updaterLooper.Restart,
getOpenvpnSettings, getPortForwarded) getOpenvpnSettings, getPortForwarded)
wg.Add(1)
go httpServer.Run(ctx, wg) go httpServer.Run(ctx, wg)
// Start openvpn for the first time // Start openvpn for the first time

View File

@@ -14,7 +14,7 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart() Restart()
Start() Start()
Stop() Stop()
@@ -139,7 +139,6 @@ func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel conte
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done() defer wg.Done()
const fallback = false const fallback = false
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
@@ -282,7 +281,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Providers) l.logger.Error("no ipv4 DNS address found for providers %s", settings.Providers)
} }
func (l *looper) RunRestartTicker(ctx context.Context) { func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
ticker := time.NewTicker(time.Hour) ticker := time.NewTicker(time.Hour)
settings := l.GetSettings() settings := l.GetSettings()
if settings.UpdatePeriod > 0 { if settings.UpdatePeriod > 0 {

View File

@@ -98,7 +98,6 @@ func (l *looper) SetAllServers(allServers models.AllServers) {
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done() defer wg.Done()
select { select {
case <-l.restart: case <-l.restart:

View File

@@ -12,8 +12,8 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context) Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart() Restart()
Stop() Stop()
GetPeriod() (period time.Duration) GetPeriod() (period time.Duration)
@@ -74,7 +74,8 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
<-ctx.Done() <-ctx.Done()
} }
func (l *looper) Run(ctx context.Context) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
select { select {
case <-l.restart: case <-l.restart:
case <-ctx.Done(): case <-ctx.Done():
@@ -124,7 +125,8 @@ func (l *looper) Run(ctx context.Context) {
} }
} }
func (l *looper) RunRestartTicker(ctx context.Context) { func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
ticker := time.NewTicker(time.Hour) ticker := time.NewTicker(time.Hour)
period := l.GetPeriod() period := l.GetPeriod()
if period > 0 { if period > 0 {

View File

@@ -42,7 +42,6 @@ func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound,
} }
func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
server := http.Server{Addr: s.address, Handler: s.makeHandler()} server := http.Server{Addr: s.address, Handler: s.makeHandler()}
go func() { go func() {
defer wg.Done() defer wg.Done()

View File

@@ -82,7 +82,6 @@ func (l *looper) setEnabled(enabled bool) {
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done() defer wg.Done()
waitForStart := true waitForStart := true
for waitForStart { for waitForStart {

View File

@@ -89,7 +89,6 @@ func (l *looper) Start() { l.start <- struct{}{} }
func (l *looper) Stop() { l.stop <- struct{}{} } func (l *looper) Stop() { l.stop <- struct{}{} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done() defer wg.Done()
waitForStart := true waitForStart := true
for waitForStart { for waitForStart {

View File

@@ -13,7 +13,7 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart() Restart()
Stop() Stop()
GetPeriod() (period time.Duration) GetPeriod() (period time.Duration)
@@ -123,7 +123,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
} }
} }
func (l *looper) RunRestartTicker(ctx context.Context) { func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
ticker := time.NewTicker(time.Hour) ticker := time.NewTicker(time.Hour)
period := l.GetPeriod() period := l.GetPeriod()
if period > 0 { if period > 0 {