Code maintenance: use native Go HTTP client
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -11,7 +14,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/qdm12/golibs/network/mock_network"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -32,17 +34,45 @@ func Test_generateUnboundConf(t *testing.T) {
|
||||
}
|
||||
mockCtrl := gomock.NewController(t)
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return([]byte("b\na\nc"), 200, nil)
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
||||
Return([]byte("c\nd\n"), 200, nil)
|
||||
|
||||
clientCalls := map[models.URL]int{
|
||||
constants.MaliciousBlockListIPsURL: 0,
|
||||
constants.MaliciousBlockListHostnamesURL: 0,
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
url := models.URL(r.URL.String())
|
||||
if _, ok := clientCalls[url]; !ok {
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
clientCalls[url]++
|
||||
var body string
|
||||
switch url {
|
||||
case constants.MaliciousBlockListIPsURL:
|
||||
body = "c\nd"
|
||||
case constants.MaliciousBlockListHostnamesURL:
|
||||
body = "b\na\nc"
|
||||
default:
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(strings.NewReader(body)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("%d hostnames blocked overall", 2)
|
||||
logger.EXPECT().Info("%d IP addresses blocked overall", 3)
|
||||
lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger)
|
||||
require.Len(t, warnings, 0)
|
||||
expected := `
|
||||
for url, count := range clientCalls {
|
||||
assert.Equalf(t, 1, count, "for url %q", url)
|
||||
}
|
||||
const expected = `
|
||||
server:
|
||||
cache-max-ttl: 9000
|
||||
cache-min-ttl: 3600
|
||||
@@ -209,7 +239,10 @@ func Test_buildBlocked(t *testing.T) {
|
||||
ipsLines: []string{
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
errsString: []string{"ads error", "ads error"},
|
||||
errsString: []string{
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads error`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads error`,
|
||||
},
|
||||
},
|
||||
"all blocked with errors": {
|
||||
malicious: blockParams{
|
||||
@@ -224,37 +257,74 @@ func Test_buildBlocked(t *testing.T) {
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("surveillance"),
|
||||
},
|
||||
errsString: []string{"malicious", "malicious", "ads", "ads", "surveillance", "surveillance"},
|
||||
errsString: []string{
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated": malicious`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated": malicious`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance`,
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
|
||||
clientCalls := map[models.URL]int{}
|
||||
if tc.malicious.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr)
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr)
|
||||
clientCalls[constants.MaliciousBlockListIPsURL] = 0
|
||||
clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr)
|
||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr)
|
||||
clientCalls[constants.AdsBlockListIPsURL] = 0
|
||||
clientCalls[constants.AdsBlockListHostnamesURL] = 0
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
|
||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
|
||||
clientCalls[constants.SurveillanceBlockListIPsURL] = 0
|
||||
clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
url := models.URL(r.URL.String())
|
||||
if _, ok := clientCalls[url]; !ok {
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
clientCalls[url]++
|
||||
var body []byte
|
||||
var err error
|
||||
switch url {
|
||||
case constants.MaliciousBlockListIPsURL, constants.MaliciousBlockListHostnamesURL:
|
||||
body = tc.malicious.content
|
||||
err = tc.malicious.clientErr
|
||||
case constants.AdsBlockListIPsURL, constants.AdsBlockListHostnamesURL:
|
||||
body = tc.ads.content
|
||||
err = tc.ads.clientErr
|
||||
case constants.SurveillanceBlockListIPsURL, constants.SurveillanceBlockListHostnamesURL:
|
||||
body = tc.surveillance.content
|
||||
err = tc.surveillance.clientErr
|
||||
default: // just in case if the test is badly written
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
hostnamesLines, ipsLines, errs := buildBlocked(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
|
||||
tc.allowedHostnames, tc.privateAddresses)
|
||||
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
@@ -262,6 +332,10 @@ func Test_buildBlocked(t *testing.T) {
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines)
|
||||
assert.ElementsMatch(t, tc.ipsLines, ipsLines)
|
||||
|
||||
for url, count := range clientCalls {
|
||||
assert.Equalf(t, 1, count, "for url %q", url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -275,20 +349,45 @@ func Test_getList(t *testing.T) {
|
||||
results []string
|
||||
err error
|
||||
}{
|
||||
"no result": {nil, 200, nil, nil, nil},
|
||||
"bad status": {nil, 500, nil, nil, fmt.Errorf("HTTP status code is 500 and not 200")},
|
||||
"network error": {nil, 200, fmt.Errorf("error"), nil, fmt.Errorf("error")},
|
||||
"results": {[]byte("a\nb\nc\n"), 200, nil, []string{"a", "b", "c"}, nil},
|
||||
"no result": {
|
||||
status: http.StatusOK,
|
||||
},
|
||||
"bad status": {
|
||||
status: http.StatusInternalServerError,
|
||||
err: fmt.Errorf("bad HTTP status from irrelevant_url: Internal Server Error"),
|
||||
},
|
||||
"network error": {
|
||||
status: http.StatusOK,
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf(`Get "irrelevant_url": error`),
|
||||
},
|
||||
"results": {
|
||||
content: []byte("a\nb\nc\n"),
|
||||
status: http.StatusOK,
|
||||
results: []string{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, "irrelevant_url").
|
||||
Return(tc.content, tc.status, tc.clientErr)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, "irrelevant_url", r.URL.String())
|
||||
if tc.clientErr != nil {
|
||||
return nil, tc.clientErr
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: tc.status,
|
||||
Status: http.StatusText(tc.status),
|
||||
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
results, err := getList(ctx, client, "irrelevant_url")
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
@@ -316,10 +415,7 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
||||
lines []string
|
||||
errsString []string
|
||||
}{
|
||||
"nothing blocked": {
|
||||
lines: nil,
|
||||
errsString: nil,
|
||||
},
|
||||
"nothing blocked": {},
|
||||
"only malicious blocked": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
@@ -329,7 +425,6 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static"},
|
||||
errsString: nil,
|
||||
},
|
||||
"all blocked with some duplicates": {
|
||||
malicious: blockParams{
|
||||
@@ -348,7 +443,6 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static",
|
||||
" local-zone: \"site_c\" static"},
|
||||
errsString: nil,
|
||||
},
|
||||
"all blocked with one errored": {
|
||||
malicious: blockParams{
|
||||
@@ -367,7 +461,9 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static",
|
||||
" local-zone: \"site_c\" static"},
|
||||
errsString: []string{"surveillance error"},
|
||||
errsString: []string{
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance error`,
|
||||
},
|
||||
},
|
||||
"blocked with allowed hostnames": {
|
||||
malicious: blockParams{
|
||||
@@ -384,34 +480,71 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
||||
" local-zone: \"site_d\" static"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests { //nolint:dupl
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
|
||||
clientCalls := map[models.URL]int{}
|
||||
if tc.malicious.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr)
|
||||
clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr)
|
||||
clientCalls[constants.AdsBlockListHostnamesURL] = 0
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
|
||||
clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
url := models.URL(r.URL.String())
|
||||
if _, ok := clientCalls[url]; !ok {
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
clientCalls[url]++
|
||||
var body []byte
|
||||
var err error
|
||||
switch url {
|
||||
case constants.MaliciousBlockListHostnamesURL:
|
||||
body = tc.malicious.content
|
||||
err = tc.malicious.clientErr
|
||||
case constants.AdsBlockListHostnamesURL:
|
||||
body = tc.ads.content
|
||||
err = tc.ads.clientErr
|
||||
case constants.SurveillanceBlockListHostnamesURL:
|
||||
body = tc.surveillance.content
|
||||
err = tc.surveillance.clientErr
|
||||
default: // just in case if the test is badly written
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
lines, errs := buildBlockedHostnames(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked,
|
||||
tc.surveillance.blocked, tc.allowedHostnames)
|
||||
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
}
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.lines, lines)
|
||||
|
||||
for url, count := range clientCalls {
|
||||
assert.Equalf(t, 1, count, "for url %q", url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -431,10 +564,7 @@ func Test_buildBlockedIPs(t *testing.T) {
|
||||
lines []string
|
||||
errsString []string
|
||||
}{
|
||||
"nothing blocked": {
|
||||
lines: nil,
|
||||
errsString: nil,
|
||||
},
|
||||
"nothing blocked": {},
|
||||
"only malicious blocked": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
@@ -444,7 +574,6 @@ func Test_buildBlockedIPs(t *testing.T) {
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b"},
|
||||
errsString: nil,
|
||||
},
|
||||
"all blocked with some duplicates": {
|
||||
malicious: blockParams{
|
||||
@@ -463,7 +592,6 @@ func Test_buildBlockedIPs(t *testing.T) {
|
||||
" private-address: site_a",
|
||||
" private-address: site_b",
|
||||
" private-address: site_c"},
|
||||
errsString: nil,
|
||||
},
|
||||
"all blocked with one errored": {
|
||||
malicious: blockParams{
|
||||
@@ -482,7 +610,9 @@ func Test_buildBlockedIPs(t *testing.T) {
|
||||
" private-address: site_a",
|
||||
" private-address: site_b",
|
||||
" private-address: site_c"},
|
||||
errsString: []string{"surveillance error"},
|
||||
errsString: []string{
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance error`,
|
||||
},
|
||||
},
|
||||
"blocked with private addresses": {
|
||||
malicious: blockParams{
|
||||
@@ -501,34 +631,72 @@ func Test_buildBlockedIPs(t *testing.T) {
|
||||
" private-address: site_d"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests { //nolint:dupl
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
|
||||
clientCalls := map[models.URL]int{}
|
||||
if tc.malicious.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr)
|
||||
clientCalls[constants.MaliciousBlockListIPsURL] = 0
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr)
|
||||
clientCalls[constants.AdsBlockListIPsURL] = 0
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
|
||||
clientCalls[constants.SurveillanceBlockListIPsURL] = 0
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
url := models.URL(r.URL.String())
|
||||
if _, ok := clientCalls[url]; !ok {
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
clientCalls[url]++
|
||||
var body []byte
|
||||
var err error
|
||||
switch url {
|
||||
case constants.MaliciousBlockListIPsURL:
|
||||
body = tc.malicious.content
|
||||
err = tc.malicious.clientErr
|
||||
case constants.AdsBlockListIPsURL:
|
||||
body = tc.ads.content
|
||||
err = tc.ads.clientErr
|
||||
case constants.SurveillanceBlockListIPsURL:
|
||||
body = tc.surveillance.content
|
||||
err = tc.surveillance.clientErr
|
||||
default: // just in case if the test is badly written
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
lines, errs := buildBlockedIPs(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked,
|
||||
tc.surveillance.blocked, tc.privateAddresses)
|
||||
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
}
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.lines, lines)
|
||||
|
||||
for url, count := range clientCalls {
|
||||
assert.Equalf(t, 1, count, "for url %q", url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user