Maintenance: rework main function
This commit is contained in:
@@ -53,42 +53,79 @@ func main() {
|
||||
Commit: commit,
|
||||
BuildDate: buildDate,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
nativeos.Exit(1)
|
||||
}
|
||||
|
||||
args := nativeos.Args
|
||||
os := os.New()
|
||||
osUser := user.New()
|
||||
unix := unix.New()
|
||||
cli := cli.New()
|
||||
nativeos.Exit(_main(ctx, buildInfo, args, os, osUser, unix, cli))
|
||||
|
||||
errorCh := make(chan error)
|
||||
go func() {
|
||||
errorCh <- _main(ctx, buildInfo, args, logger, os, osUser, unix, cli)
|
||||
}()
|
||||
|
||||
signalsCh := make(chan nativeos.Signal, 1)
|
||||
signal.Notify(signalsCh,
|
||||
syscall.SIGINT,
|
||||
syscall.SIGTERM,
|
||||
nativeos.Interrupt,
|
||||
)
|
||||
|
||||
select {
|
||||
case signal := <-signalsCh:
|
||||
logger.Warn("Caught OS signal %s, shutting down", signal)
|
||||
case err := <-errorCh:
|
||||
logger.Error(err)
|
||||
close(errorCh)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
const shutdownGracePeriod = 5 * time.Second
|
||||
timer := time.NewTimer(shutdownGracePeriod)
|
||||
select {
|
||||
case <-errorCh:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
logger.Info("Shutdown successful")
|
||||
case <-timer.C:
|
||||
logger.Warn("Shutdown timed out")
|
||||
}
|
||||
|
||||
nativeos.Exit(1)
|
||||
}
|
||||
|
||||
//nolint:gocognit,gocyclo
|
||||
func _main(background context.Context, buildInfo models.BuildInformation,
|
||||
args []string, os os.OS, osUser user.OSUser, unix unix.Unix,
|
||||
cli cli.CLI) int {
|
||||
args []string, logger logging.Logger, os os.OS, osUser user.OSUser, unix unix.Unix,
|
||||
cli cli.CLI) error {
|
||||
if len(args) > 1 { // cli operation
|
||||
var err error
|
||||
switch args[1] {
|
||||
case "healthcheck":
|
||||
err = cli.HealthCheck(background)
|
||||
return cli.HealthCheck(background)
|
||||
case "clientkey":
|
||||
err = cli.ClientKey(args[2:], os.OpenFile)
|
||||
return cli.ClientKey(args[2:], os.OpenFile)
|
||||
case "openvpnconfig":
|
||||
err = cli.OpenvpnConfig(os)
|
||||
return cli.OpenvpnConfig(os)
|
||||
case "update":
|
||||
err = cli.Update(args[2:], os)
|
||||
return cli.Update(args[2:], os)
|
||||
default:
|
||||
err = fmt.Errorf("command %q is unknown", args[1])
|
||||
return fmt.Errorf("command %q is unknown", args[1])
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
ctx, cancel := context.WithCancel(background)
|
||||
defer cancel()
|
||||
logger := createLogger()
|
||||
|
||||
const clientTimeout = 15 * time.Second
|
||||
httpClient := &http.Client{Timeout: clientTimeout}
|
||||
@@ -114,26 +151,22 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
allSettings, err := settings.GetAllSettings(paramsReader)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
logger.Info(allSettings.String())
|
||||
|
||||
if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll("/gluetun", 0644); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO run this in a loop or in openvpn to reload from file without restarting
|
||||
storage := storage.New(logger, os, constants.ServersData)
|
||||
allServers, err := storage.SyncServers(constants.GetAllServers())
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
|
||||
// Should never change
|
||||
@@ -142,16 +175,14 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
||||
const defaultUsername = "nonrootuser"
|
||||
nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
if nonRootUsername != defaultUsername {
|
||||
logger.Info("using existing username %s corresponding to user id %d", nonRootUsername, puid)
|
||||
}
|
||||
|
||||
if err := os.Chown("/etc/unbound", puid, pgid); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
|
||||
if allSettings.Firewall.Debug {
|
||||
@@ -161,27 +192,23 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
defaultInterface, defaultGateway, err := routingConf.DefaultRoute()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
|
||||
localSubnet, err := routingConf.LocalSubnet()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
|
||||
defaultIP, err := routingConf.DefaultIP()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
|
||||
firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet, defaultIP)
|
||||
|
||||
if err := routingConf.Setup(); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
routingConf.SetVerbose(false)
|
||||
@@ -191,20 +218,17 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
||||
}()
|
||||
|
||||
if err := firewallConf.SetOutboundSubnets(ctx, allSettings.Firewall.OutboundSubnets); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
if err := routingConf.SetOutboundRoutes(allSettings.Firewall.OutboundSubnets); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ovpnConf.CheckTUN(); err != nil {
|
||||
logger.Warn(err)
|
||||
err = ovpnConf.CreateTUN()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,30 +241,27 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
||||
if allSettings.Firewall.Enabled {
|
||||
err := firewallConf.SetEnabled(ctx, true) // disabled by default
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
|
||||
err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, port := range allSettings.Firewall.InputPorts {
|
||||
err = firewallConf.SetAllowedPort(ctx, port, defaultInterface)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
return err
|
||||
}
|
||||
} // TODO move inside firewall?
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
|
||||
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) // TODO waitgroup
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
|
||||
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel)
|
||||
@@ -296,55 +317,18 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
||||
// until openvpn is launched
|
||||
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
|
||||
|
||||
signalsCh := make(chan nativeos.Signal, 1)
|
||||
signal.Notify(signalsCh,
|
||||
syscall.SIGINT,
|
||||
syscall.SIGTERM,
|
||||
nativeos.Interrupt,
|
||||
)
|
||||
shutdownErrorsCount := 0
|
||||
select {
|
||||
case signal := <-signalsCh:
|
||||
logger.Warn("Caught OS signal %s, shutting down", signal)
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
logger.Warn("context canceled, shutting down")
|
||||
}
|
||||
<-ctx.Done()
|
||||
|
||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
||||
if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
||||
logger.Error(err)
|
||||
shutdownErrorsCount++
|
||||
}
|
||||
}
|
||||
const shutdownGracePeriod = 5 * time.Second
|
||||
waiting, waited := context.WithTimeout(context.Background(), shutdownGracePeriod)
|
||||
go func() {
|
||||
defer waited()
|
||||
wg.Wait()
|
||||
}()
|
||||
<-waiting.Done()
|
||||
if waiting.Err() == context.DeadlineExceeded {
|
||||
if shutdownErrorsCount > 0 {
|
||||
logger.Warn("Shutdown had %d errors", shutdownErrorsCount)
|
||||
}
|
||||
logger.Warn("Shutdown timed out")
|
||||
return 1
|
||||
}
|
||||
if shutdownErrorsCount > 0 {
|
||||
logger.Warn("Shutdown had %d errors")
|
||||
return 1
|
||||
}
|
||||
logger.Info("Shutdown successful")
|
||||
return 0
|
||||
}
|
||||
|
||||
func createLogger() logging.Logger {
|
||||
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return logger
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printVersions(ctx context.Context, logger logging.Logger,
|
||||
|
||||
Reference in New Issue
Block a user