diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 665620c3..b5a77e29 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -36,7 +36,7 @@ func main() { os.Exit(_main(ctx, os.Args)) } -func _main(background context.Context, args []string) int { +func _main(background context.Context, args []string) int { //nolint:gocognit,gocyclo if len(args) > 1 { // cli operation var err error switch args[1] { @@ -59,8 +59,6 @@ func _main(background context.Context, args []string) int { defer cancel() logger := createLogger() - fatalOnError := makeFatalOnError(logger, cancel) - client := network.NewClient(15 * time.Second) // Create configurators fileManager := files.NewFileManager() @@ -86,7 +84,10 @@ func _main(background context.Context, args []string) int { }) allSettings, err := settings.GetAllSettings(paramsReader) - fatalOnError(err) + if err != nil { + logger.Error(err) + return 1 + } logger.Info(allSettings.String()) // TODO run this in a loop or in openvpn to reload from file without restarting @@ -101,11 +102,20 @@ func _main(background context.Context, args []string) int { uid, gid := allSettings.System.UID, allSettings.System.GID err = alpineConf.CreateUser("nonrootuser", uid) - fatalOnError(err) + if err != nil { + logger.Error(err) + return 1 + } err = fileManager.SetOwnership("/etc/unbound", uid, gid) - fatalOnError(err) + if err != nil { + logger.Error(err) + return 1 + } err = fileManager.SetOwnership("/etc/tinyproxy", uid, gid) - fatalOnError(err) + if err != nil { + logger.Error(err) + return 1 + } if allSettings.Firewall.Debug { firewallConf.SetDebug() @@ -114,12 +124,14 @@ func _main(background context.Context, args []string) int { defaultInterface, defaultGateway, err := routingConf.DefaultRoute() if err != nil { - fatalOnError(err) + logger.Error(err) + return 1 } localSubnet, err := routingConf.LocalSubnet() if err != nil { - fatalOnError(err) + logger.Error(err) + return 1 } firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet) @@ -127,7 +139,10 @@ func _main(background context.Context, args []string) int { if err := ovpnConf.CheckTUN(); err != nil { logger.Warn(err) err = ovpnConf.CreateTUN() - fatalOnError(err) + if err != nil { + logger.Error(err) + return 1 + } } connectedCh := make(chan struct{}) @@ -135,25 +150,35 @@ func _main(background context.Context, args []string) int { connectedCh <- struct{}{} } defer close(connectedCh) - go collectStreamLines(ctx, streamMerger, logger, signalConnected) if allSettings.Firewall.Enabled { err := firewallConf.SetEnabled(ctx, true) // disabled by default - fatalOnError(err) + if err != nil { + logger.Error(err) + return 1 + } } err = firewallConf.SetAllowedSubnets(ctx, allSettings.Firewall.AllowedSubnets) - fatalOnError(err) + if err != nil { + logger.Error(err) + return 1 + } for _, vpnPort := range allSettings.Firewall.VPNInputPorts { err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN)) - fatalOnError(err) + if err != nil { + logger.Error(err) + return 1 + } } wg := &sync.WaitGroup{} + go collectStreamLines(ctx, streamMerger, logger, signalConnected) + openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers, - ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError) + ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel) restartOpenvpn := openvpnLooper.Restart portForward := openvpnLooper.PortForward getOpenvpnSettings := openvpnLooper.GetSettings @@ -258,15 +283,6 @@ func _main(background context.Context, args []string) int { return 0 } -func makeFatalOnError(logger logging.Logger, cancel context.CancelFunc) func(err error) { - return func(err error) { - if err != nil { - logger.Error(err) - cancel() - } - } -} - func createLogger() logging.Logger { logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel, -1) if err != nil { diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index c10eb97d..3a7e7675 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -45,7 +45,7 @@ type looper struct { client network.Client fileManager files.FileManager streamMerger command.StreamMerger - fatalOnError func(err error) + cancel context.CancelFunc // Internal channels restart chan struct{} portForwardSignals chan struct{} @@ -55,7 +55,7 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, uid, gid int, allServers models.AllServers, conf Configurator, fw firewall.Configurator, logger logging.Logger, client network.Client, fileManager files.FileManager, - streamMerger command.StreamMerger, fatalOnError func(err error)) Looper { + streamMerger command.StreamMerger, cancel context.CancelFunc) Looper { return &looper{ provider: provider, settings: settings, @@ -68,7 +68,7 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, client: client, fileManager: fileManager, streamMerger: streamMerger, - fatalOnError: fatalOnError, + cancel: cancel, restart: make(chan struct{}), portForwardSignals: make(chan struct{}), } @@ -104,7 +104,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { providerConf := provider.New(l.provider, l.allServers) connections, err := providerConf.GetOpenVPNConnections(settings.Provider.ServerSelection) if err != nil { - l.fatalOnError(err) + l.logger.Error(err) + l.cancel() return } lines := providerConf.BuildConf( @@ -118,17 +119,20 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { settings.Provider.ExtraConfigOptions, ) if err := l.fileManager.WriteLinesToFile(string(constants.OpenVPNConf), lines, files.Ownership(l.uid, l.gid), files.Permissions(0400)); err != nil { - l.fatalOnError(err) + l.logger.Error(err) + l.cancel() return } if err := l.conf.WriteAuthFile(settings.User, settings.Password, l.uid, l.gid); err != nil { - l.fatalOnError(err) + l.logger.Error(err) + l.cancel() return } if err := l.fw.SetVPNConnections(ctx, connections); err != nil { - l.fatalOnError(err) + l.logger.Error(err) + l.cancel() return }