Fix public IP on restarts, refers to 359

This commit is contained in:
Quentin McGaw
2021-01-29 00:06:55 +00:00
parent bc83b75634
commit 5194361f3b
3 changed files with 35 additions and 32 deletions

View File

@@ -236,9 +236,7 @@ func _main(background context.Context, buildInfo models.BuildInformation,
} }
tunnelReadyCh := make(chan struct{}) tunnelReadyCh := make(chan struct{})
dnsReadyCh := make(chan struct{})
defer close(tunnelReadyCh) defer close(tunnelReadyCh)
defer close(dnsReadyCh)
if allSettings.Firewall.Enabled { if allSettings.Firewall.Enabled {
err := firewallConf.SetEnabled(ctx, true) // disabled by default err := firewallConf.SetEnabled(ctx, true) // disabled by default
@@ -279,7 +277,7 @@ func _main(background context.Context, buildInfo models.BuildInformation,
logger, nonRootUsername, puid, pgid) logger, nonRootUsername, puid, pgid)
wg.Add(1) wg.Add(1)
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg, dnsReadyCh) go unboundLooper.Run(ctx, wg)
publicIPLooper := publicip.NewLooper( publicIPLooper := publicip.NewLooper(
httpClient, logger, allSettings.PublicIP, puid, pgid, os) httpClient, logger, allSettings.PublicIP, puid, pgid, os)
@@ -297,7 +295,7 @@ func _main(background context.Context, buildInfo models.BuildInformation,
go shadowsocksLooper.Run(ctx, wg) go shadowsocksLooper.Run(ctx, wg)
wg.Add(1) wg.Add(1)
go routeReadyEvents(ctx, wg, buildInfo, tunnelReadyCh, dnsReadyCh, go routeReadyEvents(ctx, wg, buildInfo, tunnelReadyCh,
unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient, unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient,
allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward, allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward,
) )
@@ -347,7 +345,7 @@ func printVersions(ctx context.Context, logger logging.Logger,
} }
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.BuildInformation, func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.BuildInformation,
tunnelReadyCh, dnsReadyCh <-chan struct{}, tunnelReadyCh <-chan struct{},
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
routing routing.Routing, logger logging.Logger, httpClient *http.Client, routing routing.Routing, logger logging.Logger, httpClient *http.Client,
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) { versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
@@ -356,6 +354,7 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.
// for linters only // for linters only
var restartTickerContext context.Context var restartTickerContext context.Context
var restartTickerCancel context.CancelFunc = func() {} var restartTickerCancel context.CancelFunc = func() {}
first := true
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -363,22 +362,41 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.
tickerWg.Wait() tickerWg.Wait()
return return
case <-tunnelReadyCh: // blocks until openvpn is connected case <-tunnelReadyCh: // blocks until openvpn is connected
if unboundLooper.GetSettings().Enabled {
_, _ = unboundLooper.SetStatus(constants.Running)
}
restartTickerCancel() // stop previous restart tickers
tickerWg.Wait()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
//nolint:gomnd
tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
vpnDestination, err := routing.VPNDestinationIP() vpnDestination, err := routing.VPNDestinationIP()
if err != nil { if err != nil {
logger.Warn(err) logger.Warn(err)
} else { } else {
logger.Info("VPN routing IP address: %s", vpnDestination) logger.Info("VPN routing IP address: %s", vpnDestination)
} }
if unboundLooper.GetSettings().Enabled {
_, _ = unboundLooper.SetStatus(constants.Running)
}
restartTickerCancel() // stop previous restart tickers
tickerWg.Wait()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
// Runs the Public IP getter job once
_, _ = publicIPLooper.SetStatus(constants.Running)
if !versionInformation {
break
}
if first {
first = false
message, err := versionpkg.GetMessage(ctx, buildInfo, httpClient)
if err != nil {
logger.Error(err)
} else {
logger.Info(message)
}
}
//nolint:gomnd
tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
if portForwardingEnabled { if portForwardingEnabled {
// vpnGateway required only for PIA // vpnGateway required only for PIA
vpnGateway, err := routing.VPNLocalGatewayIP() vpnGateway, err := routing.VPNLocalGatewayIP()
@@ -388,18 +406,6 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, buildInfo models.
logger.Info("VPN gateway IP address: %s", vpnGateway) logger.Info("VPN gateway IP address: %s", vpnGateway)
startPortForward(vpnGateway) startPortForward(vpnGateway)
} }
case <-dnsReadyCh:
// Runs the Public IP getter job once
_, _ = publicIPLooper.SetStatus(constants.Running)
if !versionInformation {
break
}
message, err := versionpkg.GetMessage(ctx, buildInfo, httpClient)
if err != nil {
logger.Error(err)
break
}
logger.Info(message)
} }
} }
} }

View File

@@ -16,7 +16,7 @@ import (
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup, dnsReadyCh chan<- struct{}) Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
GetStatus() (status models.LoopStatus) GetStatus() (status models.LoopStatus)
SetStatus(status models.LoopStatus) (outcome string, err error) SetStatus(status models.LoopStatus) (outcome string, err error)
@@ -83,7 +83,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, dnsReadyCh chan<- struct{}) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
const fallback = false const fallback = false
@@ -136,8 +136,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, dnsReadyCh chan<-
closeStreams = func() {} closeStreams = func() {}
} }
dnsReadyCh <- struct{}{}
stayHere := true stayHere := true
for stayHere { for stayHere {
select { select {

View File

@@ -138,7 +138,6 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
} }
return return
case <-l.start: case <-l.start:
l.logger.Info("starting")
getCancel() getCancel()
stayHere = false stayHere = false
case <-l.stop: case <-l.stop: