diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 90330b14..253cc04d 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -156,10 +156,10 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go } } - connectedCh, dnsReadyCh := make(chan struct{}), make(chan struct{}) - signalConnected := func() { connectedCh <- struct{}{} } + tunnelReadyCh, dnsReadyCh := make(chan struct{}), make(chan struct{}) + signalTunnelReady := func() { tunnelReadyCh <- struct{}{} } signalDNSReady := func() { dnsReadyCh <- struct{}{} } - defer close(connectedCh) + defer close(tunnelReadyCh) defer close(dnsReadyCh) if allSettings.Firewall.Enabled { @@ -186,14 +186,10 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go wg := &sync.WaitGroup{} - go collectStreamLines(ctx, streamMerger, logger, signalConnected) + go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers, ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel) - restartOpenvpn := openvpnLooper.Restart - portForward := openvpnLooper.PortForward - getOpenvpnSettings := openvpnLooper.GetSettings - getPortForwarded := openvpnLooper.GetPortForwarded wg.Add(1) // wait for restartOpenvpn go openvpnLooper.Run(ctx, wg) @@ -205,19 +201,16 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go go updaterLooper.Run(ctx, wg) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid) - restartUnbound := unboundLooper.Restart wg.Add(1) - // wait for restartUnbound or its ticker launched with RunRestartTicker + // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker go unboundLooper.Run(ctx, wg, signalDNSReady) publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid) - restartPublicIP := publicIPLooper.Restart - setPublicIPPeriod := publicIPLooper.SetPeriod wg.Add(1) go publicIPLooper.Run(ctx, wg) wg.Add(1) go publicIPLooper.RunRestartTicker(ctx, wg) - setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker + publicIPLooper.SetPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface) restartTinyproxy := tinyproxyLooper.Restart @@ -236,52 +229,18 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go restartShadowsocks() } - versionInformation := func() { - if !allSettings.VersionInformation { - return - } - message, err := versionpkg.GetMessage(version, commit, httpClient) - if err != nil { - logger.Error(err) - return - } - logger.Info(message) - } wg.Add(1) - go func() { - defer wg.Done() - tickerWg := &sync.WaitGroup{} - // for linters only - var restartTickerContext context.Context - var restartTickerCancel context.CancelFunc = func() {} - for { - select { - case <-ctx.Done(): - restartTickerCancel() // for linters only - tickerWg.Wait() - return - case <-connectedCh: // blocks until openvpn is connected - restartTickerCancel() // stop previous restart tickers - tickerWg.Wait() - restartTickerContext, restartTickerCancel = context.WithCancel(ctx) - tickerWg.Add(2) - go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) - go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) - onConnected(allSettings, logger, routingConf, portForward, restartUnbound) - case <-dnsReadyCh: - restartPublicIP() // TODO do not restart if disabled - versionInformation() - } - } - }() + go routeReadyEvents(ctx, wg, tunnelReadyCh, dnsReadyCh, + unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient, + allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward, + ) - httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound, updaterLooper.Restart, - getOpenvpnSettings, getPortForwarded) + httpServer := server.New("0.0.0.0:8000", logger, openvpnLooper, unboundLooper, updaterLooper) wg.Add(1) go httpServer.Run(ctx, wg) // Start openvpn for the first time - restartOpenvpn() + openvpnLooper.Restart() signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, @@ -352,7 +311,7 @@ func printVersions(ctx context.Context, logger logging.Logger, versionFunctions } } -func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, logger logging.Logger, signalConnected func()) { +func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, logger logging.Logger, signalTunnelReady func()) { // Blocking line merging paramsReader for all programs: openvpn, tinyproxy, unbound and shadowsocks logger.Info("Launching standard output merger") streamMerger.CollectLines(ctx, func(line string) { @@ -369,29 +328,61 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, logger.Error(line) } if strings.Contains(line, "Initialization Sequence Completed") { - signalConnected() + signalTunnelReady() } }, func(err error) { logger.Warn(err) }) } -func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing, - portForward, restartUnbound func(), -) { - restartUnbound() - if allSettings.OpenVPN.Provider.PortForwarding.Enabled { - time.AfterFunc(5*time.Second, portForward) - } - defaultInterface, _, err := routingConf.DefaultRoute() - if err != nil { - logger.Warn(err) - } else { - vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface) - if err != nil { - logger.Warn(err) - } else { - logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) +func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{}, + unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, + routing routing.Routing, logger logging.Logger, httpClient *http.Client, + versionInformation, portForwardingEnabled bool, startPortForward func()) { + defer wg.Done() + tickerWg := &sync.WaitGroup{} + // for linters only + var restartTickerContext context.Context + var restartTickerCancel context.CancelFunc = func() {} + for { + select { + case <-ctx.Done(): + restartTickerCancel() // for linters only + tickerWg.Wait() + return + case <-tunnelReadyCh: // blocks until openvpn is connected + unboundLooper.Restart() + restartTickerCancel() // stop previous restart tickers + tickerWg.Wait() + restartTickerContext, restartTickerCancel = context.WithCancel(ctx) + tickerWg.Add(2) + go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) + go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) + if portForwardingEnabled { + time.AfterFunc(5*time.Second, startPortForward) + } + defaultInterface, _, err := routing.DefaultRoute() + if err != nil { + logger.Warn(err) + } else { + vpnGatewayIP, err := routing.VPNGatewayIP(defaultInterface) + if err != nil { + logger.Warn(err) + } else { + logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) + } + } + case <-dnsReadyCh: + publicIPLooper.Restart() // TODO do not restart if disabled + if !versionInformation { + break + } + message, err := versionpkg.GetMessage(version, commit, httpClient) + if err != nil { + logger.Error(err) + break + } + logger.Info(message) } } } diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 91500278..1f92cca0 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -154,6 +154,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { continue } + // Needs the stream line from main.go to know when the tunnel is up go func(ctx context.Context) { for { select { diff --git a/internal/server/openvpn.go b/internal/server/openvpn.go index 94e1074a..b02e633b 100644 --- a/internal/server/openvpn.go +++ b/internal/server/openvpn.go @@ -6,7 +6,7 @@ import ( ) func (s *server) handleGetPortForwarded(w http.ResponseWriter) { - port := s.getPortForwarded() + port := s.openvpnLooper.GetPortForwarded() data, err := json.Marshal(struct { Port uint16 `json:"port"` }{port}) @@ -22,7 +22,7 @@ func (s *server) handleGetPortForwarded(w http.ResponseWriter) { } func (s *server) handleGetOpenvpnSettings(w http.ResponseWriter) { - settings := s.getOpenvpnSettings() + settings := s.openvpnLooper.GetSettings() data, err := json.Marshal(settings) if err != nil { s.logger.Warn(err) diff --git a/internal/server/server.go b/internal/server/server.go index 957f4855..5abcdb99 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -8,7 +8,9 @@ import ( "sync" "time" - "github.com/qdm12/gluetun/internal/settings" + "github.com/qdm12/gluetun/internal/dns" + "github.com/qdm12/gluetun/internal/openvpn" + "github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/golibs/logging" ) @@ -17,27 +19,22 @@ type Server interface { } type server struct { - address string - logger logging.Logger - restartOpenvpn func() - restartUnbound func() - restartUpdater func() - getOpenvpnSettings func() settings.OpenVPN - getPortForwarded func() uint16 - lookupIP func(host string) ([]net.IP, error) + address string + logger logging.Logger + openvpnLooper openvpn.Looper + unboundLooper dns.Looper + updaterLooper updater.Looper + lookupIP func(host string) ([]net.IP, error) } -func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound, restartUpdater func(), - getOpenvpnSettings func() settings.OpenVPN, getPortForwarded func() uint16) Server { +func New(address string, logger logging.Logger, openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper) Server { return &server{ - address: address, - logger: logger.WithPrefix("http server: "), - restartOpenvpn: restartOpenvpn, - restartUnbound: restartUnbound, - restartUpdater: restartUpdater, - getOpenvpnSettings: getOpenvpnSettings, - getPortForwarded: getPortForwarded, - lookupIP: net.LookupIP, + address: address, + logger: logger.WithPrefix("http server: "), + openvpnLooper: openvpnLooper, + unboundLooper: unboundLooper, + updaterLooper: updaterLooper, + lookupIP: net.LookupIP, } } @@ -68,10 +65,10 @@ func (s *server) makeHandler() http.HandlerFunc { case http.MethodGet: switch r.RequestURI { case "/openvpn/actions/restart": - s.restartOpenvpn() + s.openvpnLooper.Restart() w.WriteHeader(http.StatusOK) case "/unbound/actions/restart": - s.restartUnbound() + s.unboundLooper.Restart() w.WriteHeader(http.StatusOK) case "/openvpn/portforwarded": s.handleGetPortForwarded(w) @@ -80,7 +77,7 @@ func (s *server) makeHandler() http.HandlerFunc { case "/health": s.handleHealth(w) case "/updater/restart": - s.restartUpdater() + s.updaterLooper.Restart() w.WriteHeader(http.StatusOK) default: routeDoesNotExist(s.logger, w, r)