Code maintenance: OS package for file system
- OS custom internal package for file system interaction - Remove fileManager external dependency - Closer API to Go's native API on the OS - Create directories at startup - Better testability - Move Unsetenv to os interface
This commit is contained in:
@@ -8,8 +8,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
@@ -21,11 +21,29 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN
|
||||
for _, warning := range warnings {
|
||||
c.logger.Warn(warning)
|
||||
}
|
||||
return c.fileManager.WriteLinesToFile(
|
||||
string(constants.UnboundConf),
|
||||
lines,
|
||||
files.Ownership(uid, gid),
|
||||
files.Permissions(constants.UserReadPermission))
|
||||
|
||||
const filepath = string(constants.UnboundConf)
|
||||
file, err := c.openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.WriteString(strings.Join(lines, "\n"))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := file.Chown(uid, gid); err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeUnboundConf generates an Unbound configuration from the user provided settings.
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
@@ -24,19 +24,20 @@ type Configurator interface {
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
fileManager files.FileManager
|
||||
commander command.Commander
|
||||
lookupIP func(host string) ([]net.IP, error)
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
openFile os.OpenFileFunc
|
||||
commander command.Commander
|
||||
lookupIP func(host string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator {
|
||||
func NewConfigurator(logger logging.Logger, client network.Client,
|
||||
openFile os.OpenFileFunc) Configurator {
|
||||
return &configurator{
|
||||
logger: logger.WithPrefix("dns configurator: "),
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
commander: command.NewCommander(),
|
||||
lookupIP: net.LookupIP,
|
||||
logger: logger.WithPrefix("dns configurator: "),
|
||||
client: client,
|
||||
openFile: openFile,
|
||||
commander: command.NewCommander(),
|
||||
lookupIP: net.LookupIP,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
)
|
||||
|
||||
// UseDNSInternally is to change the Go program DNS only.
|
||||
@@ -23,10 +25,16 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
|
||||
// UseDNSSystemWide changes the nameserver to use for DNS system wide.
|
||||
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
|
||||
c.logger.Info("using DNS address %s system wide", ip.String())
|
||||
data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
|
||||
const filepath = string(constants.ResolvConf)
|
||||
file, err := c.openFile(filepath, os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
s := strings.TrimSuffix(string(data), "\n")
|
||||
lines := strings.Split(s, "\n")
|
||||
if len(lines) == 1 && lines[0] == "" {
|
||||
@@ -44,6 +52,11 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
|
||||
if !found {
|
||||
lines = append(lines, "nameserver "+ip.String())
|
||||
}
|
||||
data = []byte(strings.Join(lines, "\n"))
|
||||
return c.fileManager.WriteToFile(string(constants.ResolvConf), data)
|
||||
s = strings.Join(lines, "\n")
|
||||
_, err = file.WriteString(s)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files/mock_files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/os/mock_os"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -17,30 +19,36 @@ func Test_UseDNSSystemWide(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
writtenData []byte
|
||||
writtenData string
|
||||
openErr error
|
||||
readErr error
|
||||
writeErr error
|
||||
closeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
writtenData: []byte("nameserver 127.0.0.1"),
|
||||
writtenData: "nameserver 127.0.0.1",
|
||||
},
|
||||
"open error": {
|
||||
openErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
writtenData: []byte("nameserver 127.0.0.1"),
|
||||
writtenData: "nameserver 127.0.0.1",
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"lines without nameserver": {
|
||||
data: []byte("abc\ndef\n"),
|
||||
writtenData: []byte("abc\ndef\nnameserver 127.0.0.1"),
|
||||
writtenData: "abc\ndef\nnameserver 127.0.0.1",
|
||||
},
|
||||
"lines with nameserver": {
|
||||
data: []byte("abc\nnameserver abc def\ndef\n"),
|
||||
writtenData: []byte("abc\nnameserver 127.0.0.1\ndef"),
|
||||
writtenData: "abc\nnameserver 127.0.0.1\ndef",
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
@@ -49,18 +57,43 @@ func Test_UseDNSSystemWide(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
fileManager.EXPECT().ReadFile(string(constants.ResolvConf)).
|
||||
Return(tc.data, tc.readErr)
|
||||
if tc.readErr == nil {
|
||||
fileManager.EXPECT().WriteToFile(string(constants.ResolvConf), tc.writtenData).
|
||||
Return(tc.writeErr)
|
||||
|
||||
file := mock_os.NewMockFile(mockCtrl)
|
||||
if tc.openErr == nil {
|
||||
firstReadCall := file.EXPECT().
|
||||
Read(gomock.AssignableToTypeOf([]byte{})).
|
||||
DoAndReturn(func(b []byte) (int, error) {
|
||||
copy(b, tc.data)
|
||||
return len(tc.data), nil
|
||||
})
|
||||
readErr := tc.readErr
|
||||
if readErr == nil {
|
||||
readErr = io.EOF
|
||||
}
|
||||
finalReadCall := file.EXPECT().
|
||||
Read(gomock.AssignableToTypeOf([]byte{})).
|
||||
Return(0, readErr).After(firstReadCall)
|
||||
if tc.readErr == nil {
|
||||
writeCall := file.EXPECT().WriteString(tc.writtenData).
|
||||
Return(0, tc.writeErr).After(finalReadCall)
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
|
||||
} else {
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
|
||||
}
|
||||
}
|
||||
|
||||
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
assert.Equal(t, string(constants.ResolvConf), name)
|
||||
assert.Equal(t, os.O_RDWR, flag)
|
||||
assert.Equal(t, os.FileMode(0644), perm)
|
||||
return file, tc.openErr
|
||||
}
|
||||
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1")
|
||||
c := &configurator{
|
||||
fileManager: fileManager,
|
||||
logger: logger,
|
||||
openFile: openFile,
|
||||
logger: logger,
|
||||
}
|
||||
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
|
||||
if tc.err != nil {
|
||||
|
||||
@@ -4,37 +4,46 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files"
|
||||
)
|
||||
|
||||
func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error {
|
||||
c.logger.Info("downloading root hints from %s", constants.NamedRootURL)
|
||||
content, status, err := c.client.Get(ctx, string(constants.NamedRootURL))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != http.StatusOK {
|
||||
return fmt.Errorf("HTTP status code is %d for %s", status, constants.NamedRootURL)
|
||||
}
|
||||
return c.fileManager.WriteToFile(
|
||||
string(constants.RootHints),
|
||||
content,
|
||||
files.Ownership(uid, gid),
|
||||
files.Permissions(constants.UserReadPermission))
|
||||
return c.downloadAndSave(ctx, "root hints",
|
||||
string(constants.NamedRootURL), string(constants.RootHints), uid, gid)
|
||||
}
|
||||
|
||||
func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error {
|
||||
c.logger.Info("downloading root key from %s", constants.RootKeyURL)
|
||||
content, status, err := c.client.Get(ctx, string(constants.RootKeyURL))
|
||||
return c.downloadAndSave(ctx, "root key",
|
||||
string(constants.RootKeyURL), string(constants.RootKey), uid, gid)
|
||||
}
|
||||
|
||||
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error {
|
||||
c.logger.Info("downloading %s from %s", logName, url)
|
||||
content, status, err := c.client.Get(ctx, url)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != http.StatusOK {
|
||||
return fmt.Errorf("HTTP status code is %d for %s", status, constants.RootKeyURL)
|
||||
return fmt.Errorf("HTTP status code is %d for %s", status, url)
|
||||
}
|
||||
return c.fileManager.WriteToFile(
|
||||
string(constants.RootKey),
|
||||
content,
|
||||
files.Ownership(uid, gid),
|
||||
files.Permissions(constants.UserReadPermission))
|
||||
|
||||
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.Write(content)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = file.Chown(uid, gid)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
@@ -2,27 +2,31 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/files/mock_files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/os/mock_os"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/qdm12/golibs/network/mock_network"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
func Test_downloadAndSave(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
openErr error
|
||||
writeErr error
|
||||
chownErr error
|
||||
closeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
@@ -36,11 +40,26 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"open error": {
|
||||
status: http.StatusOK,
|
||||
openErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
status: http.StatusOK,
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"chown error": {
|
||||
status: http.StatusOK,
|
||||
chownErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"close error": {
|
||||
status: http.StatusOK,
|
||||
closeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"data": {
|
||||
content: []byte("content"),
|
||||
status: http.StatusOK,
|
||||
@@ -52,23 +71,49 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading root hints from %s", constants.NamedRootURL)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
|
||||
Return(tc.content, tc.status, tc.clientErr)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
fileManager.EXPECT().WriteToFile(
|
||||
string(constants.RootHints),
|
||||
tc.content,
|
||||
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
|
||||
gomock.AssignableToTypeOf(files.Ownership(0, 0))).
|
||||
Return(tc.writeErr)
|
||||
|
||||
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
return nil, nil
|
||||
}
|
||||
c := &configurator{logger: logger, client: client, fileManager: fileManager}
|
||||
err := c.DownloadRootHints(ctx, 1000, 1000)
|
||||
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
file := mock_os.NewMockFile(mockCtrl)
|
||||
if tc.openErr == nil {
|
||||
writeCall := file.EXPECT().Write(tc.content).
|
||||
Return(0, tc.writeErr)
|
||||
if tc.writeErr != nil {
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
|
||||
} else {
|
||||
chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall)
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(chownCall)
|
||||
}
|
||||
}
|
||||
|
||||
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
assert.Equal(t, string(constants.RootHints), name)
|
||||
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
|
||||
assert.Equal(t, os.FileMode(0400), perm)
|
||||
return file, tc.openErr
|
||||
}
|
||||
}
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
openFile: openFile,
|
||||
}
|
||||
|
||||
err := c.downloadAndSave(ctx, "root hints",
|
||||
string(constants.NamedRootURL), string(constants.RootHints),
|
||||
1000, 1000)
|
||||
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
@@ -79,65 +124,44 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DownloadRootKey(t *testing.T) { //nolint:dupl
|
||||
func Test_DownloadRootHints(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
writeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
status: http.StatusOK,
|
||||
},
|
||||
"bad status": {
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"), //nolint:lll
|
||||
},
|
||||
"client error": {
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
status: http.StatusOK,
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"data": {
|
||||
content: []byte("content"),
|
||||
status: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading root key from %s", constants.RootKeyURL)
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
|
||||
Return(tc.content, tc.status, tc.clientErr)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
fileManager.EXPECT().WriteToFile(
|
||||
string(constants.RootKey),
|
||||
tc.content,
|
||||
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
|
||||
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
|
||||
).Return(tc.writeErr)
|
||||
}
|
||||
c := &configurator{logger: logger, client: client, fileManager: fileManager}
|
||||
err := c.DownloadRootKey(ctx, 1000, 1001)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
|
||||
Return(nil, http.StatusOK, errors.New("test"))
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
|
||||
err := c.DownloadRootHints(ctx, 1000, 1000)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "test", err.Error())
|
||||
}
|
||||
|
||||
func Test_DownloadRootKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
|
||||
Return(nil, http.StatusOK, errors.New("test"))
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
|
||||
err := c.DownloadRootKey(ctx, 1000, 1000)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "test", err.Error())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user