196 lines
4.8 KiB
Go
196 lines
4.8 KiB
Go
package dns
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/qdm12/gluetun/internal/constants"
|
|
"github.com/qdm12/golibs/logging/mock_logging"
|
|
"github.com/qdm12/golibs/os"
|
|
"github.com/qdm12/golibs/os/mock_os"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func Test_downloadAndSave(t *testing.T) {
|
|
t.Parallel()
|
|
const defaultURL = "https://test.com"
|
|
tests := map[string]struct {
|
|
url string // to trigger a new request error
|
|
content []byte
|
|
status int
|
|
clientErr error
|
|
openErr error
|
|
writeErr error
|
|
chownErr error
|
|
closeErr error
|
|
err error
|
|
}{
|
|
"no data": {
|
|
url: defaultURL,
|
|
status: http.StatusOK,
|
|
},
|
|
"bad status": {
|
|
url: defaultURL,
|
|
status: http.StatusBadRequest,
|
|
err: fmt.Errorf("bad HTTP status from %s: Bad Request", defaultURL),
|
|
},
|
|
"client error": {
|
|
url: defaultURL,
|
|
clientErr: fmt.Errorf("error"),
|
|
err: fmt.Errorf("Get %q: error", defaultURL),
|
|
},
|
|
"open error": {
|
|
url: defaultURL,
|
|
status: http.StatusOK,
|
|
openErr: fmt.Errorf("error"),
|
|
err: fmt.Errorf("error"),
|
|
},
|
|
"chown error": {
|
|
url: defaultURL,
|
|
status: http.StatusOK,
|
|
chownErr: fmt.Errorf("error"),
|
|
err: fmt.Errorf("error"),
|
|
},
|
|
"close error": {
|
|
url: defaultURL,
|
|
status: http.StatusOK,
|
|
closeErr: fmt.Errorf("error"),
|
|
err: fmt.Errorf("error"),
|
|
},
|
|
"data": {
|
|
url: defaultURL,
|
|
content: []byte("content"),
|
|
status: http.StatusOK,
|
|
},
|
|
}
|
|
for name, tc := range tests {
|
|
tc := tc
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
mockCtrl := gomock.NewController(t)
|
|
|
|
ctx := context.Background()
|
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
|
logger.EXPECT().Info("downloading %s from %s", "root hints", tc.url)
|
|
|
|
client := &http.Client{
|
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
assert.Equal(t, tc.url, r.URL.String())
|
|
if tc.clientErr != nil {
|
|
return nil, tc.clientErr
|
|
}
|
|
return &http.Response{
|
|
StatusCode: tc.status,
|
|
Status: http.StatusText(tc.status),
|
|
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
|
|
}, nil
|
|
}),
|
|
}
|
|
|
|
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
const filepath = "/test"
|
|
|
|
if tc.clientErr == nil && tc.status == http.StatusOK {
|
|
file := mock_os.NewMockFile(mockCtrl)
|
|
if tc.openErr == nil {
|
|
if len(tc.content) > 0 {
|
|
file.EXPECT().
|
|
Write(tc.content).
|
|
Return(len(tc.content), tc.writeErr)
|
|
}
|
|
file.EXPECT().
|
|
Close().
|
|
Return(tc.closeErr)
|
|
file.EXPECT().
|
|
Chown(1000, 1000).
|
|
Return(tc.chownErr)
|
|
}
|
|
|
|
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
|
|
assert.Equal(t, filepath, name)
|
|
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
|
|
assert.Equal(t, os.FileMode(0400), perm)
|
|
return file, tc.openErr
|
|
}
|
|
}
|
|
|
|
c := &configurator{
|
|
logger: logger,
|
|
client: client,
|
|
openFile: openFile,
|
|
}
|
|
|
|
err := c.downloadAndSave(ctx, "root hints",
|
|
tc.url, filepath,
|
|
1000, 1000)
|
|
|
|
if tc.err != nil {
|
|
require.Error(t, err)
|
|
assert.Equal(t, tc.err.Error(), err.Error())
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_DownloadRootHints(t *testing.T) {
|
|
t.Parallel()
|
|
mockCtrl := gomock.NewController(t)
|
|
|
|
ctx := context.Background()
|
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
|
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
|
|
|
client := &http.Client{
|
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
assert.Equal(t, string(constants.NamedRootURL), r.URL.String())
|
|
return nil, errors.New("test")
|
|
}),
|
|
}
|
|
|
|
c := &configurator{
|
|
logger: logger,
|
|
client: client,
|
|
}
|
|
|
|
err := c.DownloadRootHints(ctx, 1000, 1000)
|
|
require.Error(t, err)
|
|
assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated": test`, err.Error())
|
|
}
|
|
|
|
func Test_DownloadRootKey(t *testing.T) {
|
|
t.Parallel()
|
|
mockCtrl := gomock.NewController(t)
|
|
|
|
ctx := context.Background()
|
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
|
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
|
|
|
|
client := &http.Client{
|
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
assert.Equal(t, string(constants.RootKeyURL), r.URL.String())
|
|
return nil, errors.New("test")
|
|
}),
|
|
}
|
|
|
|
c := &configurator{
|
|
logger: logger,
|
|
client: client,
|
|
}
|
|
|
|
err := c.DownloadRootKey(ctx, 1000, 1000)
|
|
require.Error(t, err)
|
|
assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated": test`, err.Error())
|
|
}
|