Reduced main.go code complexity

This commit is contained in:
Quentin McGaw
2020-05-02 14:48:18 +00:00
parent 6049b10209
commit 363fabc810
5 changed files with 238 additions and 288 deletions

40
internal/env/env.go vendored
View File

@@ -1,40 +0,0 @@
package env
import (
"context"
"github.com/qdm12/golibs/logging"
)
type Env interface {
FatalOnError(err error)
PrintVersion(ctx context.Context, program string, commandFn func(ctx context.Context) (string, error))
}
type env struct {
logger logging.Logger
cancelContext func()
}
func New(logger logging.Logger, cancelContext context.CancelFunc) Env {
return &env{
logger: logger,
cancelContext: cancelContext,
}
}
func (e *env) FatalOnError(err error) {
if err != nil {
e.logger.Error(err)
e.cancelContext()
}
}
func (e *env) PrintVersion(ctx context.Context, program string, commandFn func(ctx context.Context) (string, error)) {
version, err := commandFn(ctx)
if err != nil {
e.logger.Error(err)
} else {
e.logger.Info("%s version: %s", program, version)
}
}

View File

@@ -1,90 +0,0 @@
package env
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/stretchr/testify/assert"
)
func Test_FatalOnError(t *testing.T) {
t.Parallel()
tests := map[string]struct {
err error
}{
"nil": {},
"err": {fmt.Errorf("error")},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
var logged string
var canceled bool
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
logger := mock_logging.NewMockLogger(mockCtrl)
if tc.err != nil {
logger.EXPECT().Error(tc.err).Do(func(err error) {
logged = err.Error()
}).Times(1)
}
e := &env{
logger: logger,
cancelContext: func() { canceled = true },
}
e.FatalOnError(tc.err)
if tc.err != nil {
assert.Equal(t, logged, tc.err.Error())
assert.True(t, canceled)
} else {
assert.Empty(t, logged)
assert.False(t, canceled)
}
})
}
}
func Test_PrintVersion(t *testing.T) {
t.Parallel()
tests := map[string]struct {
program string
commandVersion string
commandErr error
}{
"no data": {},
"data": {"binu", "2.3-5", nil},
"error": {"binu", "", fmt.Errorf("error")},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
var logged string
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
logger := mock_logging.NewMockLogger(mockCtrl)
if tc.commandErr != nil {
logger.EXPECT().Error(tc.commandErr).Do(func(err error) {
logged = err.Error()
}).Times(1)
} else {
logger.EXPECT().Info("%s version: %s", tc.program, tc.commandVersion).
Do(func(format, program, version string) {
logged = fmt.Sprintf(format, program, version)
}).Times(1)
}
e := &env{logger: logger}
commandFn := func(ctx context.Context) (string, error) { return tc.commandVersion, tc.commandErr }
e.PrintVersion(context.Background(), tc.program, commandFn)
if tc.commandErr != nil {
assert.Equal(t, logged, tc.commandErr.Error())
} else {
assert.Equal(t, logged, fmt.Sprintf("%s version: %s", tc.program, tc.commandVersion))
}
})
}
}

View File

@@ -16,22 +16,28 @@ type Server interface {
}
type server struct {
address string
logger logging.Logger
restartOpenvpn func()
address string
logger logging.Logger
restartOpenvpn func()
restartOpenvpnSet context.Context
restartOpenvpnSetSignal func()
sync.RWMutex
}
func New(address string, logger logging.Logger) Server {
restartOpenvpnSet, restartOpenvpnSetSignal := context.WithCancel(context.Background())
return &server{
address: address,
logger: logger.WithPrefix("http server: "),
address: address,
logger: logger.WithPrefix("http server: "),
restartOpenvpnSet: restartOpenvpnSet,
restartOpenvpnSetSignal: restartOpenvpnSetSignal,
}
}
func (s *server) Run(ctx context.Context) error {
if s.restartOpenvpn == nil {
s.logger.Warn("restartOpenvpn function is not set")
if s.restartOpenvpnSet.Err() == nil {
s.logger.Warn("restartOpenvpn function is not set, waiting...")
<-s.restartOpenvpnSet.Done()
}
server := http.Server{Addr: s.address, Handler: s.makeHandler()}
go func() {
@@ -50,6 +56,9 @@ func (s *server) SetOpenVPNRestart(f func()) {
s.Lock()
defer s.Unlock()
s.restartOpenvpn = f
if s.restartOpenvpnSet.Err() == nil {
s.restartOpenvpnSetSignal()
}
}
func (s *server) makeHandler() http.HandlerFunc {

View File

@@ -7,14 +7,10 @@ import (
"github.com/kyokomi/emoji"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
// Splash returns the welcome spash message
func Splash(paramsReader params.Reader) string {
version := paramsReader.GetVersion()
vcsRef := paramsReader.GetVcsRef()
buildDate := paramsReader.GetBuildDate()
func Splash(version, vcsRef, buildDate string) string {
lines := title()
lines = append(lines, "")
lines = append(lines, fmt.Sprintf("Running version %s built on %s (commit %s)", version, buildDate, vcsRef))