Revisit waitgroup (#241)
* Fix Add to waitgroup out of goroutines calling wg.Done() * Pass waitgroup to other loop functions
This commit is contained in:
@@ -194,6 +194,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
|||||||
portForward := openvpnLooper.PortForward
|
portForward := openvpnLooper.PortForward
|
||||||
getOpenvpnSettings := openvpnLooper.GetSettings
|
getOpenvpnSettings := openvpnLooper.GetSettings
|
||||||
getPortForwarded := openvpnLooper.GetPortForwarded
|
getPortForwarded := openvpnLooper.GetPortForwarded
|
||||||
|
wg.Add(1)
|
||||||
// wait for restartOpenvpn
|
// wait for restartOpenvpn
|
||||||
go openvpnLooper.Run(ctx, wg)
|
go openvpnLooper.Run(ctx, wg)
|
||||||
|
|
||||||
@@ -205,22 +206,27 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
|||||||
|
|
||||||
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
|
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
|
||||||
restartUnbound := unboundLooper.Restart
|
restartUnbound := unboundLooper.Restart
|
||||||
// wait for restartUnbound
|
wg.Add(1)
|
||||||
|
// wait for restartUnbound or its ticker launched with RunRestartTicker
|
||||||
go unboundLooper.Run(ctx, wg)
|
go unboundLooper.Run(ctx, wg)
|
||||||
|
|
||||||
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid)
|
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid)
|
||||||
restartPublicIP := publicIPLooper.Restart
|
restartPublicIP := publicIPLooper.Restart
|
||||||
setPublicIPPeriod := publicIPLooper.SetPeriod
|
setPublicIPPeriod := publicIPLooper.SetPeriod
|
||||||
go publicIPLooper.Run(ctx)
|
wg.Add(1)
|
||||||
go publicIPLooper.RunRestartTicker(ctx)
|
go publicIPLooper.Run(ctx, wg)
|
||||||
|
wg.Add(1)
|
||||||
|
go publicIPLooper.RunRestartTicker(ctx, wg)
|
||||||
setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker
|
setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker
|
||||||
|
|
||||||
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface)
|
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface)
|
||||||
restartTinyproxy := tinyproxyLooper.Restart
|
restartTinyproxy := tinyproxyLooper.Restart
|
||||||
|
wg.Add(1)
|
||||||
go tinyproxyLooper.Run(ctx, wg)
|
go tinyproxyLooper.Run(ctx, wg)
|
||||||
|
|
||||||
shadowsocksLooper := shadowsocks.NewLooper(firewallConf, allSettings.ShadowSocks, logger, defaultInterface)
|
shadowsocksLooper := shadowsocks.NewLooper(firewallConf, allSettings.ShadowSocks, logger, defaultInterface)
|
||||||
restartShadowsocks := shadowsocksLooper.Restart
|
restartShadowsocks := shadowsocksLooper.Restart
|
||||||
|
wg.Add(1)
|
||||||
go shadowsocksLooper.Run(ctx, wg)
|
go shadowsocksLooper.Run(ctx, wg)
|
||||||
|
|
||||||
if allSettings.TinyProxy.Enabled {
|
if allSettings.TinyProxy.Enabled {
|
||||||
@@ -241,19 +247,26 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
|||||||
}
|
}
|
||||||
logger.Info(message)
|
logger.Info(message)
|
||||||
}
|
}
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
tickerWg := &sync.WaitGroup{}
|
||||||
|
// for linters only
|
||||||
var restartTickerContext context.Context
|
var restartTickerContext context.Context
|
||||||
var restartTickerCancel context.CancelFunc = func() {}
|
var restartTickerCancel context.CancelFunc = func() {}
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
restartTickerCancel()
|
restartTickerCancel() // for linters only
|
||||||
|
tickerWg.Wait()
|
||||||
return
|
return
|
||||||
case <-connectedCh: // blocks until openvpn is connected
|
case <-connectedCh: // blocks until openvpn is connected
|
||||||
restartTickerCancel()
|
restartTickerCancel() // stop previous restart tickers
|
||||||
|
tickerWg.Wait()
|
||||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||||
go unboundLooper.RunRestartTicker(restartTickerContext)
|
tickerWg.Add(2)
|
||||||
go updaterLooper.RunRestartTicker(ctx)
|
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
||||||
|
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
||||||
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation)
|
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP, versionInformation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -261,6 +274,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
|||||||
|
|
||||||
httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound, updaterLooper.Restart,
|
httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound, updaterLooper.Restart,
|
||||||
getOpenvpnSettings, getPortForwarded)
|
getOpenvpnSettings, getPortForwarded)
|
||||||
|
wg.Add(1)
|
||||||
go httpServer.Run(ctx, wg)
|
go httpServer.Run(ctx, wg)
|
||||||
|
|
||||||
// Start openvpn for the first time
|
// Start openvpn for the first time
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
RunRestartTicker(ctx context.Context)
|
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
||||||
Restart()
|
Restart()
|
||||||
Start()
|
Start()
|
||||||
Stop()
|
Stop()
|
||||||
@@ -139,7 +139,6 @@ func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel conte
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
const fallback = false
|
const fallback = false
|
||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(fallback)
|
||||||
@@ -282,7 +281,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
|
|||||||
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Providers)
|
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) RunRestartTicker(ctx context.Context) {
|
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
ticker := time.NewTicker(time.Hour)
|
ticker := time.NewTicker(time.Hour)
|
||||||
settings := l.GetSettings()
|
settings := l.GetSettings()
|
||||||
if settings.UpdatePeriod > 0 {
|
if settings.UpdatePeriod > 0 {
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ func (l *looper) SetAllServers(allServers models.AllServers) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
select {
|
select {
|
||||||
case <-l.restart:
|
case <-l.restart:
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
RunRestartTicker(ctx context.Context)
|
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
||||||
Restart()
|
Restart()
|
||||||
Stop()
|
Stop()
|
||||||
GetPeriod() (period time.Duration)
|
GetPeriod() (period time.Duration)
|
||||||
@@ -74,7 +74,8 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
select {
|
select {
|
||||||
case <-l.restart:
|
case <-l.restart:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -124,7 +125,8 @@ func (l *looper) Run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) RunRestartTicker(ctx context.Context) {
|
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
ticker := time.NewTicker(time.Hour)
|
ticker := time.NewTicker(time.Hour)
|
||||||
period := l.GetPeriod()
|
period := l.GetPeriod()
|
||||||
if period > 0 {
|
if period > 0 {
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
|
||||||
server := http.Server{Addr: s.address, Handler: s.makeHandler()}
|
server := http.Server{Addr: s.address, Handler: s.makeHandler()}
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ func (l *looper) setEnabled(enabled bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
waitForStart := true
|
waitForStart := true
|
||||||
for waitForStart {
|
for waitForStart {
|
||||||
|
|||||||
@@ -89,7 +89,6 @@ func (l *looper) Start() { l.start <- struct{}{} }
|
|||||||
func (l *looper) Stop() { l.stop <- struct{}{} }
|
func (l *looper) Stop() { l.stop <- struct{}{} }
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
waitForStart := true
|
waitForStart := true
|
||||||
for waitForStart {
|
for waitForStart {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
RunRestartTicker(ctx context.Context)
|
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
|
||||||
Restart()
|
Restart()
|
||||||
Stop()
|
Stop()
|
||||||
GetPeriod() (period time.Duration)
|
GetPeriod() (period time.Duration)
|
||||||
@@ -123,7 +123,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) RunRestartTicker(ctx context.Context) {
|
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
ticker := time.NewTicker(time.Hour)
|
ticker := time.NewTicker(time.Hour)
|
||||||
period := l.GetPeriod()
|
period := l.GetPeriod()
|
||||||
if period > 0 {
|
if period > 0 {
|
||||||
|
|||||||
Reference in New Issue
Block a user