Maint: simplify settings code in internal/vpn

This commit is contained in:
Quentin McGaw (desktop)
2021-08-19 14:57:11 +00:00
parent 9218c7ef19
commit 5c2286f4e8
8 changed files with 31 additions and 48 deletions

View File

@@ -356,7 +356,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
tickersGroupHandler.Add(pubIPTickerHandler) tickersGroupHandler.Add(pubIPTickerHandler)
vpnLogger := logger.NewChild(logging.Settings{Prefix: "vpn: "}) vpnLogger := logger.NewChild(logging.Settings{Prefix: "vpn: "})
vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.VPN.Provider, vpnLooper := vpn.NewLoop(allSettings.VPN,
allServers, ovpnConf, firewallConf, routingConf, portForwardLooper, allServers, ovpnConf, firewallConf, routingConf, portForwardLooper,
cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient, cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient,
buildInfo, allSettings.VersionInformation) buildInfo, allSettings.VersionInformation)

View File

@@ -96,7 +96,7 @@ func (h *openvpnHandler) setStatus(w http.ResponseWriter, r *http.Request) {
} }
func (h *openvpnHandler) getSettings(w http.ResponseWriter) { func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
vpnSettings, _ := h.looper.GetSettings() vpnSettings := h.looper.GetSettings()
settings := vpnSettings.OpenVPN settings := vpnSettings.OpenVPN
settings.User = "redacted" settings.User = "redacted"
settings.Password = "redacted" settings.Password = "redacted"

View File

@@ -66,7 +66,6 @@ const (
) )
func NewLoop(vpnSettings configuration.VPN, func NewLoop(vpnSettings configuration.VPN,
providerSettings configuration.Provider,
allServers models.AllServers, openvpnConf openvpn.Interface, allServers models.AllServers, openvpnConf openvpn.Interface,
fw firewallConfigurer, routing routing.VPNGetter, fw firewallConfigurer, routing routing.VPNGetter,
portForward portforward.StartStopper, starter command.Starter, portForward portforward.StartStopper, starter command.Starter,
@@ -79,7 +78,7 @@ func NewLoop(vpnSettings configuration.VPN,
stopped := make(chan struct{}) stopped := make(chan struct{})
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped) statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
state := state.New(statusManager, vpnSettings, providerSettings, allServers) state := state.New(statusManager, vpnSettings, allServers)
return &Loop{ return &Loop{
statusManager: statusManager, statusManager: statusManager,

View File

@@ -26,18 +26,17 @@ var (
// It returns a serverName for port forwarding (PIA) and an error if it fails. // It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter, func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
openvpnConf openvpn.Interface, providerConf provider.Provider, openvpnConf openvpn.Interface, providerConf provider.Provider,
openVPNSettings configuration.OpenVPN, providerSettings configuration.Provider, settings configuration.VPN, starter command.Starter, logger logging.Logger) (
starter command.Starter, logger logging.Logger) (
runner vpnRunner, serverName string, err error) { runner vpnRunner, serverName string, err error) {
var connection models.Connection var connection models.Connection
var lines []string var lines []string
if openVPNSettings.Config == "" { if settings.OpenVPN.Config == "" {
connection, err = providerConf.GetConnection(providerSettings.ServerSelection) connection, err = providerConf.GetConnection(settings.Provider.ServerSelection)
if err == nil { if err == nil {
lines = providerConf.BuildConf(connection, openVPNSettings) lines = providerConf.BuildConf(connection, settings.OpenVPN)
} }
} else { } else {
lines, connection, err = custom.BuildConfig(openVPNSettings) lines, connection, err = custom.BuildConfig(settings.OpenVPN)
} }
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err) return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err)
@@ -47,8 +46,8 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
return nil, "", fmt.Errorf("%w: %s", errWriteConfig, err) return nil, "", fmt.Errorf("%w: %s", errWriteConfig, err)
} }
if openVPNSettings.User != "" { if settings.OpenVPN.User != "" {
err := openvpnConf.WriteAuthFile(openVPNSettings.User, openVPNSettings.Password) err := openvpnConf.WriteAuthFile(settings.OpenVPN.User, settings.OpenVPN.Password)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errWriteAuth, err) return nil, "", fmt.Errorf("%w: %s", errWriteAuth, err)
} }
@@ -58,7 +57,7 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
return nil, "", fmt.Errorf("%w: %s", errFirewall, err) return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
} }
runner = openvpn.NewRunner(openVPNSettings, starter, logger) runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
return runner, connection.Hostname, nil return runner, connection.Hostname, nil
} }

View File

