Maintenance: rework main function

This commit is contained in:
Quentin McGaw
2021-01-04 01:40:07 +00:00
parent c833e9a1a8
commit cfbf5624e1

View File

@@ -53,42 +53,79 @@ func main() {
Commit: commit, Commit: commit,
BuildDate: buildDate, BuildDate: buildDate,
} }
ctx := context.Background() ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
if err != nil {
fmt.Println(err)
nativeos.Exit(1)
}
args := nativeos.Args args := nativeos.Args
os := os.New() os := os.New()
osUser := user.New() osUser := user.New()
unix := unix.New() unix := unix.New()
cli := cli.New() cli := cli.New()
nativeos.Exit(_main(ctx, buildInfo, args, os, osUser, unix, cli))
errorCh := make(chan error)
go func() {
errorCh <- _main(ctx, buildInfo, args, logger, os, osUser, unix, cli)
}()
signalsCh := make(chan nativeos.Signal, 1)
signal.Notify(signalsCh,
syscall.SIGINT,
syscall.SIGTERM,
nativeos.Interrupt,
)
select {
case signal := <-signalsCh:
logger.Warn("Caught OS signal %s, shutting down", signal)
case err := <-errorCh:
logger.Error(err)
close(errorCh)
}
cancel()
const shutdownGracePeriod = 5 * time.Second
timer := time.NewTimer(shutdownGracePeriod)
select {
case <-errorCh:
if !timer.Stop() {
<-timer.C
}
logger.Info("Shutdown successful")
case <-timer.C:
logger.Warn("Shutdown timed out")
}
nativeos.Exit(1)
} }
//nolint:gocognit,gocyclo //nolint:gocognit,gocyclo
func _main(background context.Context, buildInfo models.BuildInformation, func _main(background context.Context, buildInfo models.BuildInformation,
args []string, os os.OS, osUser user.OSUser, unix unix.Unix, args []string, logger logging.Logger, os os.OS, osUser user.OSUser, unix unix.Unix,
cli cli.CLI) int { cli cli.CLI) error {
if len(args) > 1 { // cli operation if len(args) > 1 { // cli operation
var err error
switch args[1] { switch args[1] {
case "healthcheck": case "healthcheck":
err = cli.HealthCheck(background) return cli.HealthCheck(background)
case "clientkey": case "clientkey":
err = cli.ClientKey(args[2:], os.OpenFile) return cli.ClientKey(args[2:], os.OpenFile)
case "openvpnconfig": case "openvpnconfig":
err = cli.OpenvpnConfig(os) return cli.OpenvpnConfig(os)
case "update": case "update":
err = cli.Update(args[2:], os) return cli.Update(args[2:], os)
default: default:
err = fmt.Errorf("command %q is unknown", args[1]) return fmt.Errorf("command %q is unknown", args[1])
} }
if err != nil {
fmt.Println(err)
return 1
}
return 0
} }
ctx, cancel := context.WithCancel(background) ctx, cancel := context.WithCancel(background)
defer cancel() defer cancel()
logger := createLogger()
const clientTimeout = 15 * time.Second const clientTimeout = 15 * time.Second
httpClient := &http.Client{Timeout: clientTimeout} httpClient := &http.Client{Timeout: clientTimeout}
@@ -114,26 +151,22 @@ func _main(background context.Context, buildInfo models.BuildInformation,
allSettings, err := settings.GetAllSettings(paramsReader) allSettings, err := settings.GetAllSettings(paramsReader)
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
logger.Info(allSettings.String()) logger.Info(allSettings.String())
if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil { if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil {
logger.Error(err) return err
return 1
} }
if err := os.MkdirAll("/gluetun", 0644); err != nil { if err := os.MkdirAll("/gluetun", 0644); err != nil {
logger.Error(err) return err
return 1
} }
// TODO run this in a loop or in openvpn to reload from file without restarting // TODO run this in a loop or in openvpn to reload from file without restarting
storage := storage.New(logger, os, constants.ServersData) storage := storage.New(logger, os, constants.ServersData)
allServers, err := storage.SyncServers(constants.GetAllServers()) allServers, err := storage.SyncServers(constants.GetAllServers())
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
// Should never change // Should never change
@@ -142,16 +175,14 @@ func _main(background context.Context, buildInfo models.BuildInformation,
const defaultUsername = "nonrootuser" const defaultUsername = "nonrootuser"
nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid) nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid)
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
if nonRootUsername != defaultUsername { if nonRootUsername != defaultUsername {
logger.Info("using existing username %s corresponding to user id %d", nonRootUsername, puid) logger.Info("using existing username %s corresponding to user id %d", nonRootUsername, puid)
} }
if err := os.Chown("/etc/unbound", puid, pgid); err != nil { if err := os.Chown("/etc/unbound", puid, pgid); err != nil {
logger.Error(err) return err
return 1
} }
if allSettings.Firewall.Debug { if allSettings.Firewall.Debug {
@@ -161,27 +192,23 @@ func _main(background context.Context, buildInfo models.BuildInformation,
defaultInterface, defaultGateway, err := routingConf.DefaultRoute() defaultInterface, defaultGateway, err := routingConf.DefaultRoute()
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
localSubnet, err := routingConf.LocalSubnet() localSubnet, err := routingConf.LocalSubnet()
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
defaultIP, err := routingConf.DefaultIP() defaultIP, err := routingConf.DefaultIP()
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet, defaultIP) firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet, defaultIP)
if err := routingConf.Setup(); err != nil { if err := routingConf.Setup(); err != nil {
logger.Error(err) return err
return 1
} }
defer func() { defer func() {
routingConf.SetVerbose(false) routingConf.SetVerbose(false)
@@ -191,20 +218,17 @@ func _main(background context.Context, buildInfo models.BuildInformation,
}() }()
if err := firewallConf.SetOutboundSubnets(ctx, allSettings.Firewall.OutboundSubnets); err != nil { if err := firewallConf.SetOutboundSubnets(ctx, allSettings.Firewall.OutboundSubnets); err != nil {
logger.Error(err) return err
return 1
} }
if err := routingConf.SetOutboundRoutes(allSettings.Firewall.OutboundSubnets); err != nil { if err := routingConf.SetOutboundRoutes(allSettings.Firewall.OutboundSubnets); err != nil {
logger.Error(err) return err
return 1
} }
if err := ovpnConf.CheckTUN(); err != nil { if err := ovpnConf.CheckTUN(); err != nil {
logger.Warn(err) logger.Warn(err)
err = ovpnConf.CreateTUN() err = ovpnConf.CreateTUN()
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
} }
@@ -217,30 +241,27 @@ func _main(background context.Context, buildInfo models.BuildInformation,
if allSettings.Firewall.Enabled { if allSettings.Firewall.Enabled {
err := firewallConf.SetEnabled(ctx, true) // disabled by default err := firewallConf.SetEnabled(ctx, true) // disabled by default
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
} }
for _, vpnPort := range allSettings.Firewall.VPNInputPorts { for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN)) err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN))
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
} }
for _, port := range allSettings.Firewall.InputPorts { for _, port := range allSettings.Firewall.InputPorts {
err = firewallConf.SetAllowedPort(ctx, port, defaultInterface) err = firewallConf.SetAllowedPort(ctx, port, defaultInterface)
if err != nil { if err != nil {
logger.Error(err) return err
return 1
} }
} // TODO move inside firewall? } // TODO move inside firewall?
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) // TODO waitgroup
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel) ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel)
@@ -296,55 +317,18 @@ func _main(background context.Context, buildInfo models.BuildInformation,
// until openvpn is launched // until openvpn is launched
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable _, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
signalsCh := make(chan nativeos.Signal, 1) <-ctx.Done()
signal.Notify(signalsCh,
syscall.SIGINT,
syscall.SIGTERM,
nativeos.Interrupt,
)
shutdownErrorsCount := 0
select {
case signal := <-signalsCh:
logger.Warn("Caught OS signal %s, shutting down", signal)
cancel()
case <-ctx.Done():
logger.Warn("context canceled, shutting down")
}
if allSettings.OpenVPN.Provider.PortForwarding.Enabled { if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath) logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
logger.Error(err) logger.Error(err)
shutdownErrorsCount++
} }
} }
const shutdownGracePeriod = 5 * time.Second
waiting, waited := context.WithTimeout(context.Background(), shutdownGracePeriod)
go func() {
defer waited()
wg.Wait()
}()
<-waiting.Done()
if waiting.Err() == context.DeadlineExceeded {
if shutdownErrorsCount > 0 {
logger.Warn("Shutdown had %d errors", shutdownErrorsCount)
}
logger.Warn("Shutdown timed out")
return 1
}
if shutdownErrorsCount > 0 {
logger.Warn("Shutdown had %d errors")
return 1
}
logger.Info("Shutdown successful")
return 0
}
func createLogger() logging.Logger { wg.Wait()
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
if err != nil { return nil
panic(err)
}
return logger
} }
func printVersions(ctx context.Context, logger logging.Logger, func printVersions(ctx context.Context, logger logging.Logger,