diff --git a/cmd/main.go b/cmd/main.go index 9a982d81..0147203b 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -49,7 +49,11 @@ func main() { //nolint:gocognit } paramsReader := params.NewReader(logger) fmt.Println(splash.Splash(paramsReader)) - e := env.New(logger) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + e := env.New(logger, cancel) + client := network.NewClient(15 * time.Second) // Create configurators fileManager := files.NewFileManager() @@ -63,8 +67,6 @@ func main() { //nolint:gocognit windscribeConf := windscribe.NewConfigurator(fileManager) tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger) shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() streamMerger := command.NewStreamMerger() e.PrintVersion(ctx, "OpenVPN", ovpnConf.Version) diff --git a/internal/env/env.go b/internal/env/env.go index fee5bf38..5410a1f0 100644 --- a/internal/env/env.go +++ b/internal/env/env.go @@ -2,7 +2,6 @@ package env import ( "context" - "os" "github.com/qdm12/golibs/logging" ) @@ -13,21 +12,21 @@ type Env interface { } type env struct { - logger logging.Logger - osExit func(n int) + logger logging.Logger + cancelContext func() } -func New(logger logging.Logger) Env { +func New(logger logging.Logger, cancelContext context.CancelFunc) Env { return &env{ - logger: logger, - osExit: os.Exit, + logger: logger, + cancelContext: cancelContext, } } func (e *env) FatalOnError(err error) { if err != nil { e.logger.Error(err) - e.osExit(1) + e.cancelContext() } } diff --git a/internal/env/env_test.go b/internal/env/env_test.go index 59db21fc..26f7761a 100644 --- a/internal/env/env_test.go +++ b/internal/env/env_test.go @@ -23,7 +23,7 @@ func Test_FatalOnError(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() var logged string - var exitCode int + var canceled bool mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() logger := mock_logging.NewMockLogger(mockCtrl) @@ -32,15 +32,17 @@ func Test_FatalOnError(t *testing.T) { logged = err.Error() }).Times(1) } - osExit := func(n int) { exitCode = n } - e := &env{logger, osExit} + e := &env{ + logger: logger, + cancelContext: func() { canceled = true }, + } e.FatalOnError(tc.err) if tc.err != nil { assert.Equal(t, logged, tc.err.Error()) - assert.Equal(t, exitCode, 1) + assert.True(t, canceled) } else { assert.Empty(t, logged) - assert.Zero(t, exitCode) + assert.False(t, canceled) } }) }