@@ -26,20 +26,18 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
} }
for ctx.Err() == nil { for ctx.Err() == nil {
VPNSettings, providerSettings, allServers := l.state.GetSettingsAndServers() settings, allServers := l.state.GetSettingsAndServers()
providerConf := provider.New(providerSettings.Name, allServers, time.Now) providerConf := provider.New(settings.Provider.Name, allServers, time.Now)
vpnRunner, serverName, err := setupOpenVPN(ctx, l.fw, vpnRunner, serverName, err := setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, l.openvpnConf, providerConf, settings, l.starter, l.logger)
VPNSettings.OpenVPN, providerSettings,
l.starter, l.logger)
if err != nil { if err != nil {
l.crashed(ctx, err) l.crashed(ctx, err)
continue continue
} }
tunnelUpData := tunnelUpData{ tunnelUpData := tunnelUpData{
portForwarding: providerSettings.PortForwarding.Enabled, portForwarding: settings.Provider.PortForwarding.Enabled,
serverName: serverName, serverName: serverName,
portForwarder: providerConf, portForwarder: providerConf,
} }
@@ -67,7 +65,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-ctx.Done(): case <-ctx.Done():
const pfTimeout = 100 * time.Millisecond const pfTimeout = 100 * time.Millisecond
l.stopPortForwarding(context.Background(), l.stopPortForwarding(context.Background(),
providerSettings.PortForwarding.Enabled, pfTimeout) settings.Provider.PortForwarding.Enabled, pfTimeout)
openvpnCancel() openvpnCancel()
<-waitError <-waitError
close(waitError) close(waitError)
@@ -75,7 +73,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-l.stop: case <-l.stop:
l.userTrigger = true l.userTrigger = true
l.logger.Info("stopping") l.logger.Info("stopping")
l.stopPortForwarding(ctx, providerSettings.PortForwarding.Enabled, 0) l.stopPortForwarding(ctx, settings.Provider.PortForwarding.Enabled, 0)
openvpnCancel() openvpnCancel()
<-waitError <-waitError
// do not close waitError or the waitError // do not close waitError or the waitError
@@ -90,7 +88,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.statusManager.Lock() // prevent SetStatus from running in parallel l.statusManager.Lock() // prevent SetStatus from running in parallel
l.stopPortForwarding(ctx, providerSettings.PortForwarding.Enabled, 0) l.stopPortForwarding(ctx, settings.Provider.PortForwarding.Enabled, 0)
openvpnCancel() openvpnCancel()
l.statusManager.SetStatus(constants.Crashed) l.statusManager.SetStatus(constants.Crashed)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)

View File

@@ -9,13 +9,12 @@ import (
type SettingsGetSetter = state.SettingsGetSetter type SettingsGetSetter = state.SettingsGetSetter
func (l *Loop) GetSettings() ( func (l *Loop) GetSettings() (settings configuration.VPN) {
vpn configuration.VPN, provider configuration.Provider) {
return l.state.GetSettings() return l.state.GetSettings()
} }
func (l *Loop) SetSettings(ctx context.Context, func (l *Loop) SetSettings(ctx context.Context,
vpn configuration.VPN, provider configuration.Provider) ( vpn configuration.VPN) (
outcome string) { outcome string) {
return l.state.SetSettings(ctx, vpn, provider) return l.state.SetSettings(ctx, vpn)
} }

View File

@@ -13,17 +13,14 @@ var _ Manager = (*State)(nil)
type Manager interface { type Manager interface {
SettingsGetSetter SettingsGetSetter
ServersGetterSetter ServersGetterSetter
GetSettingsAndServers() (vpn configuration.VPN, GetSettingsAndServers() (vpn configuration.VPN, allServers models.AllServers)
provider configuration.Provider, allServers models.AllServers)
} }
func New(statusApplier loopstate.Applier, func New(statusApplier loopstate.Applier,
vpn configuration.VPN, provider configuration.Provider, vpn configuration.VPN, allServers models.AllServers) *State {
allServers models.AllServers) *State {
return &State{ return &State{
statusApplier: statusApplier, statusApplier: statusApplier,
vpn: vpn, vpn: vpn,
provider: provider,
allServers: allServers, allServers: allServers,
} }
} }
@@ -32,7 +29,6 @@ type State struct {
statusApplier loopstate.Applier statusApplier loopstate.Applier
vpn configuration.VPN vpn configuration.VPN
provider configuration.Provider
settingsMu sync.RWMutex settingsMu sync.RWMutex
allServers models.AllServers allServers models.AllServers
@@ -40,13 +36,12 @@ type State struct {
} }
func (s *State) GetSettingsAndServers() (vpn configuration.VPN, func (s *State) GetSettingsAndServers() (vpn configuration.VPN,
provider configuration.Provider, allServers models.AllServers) { allServers models.AllServers) {
s.settingsMu.RLock() s.settingsMu.RLock()
s.allServersMu.RLock() s.allServersMu.RLock()
vpn = s.vpn vpn = s.vpn
provider = s.provider
allServers = s.allServers allServers = s.allServers
s.settingsMu.RUnlock() s.settingsMu.RUnlock()
s.allServersMu.RUnlock() s.allServersMu.RUnlock()
return vpn, provider, allServers return vpn, allServers
} }

View File

@@ -9,33 +9,26 @@ import (
) )
type SettingsGetSetter interface { type SettingsGetSetter interface {
GetSettings() (vpn configuration.VPN, GetSettings() (vpn configuration.VPN)
provider configuration.Provider) SetSettings(ctx context.Context, vpn configuration.VPN) (outcome string)
SetSettings(ctx context.Context, vpn configuration.VPN,
provider configuration.Provider) (outcome string)
} }
func (s *State) GetSettings() (vpn configuration.VPN, func (s *State) GetSettings() (vpn configuration.VPN) {
provider configuration.Provider) {
s.settingsMu.RLock() s.settingsMu.RLock()
vpn = s.vpn vpn = s.vpn
provider = s.provider
s.settingsMu.RUnlock() s.settingsMu.RUnlock()
return vpn, provider return vpn
} }
func (s *State) SetSettings(ctx context.Context, func (s *State) SetSettings(ctx context.Context, vpn configuration.VPN) (
vpn configuration.VPN, provider configuration.Provider) (
outcome string) { outcome string) {
s.settingsMu.Lock() s.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(s.vpn, vpn) && settingsUnchanged := reflect.DeepEqual(s.vpn, vpn)
reflect.DeepEqual(s.provider, provider)
if settingsUnchanged { if settingsUnchanged {
s.settingsMu.Unlock() s.settingsMu.Unlock()
return "settings left unchanged" return "settings left unchanged"
} }
s.vpn = vpn s.vpn = vpn
s.provider = provider
s.settingsMu.Unlock() s.settingsMu.Unlock()
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped) _, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
outcome, _ = s.statusApplier.ApplyStatus(ctx, constants.Running) outcome, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)