Code maintenance: OS package for file system

- OS custom internal package for file system interaction
- Remove fileManager external dependency
- Closer API to Go's native API on the OS
- Create directories at startup
- Better testability
- Move Unsetenv to os interface
This commit is contained in:
Quentin McGaw
2020-12-29 00:55:31 +00:00
parent f5366c33bc
commit 73479bab26
43 changed files with 923 additions and 353 deletions

View File

@@ -8,8 +8,8 @@ import (
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
)
@@ -21,11 +21,29 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN
for _, warning := range warnings {
c.logger.Warn(warning)
}
return c.fileManager.WriteLinesToFile(
string(constants.UnboundConf),
lines,
files.Ownership(uid, gid),
files.Permissions(constants.UserReadPermission))
const filepath = string(constants.UnboundConf)
file, err := c.openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0400)
if err != nil {
return err
}
_, err = file.WriteString(strings.Join(lines, "\n"))
if err != nil {
_ = file.Close()
return err
}
if err := file.Chown(uid, gid); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
return nil
}
// MakeUnboundConf generates an Unbound configuration from the user provided settings.

View File

@@ -5,9 +5,9 @@ import (
"io"
"net"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
)
@@ -24,19 +24,20 @@ type Configurator interface {
}
type configurator struct {
logger logging.Logger
client network.Client
fileManager files.FileManager
commander command.Commander
lookupIP func(host string) ([]net.IP, error)
logger logging.Logger
client network.Client
openFile os.OpenFileFunc
commander command.Commander
lookupIP func(host string) ([]net.IP, error)
}
func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator {
func NewConfigurator(logger logging.Logger, client network.Client,
openFile os.OpenFileFunc) Configurator {
return &configurator{
logger: logger.WithPrefix("dns configurator: "),
client: client,
fileManager: fileManager,
commander: command.NewCommander(),
lookupIP: net.LookupIP,
logger: logger.WithPrefix("dns configurator: "),
client: client,
openFile: openFile,
commander: command.NewCommander(),
lookupIP: net.LookupIP,
}
}

View File

@@ -2,10 +2,12 @@ package dns
import (
"context"
"io/ioutil"
"net"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/os"
)
// UseDNSInternally is to change the Go program DNS only.
@@ -23,10 +25,16 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
// UseDNSSystemWide changes the nameserver to use for DNS system wide.
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
c.logger.Info("using DNS address %s system wide", ip.String())
data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
const filepath = string(constants.ResolvConf)
file, err := c.openFile(filepath, os.O_RDWR, 0644)
if err != nil {
return err
}
data, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
s := strings.TrimSuffix(string(data), "\n")
lines := strings.Split(s, "\n")
if len(lines) == 1 && lines[0] == "" {
@@ -44,6 +52,11 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
if !found {
lines = append(lines, "nameserver "+ip.String())
}
data = []byte(strings.Join(lines, "\n"))
return c.fileManager.WriteToFile(string(constants.ResolvConf), data)
s = strings.Join(lines, "\n")
_, err = file.WriteString(s)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -2,12 +2,14 @@ package dns
import (
"fmt"
"io"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files/mock_files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/os/mock_os"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -17,30 +19,36 @@ func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
writtenData []byte
writtenData string
openErr error
readErr error
writeErr error
closeErr error
err error
}{
"no data": {
writtenData: []byte("nameserver 127.0.0.1"),
writtenData: "nameserver 127.0.0.1",
},
"open error": {
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
writtenData: []byte("nameserver 127.0.0.1"),
writtenData: "nameserver 127.0.0.1",
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"lines without nameserver": {
data: []byte("abc\ndef\n"),
writtenData: []byte("abc\ndef\nnameserver 127.0.0.1"),
writtenData: "abc\ndef\nnameserver 127.0.0.1",
},
"lines with nameserver": {
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: []byte("abc\nnameserver 127.0.0.1\ndef"),
writtenData: "abc\nnameserver 127.0.0.1\ndef",
},
}
for name, tc := range tests {
@@ -49,18 +57,43 @@ func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
fileManager := mock_files.NewMockFileManager(mockCtrl)
fileManager.EXPECT().ReadFile(string(constants.ResolvConf)).
Return(tc.data, tc.readErr)
if tc.readErr == nil {
fileManager.EXPECT().WriteToFile(string(constants.ResolvConf), tc.writtenData).
Return(tc.writeErr)
file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil {
firstReadCall := file.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
DoAndReturn(func(b []byte) (int, error) {
copy(b, tc.data)
return len(tc.data), nil
})
readErr := tc.readErr
if readErr == nil {
readErr = io.EOF
}
finalReadCall := file.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
Return(0, readErr).After(firstReadCall)
if tc.readErr == nil {
writeCall := file.EXPECT().WriteString(tc.writtenData).
Return(0, tc.writeErr).After(finalReadCall)
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
}
}
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.ResolvConf), name)
assert.Equal(t, os.O_RDWR, flag)
assert.Equal(t, os.FileMode(0644), perm)
return file, tc.openErr
}
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1")
c := &configurator{
fileManager: fileManager,
logger: logger,
openFile: openFile,
logger: logger,
}
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
if tc.err != nil {

View File

@@ -4,37 +4,46 @@ import (
"context"
"fmt"
"net/http"
"os"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
)
func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error {
c.logger.Info("downloading root hints from %s", constants.NamedRootURL)
content, status, err := c.client.Get(ctx, string(constants.NamedRootURL))
if err != nil {
return err
} else if status != http.StatusOK {
return fmt.Errorf("HTTP status code is %d for %s", status, constants.NamedRootURL)
}
return c.fileManager.WriteToFile(
string(constants.RootHints),
content,
files.Ownership(uid, gid),
files.Permissions(constants.UserReadPermission))
return c.downloadAndSave(ctx, "root hints",
string(constants.NamedRootURL), string(constants.RootHints), uid, gid)
}
func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error {
c.logger.Info("downloading root key from %s", constants.RootKeyURL)
content, status, err := c.client.Get(ctx, string(constants.RootKeyURL))
return c.downloadAndSave(ctx, "root key",
string(constants.RootKeyURL), string(constants.RootKey), uid, gid)
}
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error {
c.logger.Info("downloading %s from %s", logName, url)
content, status, err := c.client.Get(ctx, url)
if err != nil {
return err
} else if status != http.StatusOK {
return fmt.Errorf("HTTP status code is %d for %s", status, constants.RootKeyURL)
return fmt.Errorf("HTTP status code is %d for %s", status, url)
}
return c.fileManager.WriteToFile(
string(constants.RootKey),
content,
files.Ownership(uid, gid),
files.Permissions(constants.UserReadPermission))
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
if err != nil {
return err
}
_, err = file.Write(content)
if err != nil {
_ = file.Close()
return err
}
err = file.Chown(uid, gid)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -2,27 +2,31 @@ package dns
import (
"context"
"errors"
"fmt"
"net/http"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/files/mock_files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/os/mock_os"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/qdm12/golibs/network/mock_network"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
func Test_downloadAndSave(t *testing.T) {
t.Parallel()
tests := map[string]struct {
content []byte
status int
clientErr error
openErr error
writeErr error
chownErr error
closeErr error
err error
}{
"no data": {
@@ -36,11 +40,26 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"open error": {
status: http.StatusOK,
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
status: http.StatusOK,
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"chown error": {
status: http.StatusOK,
chownErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"close error": {
status: http.StatusOK,
closeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": {
content: []byte("content"),
status: http.StatusOK,
@@ -52,23 +71,49 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading root hints from %s", constants.NamedRootURL)
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
Return(tc.content, tc.status, tc.clientErr)
fileManager := mock_files.NewMockFileManager(mockCtrl)
if tc.clientErr == nil && tc.status == http.StatusOK {
fileManager.EXPECT().WriteToFile(
string(constants.RootHints),
tc.content,
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
gomock.AssignableToTypeOf(files.Ownership(0, 0))).
Return(tc.writeErr)
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
return nil, nil
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootHints(ctx, 1000, 1000)
if tc.clientErr == nil && tc.status == http.StatusOK {
file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil {
writeCall := file.EXPECT().Write(tc.content).
Return(0, tc.writeErr)
if tc.writeErr != nil {
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall)
file.EXPECT().Close().Return(tc.closeErr).After(chownCall)
}
}
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.RootHints), name)
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
assert.Equal(t, os.FileMode(0400), perm)
return file, tc.openErr
}
}
c := &configurator{
logger: logger,
client: client,
openFile: openFile,
}
err := c.downloadAndSave(ctx, "root hints",
string(constants.NamedRootURL), string(constants.RootHints),
1000, 1000)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
@@ -79,65 +124,44 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
}
}
func Test_DownloadRootKey(t *testing.T) { //nolint:dupl
func Test_DownloadRootHints(t *testing.T) {
t.Parallel()
tests := map[string]struct {
content []byte
status int
clientErr error
writeErr error
err error
}{
"no data": {
status: http.StatusOK,
},
"bad status": {
status: http.StatusBadRequest,
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"), //nolint:lll
},
"client error": {
clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
status: http.StatusOK,
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": {
content: []byte("content"),
status: http.StatusOK,
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading root key from %s", constants.RootKeyURL)
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
Return(tc.content, tc.status, tc.clientErr)
fileManager := mock_files.NewMockFileManager(mockCtrl)
if tc.clientErr == nil && tc.status == http.StatusOK {
fileManager.EXPECT().WriteToFile(
string(constants.RootKey),
tc.content,
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
).Return(tc.writeErr)
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootKey(ctx, 1000, 1001)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
mockCtrl := gomock.NewController(t)
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
Return(nil, http.StatusOK, errors.New("test"))
c := &configurator{
logger: logger,
client: client,
}
err := c.DownloadRootHints(ctx, 1000, 1000)
require.Error(t, err)
assert.Equal(t, "test", err.Error())
}
func Test_DownloadRootKey(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
Return(nil, http.StatusOK, errors.New("test"))
c := &configurator{
logger: logger,
client: client,
}
err := c.DownloadRootKey(ctx, 1000, 1000)
require.Error(t, err)
assert.Equal(t, "test", err.Error())
}