Fix exiting without fatalOnError function

This commit is contained in:
Quentin McGaw
2020-08-27 22:59:58 +00:00
parent aa9693a84d
commit 93aaf1ab02
2 changed files with 51 additions and 31 deletions

View File

@@ -36,7 +36,7 @@ func main() {
os.Exit(_main(ctx, os.Args)) os.Exit(_main(ctx, os.Args))
} }
func _main(background context.Context, args []string) int { func _main(background context.Context, args []string) int { //nolint:gocognit,gocyclo
if len(args) > 1 { // cli operation if len(args) > 1 { // cli operation
var err error var err error
switch args[1] { switch args[1] {
@@ -59,8 +59,6 @@ func _main(background context.Context, args []string) int {
defer cancel() defer cancel()
logger := createLogger() logger := createLogger()
fatalOnError := makeFatalOnError(logger, cancel)
client := network.NewClient(15 * time.Second) client := network.NewClient(15 * time.Second)
// Create configurators // Create configurators
fileManager := files.NewFileManager() fileManager := files.NewFileManager()
@@ -86,7 +84,10 @@ func _main(background context.Context, args []string) int {
}) })
allSettings, err := settings.GetAllSettings(paramsReader) allSettings, err := settings.GetAllSettings(paramsReader)
fatalOnError(err) if err != nil {
logger.Error(err)
return 1
}
logger.Info(allSettings.String()) logger.Info(allSettings.String())
// TODO run this in a loop or in openvpn to reload from file without restarting // TODO run this in a loop or in openvpn to reload from file without restarting
@@ -101,11 +102,20 @@ func _main(background context.Context, args []string) int {
uid, gid := allSettings.System.UID, allSettings.System.GID uid, gid := allSettings.System.UID, allSettings.System.GID
err = alpineConf.CreateUser("nonrootuser", uid) err = alpineConf.CreateUser("nonrootuser", uid)
fatalOnError(err) if err != nil {
logger.Error(err)
return 1
}
err = fileManager.SetOwnership("/etc/unbound", uid, gid) err = fileManager.SetOwnership("/etc/unbound", uid, gid)
fatalOnError(err) if err != nil {
logger.Error(err)
return 1
}
err = fileManager.SetOwnership("/etc/tinyproxy", uid, gid) err = fileManager.SetOwnership("/etc/tinyproxy", uid, gid)
fatalOnError(err) if err != nil {
logger.Error(err)
return 1
}
if allSettings.Firewall.Debug { if allSettings.Firewall.Debug {
firewallConf.SetDebug() firewallConf.SetDebug()
@@ -114,12 +124,14 @@ func _main(background context.Context, args []string) int {
defaultInterface, defaultGateway, err := routingConf.DefaultRoute() defaultInterface, defaultGateway, err := routingConf.DefaultRoute()
if err != nil { if err != nil {
fatalOnError(err) logger.Error(err)
return 1
} }
localSubnet, err := routingConf.LocalSubnet() localSubnet, err := routingConf.LocalSubnet()
if err != nil { if err != nil {
fatalOnError(err) logger.Error(err)
return 1
} }
firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet) firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet)
@@ -127,7 +139,10 @@ func _main(background context.Context, args []string) int {
if err := ovpnConf.CheckTUN(); err != nil { if err := ovpnConf.CheckTUN(); err != nil {
logger.Warn(err) logger.Warn(err)
err = ovpnConf.CreateTUN() err = ovpnConf.CreateTUN()
fatalOnError(err) if err != nil {
logger.Error(err)
return 1
}
} }
connectedCh := make(chan struct{}) connectedCh := make(chan struct{})
@@ -135,25 +150,35 @@ func _main(background context.Context, args []string) int {
connectedCh <- struct{}{} connectedCh <- struct{}{}
} }
defer close(connectedCh) defer close(connectedCh)
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
if allSettings.Firewall.Enabled { if allSettings.Firewall.Enabled {
err := firewallConf.SetEnabled(ctx, true) // disabled by default err := firewallConf.SetEnabled(ctx, true) // disabled by default
fatalOnError(err) if err != nil {
logger.Error(err)
return 1
}
} }
err = firewallConf.SetAllowedSubnets(ctx, allSettings.Firewall.AllowedSubnets) err = firewallConf.SetAllowedSubnets(ctx, allSettings.Firewall.AllowedSubnets)
fatalOnError(err) if err != nil {
logger.Error(err)
return 1
}
for _, vpnPort := range allSettings.Firewall.VPNInputPorts { for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN)) err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN))
fatalOnError(err) if err != nil {
logger.Error(err)
return 1
}
} }
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers, openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers,
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError) ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel)
restartOpenvpn := openvpnLooper.Restart restartOpenvpn := openvpnLooper.Restart
portForward := openvpnLooper.PortForward portForward := openvpnLooper.PortForward
getOpenvpnSettings := openvpnLooper.GetSettings getOpenvpnSettings := openvpnLooper.GetSettings
@@ -258,15 +283,6 @@ func _main(background context.Context, args []string) int {
return 0 return 0
} }
func makeFatalOnError(logger logging.Logger, cancel context.CancelFunc) func(err error) {
return func(err error) {
if err != nil {
logger.Error(err)
cancel()
}
}
}
func createLogger() logging.Logger { func createLogger() logging.Logger {
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel, -1) logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel, -1)
if err != nil { if err != nil {

View File

@@ -45,7 +45,7 @@ type looper struct {
client network.Client client network.Client
fileManager files.FileManager fileManager files.FileManager
streamMerger command.StreamMerger streamMerger command.StreamMerger
fatalOnError func(err error) cancel context.CancelFunc
// Internal channels // Internal channels
restart chan struct{} restart chan struct{}
portForwardSignals chan struct{} portForwardSignals chan struct{}
@@ -55,7 +55,7 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
uid, gid int, allServers models.AllServers, uid, gid int, allServers models.AllServers,
conf Configurator, fw firewall.Configurator, conf Configurator, fw firewall.Configurator,
logger logging.Logger, client network.Client, fileManager files.FileManager, logger logging.Logger, client network.Client, fileManager files.FileManager,
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper { streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
return &looper{ return &looper{
provider: provider, provider: provider,
settings: settings, settings: settings,
@@ -68,7 +68,7 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
client: client, client: client,
fileManager: fileManager, fileManager: fileManager,
streamMerger: streamMerger, streamMerger: streamMerger,
fatalOnError: fatalOnError, cancel: cancel,
restart: make(chan struct{}), restart: make(chan struct{}),
portForwardSignals: make(chan struct{}), portForwardSignals: make(chan struct{}),
} }
@@ -104,7 +104,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
providerConf := provider.New(l.provider, l.allServers) providerConf := provider.New(l.provider, l.allServers)
connections, err := providerConf.GetOpenVPNConnections(settings.Provider.ServerSelection) connections, err := providerConf.GetOpenVPNConnections(settings.Provider.ServerSelection)
if err != nil { if err != nil {
l.fatalOnError(err) l.logger.Error(err)
l.cancel()
return return
} }
lines := providerConf.BuildConf( lines := providerConf.BuildConf(
@@ -118,17 +119,20 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
settings.Provider.ExtraConfigOptions, settings.Provider.ExtraConfigOptions,
) )
if err := l.fileManager.WriteLinesToFile(string(constants.OpenVPNConf), lines, files.Ownership(l.uid, l.gid), files.Permissions(0400)); err != nil { if err := l.fileManager.WriteLinesToFile(string(constants.OpenVPNConf), lines, files.Ownership(l.uid, l.gid), files.Permissions(0400)); err != nil {
l.fatalOnError(err) l.logger.Error(err)
l.cancel()
return return
} }
if err := l.conf.WriteAuthFile(settings.User, settings.Password, l.uid, l.gid); err != nil { if err := l.conf.WriteAuthFile(settings.User, settings.Password, l.uid, l.gid); err != nil {
l.fatalOnError(err) l.logger.Error(err)
l.cancel()
return return
} }
if err := l.fw.SetVPNConnections(ctx, connections); err != nil { if err := l.fw.SetVPNConnections(ctx, connections); err != nil {
l.fatalOnError(err) l.logger.Error(err)
l.cancel()
return return
} }