Maintenance: rework main function
This commit is contained in:
@@ -53,42 +53,79 @@ func main() {
|
|||||||
Commit: commit,
|
Commit: commit,
|
||||||
BuildDate: buildDate,
|
BuildDate: buildDate,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
nativeos.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
args := nativeos.Args
|
args := nativeos.Args
|
||||||
os := os.New()
|
os := os.New()
|
||||||
osUser := user.New()
|
osUser := user.New()
|
||||||
unix := unix.New()
|
unix := unix.New()
|
||||||
cli := cli.New()
|
cli := cli.New()
|
||||||
nativeos.Exit(_main(ctx, buildInfo, args, os, osUser, unix, cli))
|
|
||||||
|
errorCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
errorCh <- _main(ctx, buildInfo, args, logger, os, osUser, unix, cli)
|
||||||
|
}()
|
||||||
|
|
||||||
|
signalsCh := make(chan nativeos.Signal, 1)
|
||||||
|
signal.Notify(signalsCh,
|
||||||
|
syscall.SIGINT,
|
||||||
|
syscall.SIGTERM,
|
||||||
|
nativeos.Interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case signal := <-signalsCh:
|
||||||
|
logger.Warn("Caught OS signal %s, shutting down", signal)
|
||||||
|
case err := <-errorCh:
|
||||||
|
logger.Error(err)
|
||||||
|
close(errorCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
const shutdownGracePeriod = 5 * time.Second
|
||||||
|
timer := time.NewTimer(shutdownGracePeriod)
|
||||||
|
select {
|
||||||
|
case <-errorCh:
|
||||||
|
if !timer.Stop() {
|
||||||
|
<-timer.C
|
||||||
|
}
|
||||||
|
logger.Info("Shutdown successful")
|
||||||
|
case <-timer.C:
|
||||||
|
logger.Warn("Shutdown timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
nativeos.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit,gocyclo
|
//nolint:gocognit,gocyclo
|
||||||
func _main(background context.Context, buildInfo models.BuildInformation,
|
func _main(background context.Context, buildInfo models.BuildInformation,
|
||||||
args []string, os os.OS, osUser user.OSUser, unix unix.Unix,
|
args []string, logger logging.Logger, os os.OS, osUser user.OSUser, unix unix.Unix,
|
||||||
cli cli.CLI) int {
|
cli cli.CLI) error {
|
||||||
if len(args) > 1 { // cli operation
|
if len(args) > 1 { // cli operation
|
||||||
var err error
|
|
||||||
switch args[1] {
|
switch args[1] {
|
||||||
case "healthcheck":
|
case "healthcheck":
|
||||||
err = cli.HealthCheck(background)
|
return cli.HealthCheck(background)
|
||||||
case "clientkey":
|
case "clientkey":
|
||||||
err = cli.ClientKey(args[2:], os.OpenFile)
|
return cli.ClientKey(args[2:], os.OpenFile)
|
||||||
case "openvpnconfig":
|
case "openvpnconfig":
|
||||||
err = cli.OpenvpnConfig(os)
|
return cli.OpenvpnConfig(os)
|
||||||
case "update":
|
case "update":
|
||||||
err = cli.Update(args[2:], os)
|
return cli.Update(args[2:], os)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("command %q is unknown", args[1])
|
return fmt.Errorf("command %q is unknown", args[1])
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithCancel(background)
|
ctx, cancel := context.WithCancel(background)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
logger := createLogger()
|
|
||||||
|
|
||||||
const clientTimeout = 15 * time.Second
|
const clientTimeout = 15 * time.Second
|
||||||
httpClient := &http.Client{Timeout: clientTimeout}
|
httpClient := &http.Client{Timeout: clientTimeout}
|
||||||
@@ -114,26 +151,22 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
|||||||
|
|
||||||
allSettings, err := settings.GetAllSettings(paramsReader)
|
allSettings, err := settings.GetAllSettings(paramsReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
logger.Info(allSettings.String())
|
logger.Info(allSettings.String())
|
||||||
|
|
||||||
if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil {
|
if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
if err := os.MkdirAll("/gluetun", 0644); err != nil {
|
if err := os.MkdirAll("/gluetun", 0644); err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO run this in a loop or in openvpn to reload from file without restarting
|
// TODO run this in a loop or in openvpn to reload from file without restarting
|
||||||
storage := storage.New(logger, os, constants.ServersData)
|
storage := storage.New(logger, os, constants.ServersData)
|
||||||
allServers, err := storage.SyncServers(constants.GetAllServers())
|
allServers, err := storage.SyncServers(constants.GetAllServers())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should never change
|
// Should never change
|
||||||
@@ -142,16 +175,14 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
|||||||
const defaultUsername = "nonrootuser"
|
const defaultUsername = "nonrootuser"
|
||||||
nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid)
|
nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
if nonRootUsername != defaultUsername {
|
if nonRootUsername != defaultUsername {
|
||||||
logger.Info("using existing username %s corresponding to user id %d", nonRootUsername, puid)
|
logger.Info("using existing username %s corresponding to user id %d", nonRootUsername, puid)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Chown("/etc/unbound", puid, pgid); err != nil {
|
if err := os.Chown("/etc/unbound", puid, pgid); err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if allSettings.Firewall.Debug {
|
if allSettings.Firewall.Debug {
|
||||||
@@ -161,27 +192,23 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
|||||||
|
|
||||||
defaultInterface, defaultGateway, err := routingConf.DefaultRoute()
|
defaultInterface, defaultGateway, err := routingConf.DefaultRoute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
localSubnet, err := routingConf.LocalSubnet()
|
localSubnet, err := routingConf.LocalSubnet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultIP, err := routingConf.DefaultIP()
|
defaultIP, err := routingConf.DefaultIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet, defaultIP)
|
firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet, defaultIP)
|
||||||
|
|
||||||
if err := routingConf.Setup(); err != nil {
|
if err := routingConf.Setup(); err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
routingConf.SetVerbose(false)
|
routingConf.SetVerbose(false)
|
||||||
@@ -191,20 +218,17 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if err := firewallConf.SetOutboundSubnets(ctx, allSettings.Firewall.OutboundSubnets); err != nil {
|
if err := firewallConf.SetOutboundSubnets(ctx, allSettings.Firewall.OutboundSubnets); err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
if err := routingConf.SetOutboundRoutes(allSettings.Firewall.OutboundSubnets); err != nil {
|
if err := routingConf.SetOutboundRoutes(allSettings.Firewall.OutboundSubnets); err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ovpnConf.CheckTUN(); err != nil {
|
if err := ovpnConf.CheckTUN(); err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
err = ovpnConf.CreateTUN()
|
err = ovpnConf.CreateTUN()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,30 +241,27 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
|||||||
if allSettings.Firewall.Enabled {
|
if allSettings.Firewall.Enabled {
|
||||||
err := firewallConf.SetEnabled(ctx, true) // disabled by default
|
err := firewallConf.SetEnabled(ctx, true) // disabled by default
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
|
for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
|
||||||
err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN))
|
err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, port := range allSettings.Firewall.InputPorts {
|
for _, port := range allSettings.Firewall.InputPorts {
|
||||||
err = firewallConf.SetAllowedPort(ctx, port, defaultInterface)
|
err = firewallConf.SetAllowedPort(ctx, port, defaultInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
return err
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
} // TODO move inside firewall?
|
} // TODO move inside firewall?
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
|
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) // TODO waitgroup
|
||||||
|
|
||||||
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
|
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
|
||||||
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel)
|
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel)
|
||||||
@@ -296,55 +317,18 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
|||||||
// until openvpn is launched
|
// until openvpn is launched
|
||||||
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
|
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
|
||||||
|
|
||||||
signalsCh := make(chan nativeos.Signal, 1)
|
<-ctx.Done()
|
||||||
signal.Notify(signalsCh,
|
|
||||||
syscall.SIGINT,
|
|
||||||
syscall.SIGTERM,
|
|
||||||
nativeos.Interrupt,
|
|
||||||
)
|
|
||||||
shutdownErrorsCount := 0
|
|
||||||
select {
|
|
||||||
case signal := <-signalsCh:
|
|
||||||
logger.Warn("Caught OS signal %s, shutting down", signal)
|
|
||||||
cancel()
|
|
||||||
case <-ctx.Done():
|
|
||||||
logger.Warn("context canceled, shutting down")
|
|
||||||
}
|
|
||||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||||
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
||||||
if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
shutdownErrorsCount++
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const shutdownGracePeriod = 5 * time.Second
|
|
||||||
waiting, waited := context.WithTimeout(context.Background(), shutdownGracePeriod)
|
|
||||||
go func() {
|
|
||||||
defer waited()
|
|
||||||
wg.Wait()
|
|
||||||
}()
|
|
||||||
<-waiting.Done()
|
|
||||||
if waiting.Err() == context.DeadlineExceeded {
|
|
||||||
if shutdownErrorsCount > 0 {
|
|
||||||
logger.Warn("Shutdown had %d errors", shutdownErrorsCount)
|
|
||||||
}
|
|
||||||
logger.Warn("Shutdown timed out")
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
if shutdownErrorsCount > 0 {
|
|
||||||
logger.Warn("Shutdown had %d errors")
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
logger.Info("Shutdown successful")
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func createLogger() logging.Logger {
|
wg.Wait()
|
||||||
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
|
|
||||||
if err != nil {
|
return nil
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func printVersions(ctx context.Context, logger logging.Logger,
|
func printVersions(ctx context.Context, logger logging.Logger,
|
||||||
|
|||||||
Reference in New Issue
Block a user