From 616ba0c538a64f376cc2758e031688533da8a507 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 15 Jul 2020 01:34:46 +0000 Subject: [PATCH] Replace explicit channels with functions --- cmd/gluetun/main.go | 44 ++++++++++++++++-------------------- internal/dns/loop.go | 21 ++++++++++------- internal/openvpn/loop.go | 42 +++++++++++++++++++++------------- internal/publicip/loop.go | 19 ++++++++++------ internal/server/server.go | 10 ++++---- internal/shadowsocks/loop.go | 13 +++++++---- internal/tinyproxy/loop.go | 13 +++++++---- 7 files changed, 94 insertions(+), 68 deletions(-) diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index b81b9fc7..455715db 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -117,14 +117,6 @@ func _main(background context.Context, args []string) int { defer close(connectedCh) go collectStreamLines(ctx, streamMerger, logger, signalConnected) - // TODO replace these with methods on loopers and pass loopers around - restartOpenvpn := make(chan struct{}) - portForward := make(chan struct{}) - restartUnbound := make(chan struct{}) - restartPublicIP := make(chan struct{}) - restartTinyproxy := make(chan struct{}) - restartShadowsocks := make(chan struct{}) - if allSettings.Firewall.Enabled { err := firewallConf.SetEnabled(ctx, true) // disabled by default fatalOnError(err) @@ -135,28 +127,34 @@ func _main(background context.Context, args []string) int { openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError) + restartOpenvpn := openvpnLooper.Restart + portForward := openvpnLooper.PortForward // wait for restartOpenvpn - go openvpnLooper.Run(ctx, restartOpenvpn, portForward, wg) + go openvpnLooper.Run(ctx, wg) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid) + restartUnbound := unboundLooper.Restart // wait for restartUnbound - go unboundLooper.Run(ctx, restartUnbound, wg) + go unboundLooper.Run(ctx, wg) publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, uid, gid) - go publicIPLooper.Run(ctx, restartPublicIP) - go publicIPLooper.RunRestartTicker(ctx, restartPublicIP) + restartPublicIP := publicIPLooper.Restart + go publicIPLooper.Run(ctx) + go publicIPLooper.RunRestartTicker(ctx) tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid) - go tinyproxyLooper.Run(ctx, restartTinyproxy, wg) + restartTinyproxy := tinyproxyLooper.Restart + go tinyproxyLooper.Run(ctx, wg) shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid) - go shadowsocksLooper.Run(ctx, restartShadowsocks, wg) + restartShadowsocks := shadowsocksLooper.Restart + go shadowsocksLooper.Run(ctx, wg) if allSettings.TinyProxy.Enabled { - restartTinyproxy <- struct{}{} + restartTinyproxy() } if allSettings.ShadowSocks.Enabled { - restartShadowsocks <- struct{}{} + restartShadowsocks() } go func() { @@ -170,7 +168,7 @@ func _main(background context.Context, args []string) int { case <-connectedCh: // blocks until openvpn is connected restartTickerCancel() restartTickerContext, restartTickerCancel = context.WithCancel(ctx) - go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound) + go unboundLooper.RunRestartTicker(restartTickerContext) onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP) } } @@ -180,7 +178,7 @@ func _main(background context.Context, args []string) int { go httpServer.Run(ctx, wg) // Start openvpn for the first time - restartOpenvpn <- struct{}{} + restartOpenvpn() signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, @@ -291,14 +289,12 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, } func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing, - portForward, restartUnbound, restartPublicIP chan<- struct{}, + portForward, restartUnbound, restartPublicIP func(), ) { - restartUnbound <- struct{}{} - restartPublicIP <- struct{}{} + restartUnbound() + restartPublicIP() if allSettings.OpenVPN.Provider.PortForwarding.Enabled { - time.AfterFunc(5*time.Second, func() { - portForward <- struct{}{} - }) + time.AfterFunc(5*time.Second, portForward) } defaultInterface, _, err := routingConf.DefaultRoute() if err != nil { diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 6fffad7e..1eb7a66a 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -13,8 +13,9 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) - RunRestartTicker(ctx context.Context, restart chan<- struct{}) + Run(ctx context.Context, wg *sync.WaitGroup) + RunRestartTicker(ctx context.Context) + Restart() } type looper struct { @@ -24,6 +25,7 @@ type looper struct { streamMerger command.StreamMerger uid int gid int + restart chan struct{} } func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, @@ -35,9 +37,12 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, uid: uid, gid: gid, streamMerger: streamMerger, + restart: make(chan struct{}), } } +func (l *looper) Restart() { l.restart <- struct{}{} } + func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Warn(err) l.logger.Info("attempting restart in 10 seconds") @@ -46,12 +51,12 @@ func (l *looper) logAndWait(ctx context.Context, err error) { <-ctx.Done() } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { +func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() l.fallbackToUnencryptedDNS() select { - case <-restart: + case <-l.restart: case <-ctx.Done(): return } @@ -65,7 +70,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait if !l.settings.Enabled { // wait for another restart signal to recheck if it is enabled select { - case <-restart: + case <-l.restart: case <-ctx.Done(): unboundCancel() return @@ -127,7 +132,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait <-waitError close(waitError) return - case <-restart: // triggered restart + case <-l.restart: // triggered restart l.logger.Info("restarting") // unboundCancel occurs next loop run when the setup is complete triggeredRestart = true @@ -172,7 +177,7 @@ func (l *looper) fallbackToUnencryptedDNS() { l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers) } -func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) { +func (l *looper) RunRestartTicker(ctx context.Context) { if l.settings.UpdatePeriod == 0 { return } @@ -183,7 +188,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) ticker.Stop() return case <-ticker.C: - restart <- struct{}{} + l.restart <- struct{}{} } } } diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 40153e0b..080b606c 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -18,7 +18,9 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) + Run(ctx context.Context, wg *sync.WaitGroup) + Restart() + PortForward() } type looper struct { @@ -37,6 +39,9 @@ type looper struct { fileManager files.FileManager streamMerger command.StreamMerger fatalOnError func(err error) + // Internal channels + restart chan struct{} + portForwardSignals chan struct{} } func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, @@ -45,25 +50,30 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, logger logging.Logger, client network.Client, fileManager files.FileManager, streamMerger command.StreamMerger, fatalOnError func(err error)) Looper { return &looper{ - provider: provider, - settings: settings, - uid: uid, - gid: gid, - conf: conf, - fw: fw, - logger: logger.WithPrefix("openvpn: "), - client: client, - fileManager: fileManager, - streamMerger: streamMerger, - fatalOnError: fatalOnError, + provider: provider, + settings: settings, + uid: uid, + gid: gid, + conf: conf, + fw: fw, + logger: logger.WithPrefix("openvpn: "), + client: client, + fileManager: fileManager, + streamMerger: streamMerger, + fatalOnError: fatalOnError, + restart: make(chan struct{}), + portForwardSignals: make(chan struct{}), } } -func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) { +func (l *looper) Restart() { l.restart <- struct{}{} } +func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} } + +func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() select { - case <-restart: + case <-l.restart: case <-ctx.Done(): return } @@ -107,7 +117,7 @@ func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, select { case <-ctx.Done(): return - case <-portForward: + case <-l.portForwardSignals: l.portForward(ctx, providerConf, l.client) } } @@ -126,7 +136,7 @@ func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, <-waitError close(waitError) return - case <-restart: // triggered restart + case <-l.restart: // triggered restart l.logger.Info("restarting") openvpnCancel() <-waitError diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index be374288..6062e95a 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -11,8 +11,9 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}) - RunRestartTicker(ctx context.Context, restart chan<- struct{}) + Run(ctx context.Context) + RunRestartTicker(ctx context.Context) + Restart() } type looper struct { @@ -22,6 +23,7 @@ type looper struct { ipStatusFilepath models.Filepath uid int gid int + restart chan struct{} } func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager, @@ -33,9 +35,12 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F ipStatusFilepath: ipStatusFilepath, uid: uid, gid: gid, + restart: make(chan struct{}), } } +func (l *looper) Restart() { l.restart <- struct{}{} } + func (l *looper) logAndWait(ctx context.Context, err error) { l.logger.Error(err) l.logger.Info("retrying in 5 seconds") @@ -44,9 +49,9 @@ func (l *looper) logAndWait(ctx context.Context, err error) { <-ctx.Done() } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}) { +func (l *looper) Run(ctx context.Context) { select { - case <-restart: + case <-l.restart: case <-ctx.Done(): return } @@ -69,7 +74,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) { continue } select { - case <-restart: // triggered restart + case <-l.restart: // triggered restart case <-ctx.Done(): l.logger.Warn("context canceled: exiting loop") return @@ -77,7 +82,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) { } } -func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) { +func (l *looper) RunRestartTicker(ctx context.Context) { ticker := time.NewTicker(time.Hour) for { select { @@ -85,7 +90,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) ticker.Stop() return case <-ticker.C: - restart <- struct{}{} + l.restart <- struct{}{} } } } diff --git a/internal/server/server.go b/internal/server/server.go index 397c9cc0..65b055c4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,11 +17,11 @@ type Server interface { type server struct { address string logger logging.Logger - restartOpenvpn chan<- struct{} - restartUnbound chan<- struct{} + restartOpenvpn func() + restartUnbound func() } -func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound chan<- struct{}) Server { +func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound func()) Server { return &server{ address: address, logger: logger.WithPrefix("http server: "), @@ -58,9 +58,9 @@ func (s *server) makeHandler() http.HandlerFunc { case http.MethodGet: switch r.RequestURI { case "/openvpn/actions/restart": - s.restartOpenvpn <- struct{}{} + s.restartOpenvpn() case "/unbound/actions/restart": - s.restartUnbound <- struct{}{} + s.restartUnbound() default: routeDoesNotExist(s.logger, w, r) } diff --git a/internal/shadowsocks/loop.go b/internal/shadowsocks/loop.go index d5dabec6..ce637941 100644 --- a/internal/shadowsocks/loop.go +++ b/internal/shadowsocks/loop.go @@ -12,7 +12,8 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) + Run(ctx context.Context, wg *sync.WaitGroup) + Restart() } type looper struct { @@ -24,6 +25,7 @@ type looper struct { streamMerger command.StreamMerger uid int gid int + restart chan struct{} } func (l *looper) logAndWait(ctx context.Context, err error) { @@ -45,14 +47,17 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s streamMerger: streamMerger, uid: uid, gid: gid, + restart: make(chan struct{}), } } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { +func (l *looper) Restart() { l.restart <- struct{}{} } + +func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() select { - case <-restart: + case <-l.restart: case <-ctx.Done(): return } @@ -109,7 +114,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait <-waitError close(waitError) return - case <-restart: // triggered restart + case <-l.restart: // triggered restart l.logger.Info("restarting") shadowsocksCancel() <-waitError diff --git a/internal/tinyproxy/loop.go b/internal/tinyproxy/loop.go index f01a48fb..02f0c159 100644 --- a/internal/tinyproxy/loop.go +++ b/internal/tinyproxy/loop.go @@ -12,7 +12,8 @@ import ( ) type Looper interface { - Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) + Run(ctx context.Context, wg *sync.WaitGroup) + Restart() } type looper struct { @@ -23,6 +24,7 @@ type looper struct { streamMerger command.StreamMerger uid int gid int + restart chan struct{} } func (l *looper) logAndWait(ctx context.Context, err error) { @@ -43,14 +45,17 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s streamMerger: streamMerger, uid: uid, gid: gid, + restart: make(chan struct{}), } } -func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { +func (l *looper) Restart() { l.restart <- struct{}{} } + +func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() select { - case <-restart: + case <-l.restart: case <-ctx.Done(): return } @@ -102,7 +107,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait <-waitError close(waitError) return - case <-restart: // triggered restart + case <-l.restart: // triggered restart l.logger.Info("restarting") tinyproxyCancel() <-waitError