Code maintenance: use native Go HTTP client

This commit is contained in:
Quentin McGaw
2020-12-29 02:55:34 +00:00
parent 60e98235ca
commit 8d5f2fec09
11 changed files with 379 additions and 122 deletions

View File

@@ -36,7 +36,6 @@ import (
versionpkg "github.com/qdm12/gluetun/internal/version" versionpkg "github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
) )
//nolint:gochecknoglobals //nolint:gochecknoglobals
@@ -89,11 +88,10 @@ func _main(background context.Context, buildInfo models.BuildInformation,
const clientTimeout = 15 * time.Second const clientTimeout = 15 * time.Second
httpClient := &http.Client{Timeout: clientTimeout} httpClient := &http.Client{Timeout: clientTimeout}
client := network.NewClient(clientTimeout)
// Create configurators // Create configurators
alpineConf := alpine.NewConfigurator(os.OpenFile, osUser) alpineConf := alpine.NewConfigurator(os.OpenFile, osUser)
ovpnConf := openvpn.NewConfigurator(logger, os, unix) ovpnConf := openvpn.NewConfigurator(logger, os, unix)
dnsConf := dns.NewConfigurator(logger, client, os.OpenFile) dnsConf := dns.NewConfigurator(logger, httpClient, os.OpenFile)
routingConf := routing.NewRouting(logger) routingConf := routing.NewRouting(logger)
firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile) firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile)
streamMerger := command.NewStreamMerger() streamMerger := command.NewStreamMerger()
@@ -253,7 +251,7 @@ func _main(background context.Context, buildInfo models.BuildInformation,
go unboundLooper.Run(ctx, wg, signalDNSReady) go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper( publicIPLooper := publicip.NewLooper(
client, logger, allSettings.PublicIP, uid, gid, os) httpClient, logger, allSettings.PublicIP, uid, gid, os)
wg.Add(1) wg.Add(1)
go publicIPLooper.Run(ctx, wg) go publicIPLooper.Run(ctx, wg)
wg.Add(1) wg.Add(1)

View File

@@ -3,6 +3,7 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"sort" "sort"
"strings" "strings"
@@ -11,7 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
) )
func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS, func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS,
@@ -48,7 +48,7 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN
// MakeUnboundConf generates an Unbound configuration from the user provided settings. // MakeUnboundConf generates an Unbound configuration from the user provided settings.
func generateUnboundConf(ctx context.Context, settings settings.DNS, username string, func generateUnboundConf(ctx context.Context, settings settings.DNS, username string,
client network.Client, logger logging.Logger) ( client *http.Client, logger logging.Logger) (
lines []string, warnings []error) { lines []string, warnings []error) {
doIPv6 := "no" doIPv6 := "no"
if settings.IPv6 { if settings.IPv6 {
@@ -151,7 +151,7 @@ func generateUnboundConf(ctx context.Context, settings settings.DNS, username st
return lines, warnings return lines, warnings
} }
func buildBlocked(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool, func buildBlocked(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) { allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
chHostnames := make(chan []string) chHostnames := make(chan []string)
chIPs := make(chan []string) chIPs := make(chan []string)
@@ -181,13 +181,27 @@ func buildBlocked(ctx context.Context, client network.Client, blockMalicious, bl
return hostnamesLines, ipsLines, errs return hostnamesLines, ipsLines, errs
} }
func getList(ctx context.Context, client network.Client, url string) (results []string, err error) { func getList(ctx context.Context, client *http.Client, url string) (results []string, err error) {
content, status, err := client.Get(ctx, url) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != http.StatusOK {
return nil, fmt.Errorf("HTTP status code is %d and not 200", status)
} }
response, err := client.Do(req)
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
}
content, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err)
}
results = strings.Split(string(content), "\n") results = strings.Split(string(content), "\n")
// remove empty lines // remove empty lines
@@ -206,7 +220,7 @@ func getList(ctx context.Context, client network.Client, url string) (results []
return results, nil return results, nil
} }
func buildBlockedHostnames(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool, func buildBlockedHostnames(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
allowedHostnames []string) (lines []string, errs []error) { allowedHostnames []string) (lines []string, errs []error) {
chResults := make(chan []string) chResults := make(chan []string)
chError := make(chan error) chError := make(chan error)
@@ -258,7 +272,7 @@ func buildBlockedHostnames(ctx context.Context, client network.Client, blockMali
return lines, errs return lines, errs
} }
func buildBlockedIPs(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool, func buildBlockedIPs(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
privateAddresses []string) (lines []string, errs []error) { privateAddresses []string) (lines []string, errs []error) {
chResults := make(chan []string) chResults := make(chan []string)
chError := make(chan error) chError := make(chan error)

View File

@@ -1,8 +1,11 @@
package dns package dns
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io/ioutil"
"net/http"
"strings" "strings"
"testing" "testing"
@@ -11,7 +14,6 @@ import (
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/logging/mock_logging" "github.com/qdm12/golibs/logging/mock_logging"
"github.com/qdm12/golibs/network/mock_network"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -32,17 +34,45 @@ func Test_generateUnboundConf(t *testing.T) {
} }
mockCtrl := gomock.NewController(t) mockCtrl := gomock.NewController(t)
ctx := context.Background() ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)). clientCalls := map[models.URL]int{
Return([]byte("b\na\nc"), 200, nil) constants.MaliciousBlockListIPsURL: 0,
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)). constants.MaliciousBlockListHostnamesURL: 0,
Return([]byte("c\nd\n"), 200, nil) }
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
url := models.URL(r.URL.String())
if _, ok := clientCalls[url]; !ok {
t.Errorf("unknown URL %q", url)
return nil, nil
}
clientCalls[url]++
var body string
switch url {
case constants.MaliciousBlockListIPsURL:
body = "c\nd"
case constants.MaliciousBlockListHostnamesURL:
body = "b\na\nc"
default:
t.Errorf("unknown URL %q", url)
return nil, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(body)),
}, nil
}),
}
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("%d hostnames blocked overall", 2) logger.EXPECT().Info("%d hostnames blocked overall", 2)
logger.EXPECT().Info("%d IP addresses blocked overall", 3) logger.EXPECT().Info("%d IP addresses blocked overall", 3)
lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger) lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger)
require.Len(t, warnings, 0) require.Len(t, warnings, 0)
expected := ` for url, count := range clientCalls {
assert.Equalf(t, 1, count, "for url %q", url)
}
const expected = `
server: server:
cache-max-ttl: 9000 cache-max-ttl: 9000
cache-min-ttl: 3600 cache-min-ttl: 3600
@@ -209,7 +239,10 @@ func Test_buildBlocked(t *testing.T) {
ipsLines: []string{ ipsLines: []string{
" private-address: malicious", " private-address: malicious",
" private-address: surveillance"}, " private-address: surveillance"},
errsString: []string{"ads error", "ads error"}, errsString: []string{
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads error`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads error`,
},
}, },
"all blocked with errors": { "all blocked with errors": {
malicious: blockParams{ malicious: blockParams{
@@ -224,37 +257,74 @@ func Test_buildBlocked(t *testing.T) {
blocked: true, blocked: true,
clientErr: fmt.Errorf("surveillance"), clientErr: fmt.Errorf("surveillance"),
}, },
errsString: []string{"malicious", "malicious", "ads", "ads", "surveillance", "surveillance"}, errsString: []string{
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated": malicious`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated": malicious`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance`,
},
}, },
} }
for name, tc := range tests { for name, tc := range tests {
tc := tc tc := tc
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background() ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
clientCalls := map[models.URL]int{}
if tc.malicious.blocked { if tc.malicious.blocked {
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)). clientCalls[constants.MaliciousBlockListIPsURL] = 0
Return(tc.malicious.content, 200, tc.malicious.clientErr) clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr)
} }
if tc.ads.blocked { if tc.ads.blocked {
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)). clientCalls[constants.AdsBlockListIPsURL] = 0
Return(tc.ads.content, 200, tc.ads.clientErr) clientCalls[constants.AdsBlockListHostnamesURL] = 0
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
Return(tc.ads.content, 200, tc.ads.clientErr)
} }
if tc.surveillance.blocked { if tc.surveillance.blocked {
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)). clientCalls[constants.SurveillanceBlockListIPsURL] = 0
Return(tc.surveillance.content, 200, tc.surveillance.clientErr) clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
} }
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
url := models.URL(r.URL.String())
if _, ok := clientCalls[url]; !ok {
t.Errorf("unknown URL %q", url)
return nil, nil
}
clientCalls[url]++
var body []byte
var err error
switch url {
case constants.MaliciousBlockListIPsURL, constants.MaliciousBlockListHostnamesURL:
body = tc.malicious.content
err = tc.malicious.clientErr
case constants.AdsBlockListIPsURL, constants.AdsBlockListHostnamesURL:
body = tc.ads.content
err = tc.ads.clientErr
case constants.SurveillanceBlockListIPsURL, constants.SurveillanceBlockListHostnamesURL:
body = tc.surveillance.content
err = tc.surveillance.clientErr
default: // just in case if the test is badly written
t.Errorf("unknown URL %q", url)
return nil, nil
}
if err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(body)),
}, nil
}),
}
hostnamesLines, ipsLines, errs := buildBlocked(ctx, client, hostnamesLines, ipsLines, errs := buildBlocked(ctx, client,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
tc.allowedHostnames, tc.privateAddresses) tc.allowedHostnames, tc.privateAddresses)
var errsString []string var errsString []string
for _, err := range errs { for _, err := range errs {
errsString = append(errsString, err.Error()) errsString = append(errsString, err.Error())
@@ -262,6 +332,10 @@ func Test_buildBlocked(t *testing.T) {
assert.ElementsMatch(t, tc.errsString, errsString) assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines) assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines)
assert.ElementsMatch(t, tc.ipsLines, ipsLines) assert.ElementsMatch(t, tc.ipsLines, ipsLines)
for url, count := range clientCalls {
assert.Equalf(t, 1, count, "for url %q", url)
}
}) })
} }
} }
@@ -275,20 +349,45 @@ func Test_getList(t *testing.T) {
results []string results []string
err error err error
}{ }{
"no result": {nil, 200, nil, nil, nil}, "no result": {
"bad status": {nil, 500, nil, nil, fmt.Errorf("HTTP status code is 500 and not 200")}, status: http.StatusOK,
"network error": {nil, 200, fmt.Errorf("error"), nil, fmt.Errorf("error")}, },
"results": {[]byte("a\nb\nc\n"), 200, nil, []string{"a", "b", "c"}, nil}, "bad status": {
status: http.StatusInternalServerError,
err: fmt.Errorf("bad HTTP status from irrelevant_url: Internal Server Error"),
},
"network error": {
status: http.StatusOK,
clientErr: fmt.Errorf("error"),
err: fmt.Errorf(`Get "irrelevant_url": error`),
},
"results": {
content: []byte("a\nb\nc\n"),
status: http.StatusOK,
results: []string{"a", "b", "c"},
},
} }
for name, tc := range tests { for name, tc := range tests {
tc := tc tc := tc
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background() ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, "irrelevant_url"). client := &http.Client{
Return(tc.content, tc.status, tc.clientErr) Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "irrelevant_url", r.URL.String())
if tc.clientErr != nil {
return nil, tc.clientErr
}
return &http.Response{
StatusCode: tc.status,
Status: http.StatusText(tc.status),
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
}, nil
}),
}
results, err := getList(ctx, client, "irrelevant_url") results, err := getList(ctx, client, "irrelevant_url")
if tc.err != nil { if tc.err != nil {
require.Error(t, err) require.Error(t, err)
@@ -316,10 +415,7 @@ func Test_buildBlockedHostnames(t *testing.T) {
lines []string lines []string
errsString []string errsString []string
}{ }{
"nothing blocked": { "nothing blocked": {},
lines: nil,
errsString: nil,
},
"only malicious blocked": { "only malicious blocked": {
malicious: blockParams{ malicious: blockParams{
blocked: true, blocked: true,
@@ -329,7 +425,6 @@ func Test_buildBlockedHostnames(t *testing.T) {
lines: []string{ lines: []string{
" local-zone: \"site_a\" static", " local-zone: \"site_a\" static",
" local-zone: \"site_b\" static"}, " local-zone: \"site_b\" static"},
errsString: nil,
}, },
"all blocked with some duplicates": { "all blocked with some duplicates": {
malicious: blockParams{ malicious: blockParams{
@@ -348,7 +443,6 @@ func Test_buildBlockedHostnames(t *testing.T) {
" local-zone: \"site_a\" static", " local-zone: \"site_a\" static",
" local-zone: \"site_b\" static", " local-zone: \"site_b\" static",
" local-zone: \"site_c\" static"}, " local-zone: \"site_c\" static"},
errsString: nil,
}, },
"all blocked with one errored": { "all blocked with one errored": {
malicious: blockParams{ malicious: blockParams{
@@ -367,7 +461,9 @@ func Test_buildBlockedHostnames(t *testing.T) {
" local-zone: \"site_a\" static", " local-zone: \"site_a\" static",
" local-zone: \"site_b\" static", " local-zone: \"site_b\" static",
" local-zone: \"site_c\" static"}, " local-zone: \"site_c\" static"},
errsString: []string{"surveillance error"}, errsString: []string{
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance error`,
},
}, },
"blocked with allowed hostnames": { "blocked with allowed hostnames": {
malicious: blockParams{ malicious: blockParams{
@@ -384,34 +480,71 @@ func Test_buildBlockedHostnames(t *testing.T) {
" local-zone: \"site_d\" static"}, " local-zone: \"site_d\" static"},
}, },
} }
for name, tc := range tests { //nolint:dupl for name, tc := range tests {
tc := tc tc := tc
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background() ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
clientCalls := map[models.URL]int{}
if tc.malicious.blocked { if tc.malicious.blocked {
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)). clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
Return(tc.malicious.content, 200, tc.malicious.clientErr)
} }
if tc.ads.blocked { if tc.ads.blocked {
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)). clientCalls[constants.AdsBlockListHostnamesURL] = 0
Return(tc.ads.content, 200, tc.ads.clientErr)
} }
if tc.surveillance.blocked { if tc.surveillance.blocked {
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)). clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
} }
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
url := models.URL(r.URL.String())
if _, ok := clientCalls[url]; !ok {
t.Errorf("unknown URL %q", url)
return nil, nil
}
clientCalls[url]++
var body []byte
var err error
switch url {
case constants.MaliciousBlockListHostnamesURL:
body = tc.malicious.content
err = tc.malicious.clientErr
case constants.AdsBlockListHostnamesURL:
body = tc.ads.content
err = tc.ads.clientErr
case constants.SurveillanceBlockListHostnamesURL:
body = tc.surveillance.content
err = tc.surveillance.clientErr
default: // just in case if the test is badly written
t.Errorf("unknown URL %q", url)
return nil, nil
}
if err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(body)),
}, nil
}),
}
lines, errs := buildBlockedHostnames(ctx, client, lines, errs := buildBlockedHostnames(ctx, client,
tc.malicious.blocked, tc.ads.blocked, tc.malicious.blocked, tc.ads.blocked,
tc.surveillance.blocked, tc.allowedHostnames) tc.surveillance.blocked, tc.allowedHostnames)
var errsString []string var errsString []string
for _, err := range errs { for _, err := range errs {
errsString = append(errsString, err.Error()) errsString = append(errsString, err.Error())
} }
assert.ElementsMatch(t, tc.errsString, errsString) assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.lines, lines) assert.ElementsMatch(t, tc.lines, lines)
for url, count := range clientCalls {
assert.Equalf(t, 1, count, "for url %q", url)
}
}) })
} }
} }
@@ -431,10 +564,7 @@ func Test_buildBlockedIPs(t *testing.T) {
lines []string lines []string
errsString []string errsString []string
}{ }{
"nothing blocked": { "nothing blocked": {},
lines: nil,
errsString: nil,
},
"only malicious blocked": { "only malicious blocked": {
malicious: blockParams{ malicious: blockParams{
blocked: true, blocked: true,
@@ -444,7 +574,6 @@ func Test_buildBlockedIPs(t *testing.T) {
lines: []string{ lines: []string{
" private-address: site_a", " private-address: site_a",
" private-address: site_b"}, " private-address: site_b"},
errsString: nil,
}, },
"all blocked with some duplicates": { "all blocked with some duplicates": {
malicious: blockParams{ malicious: blockParams{
@@ -463,7 +592,6 @@ func Test_buildBlockedIPs(t *testing.T) {
" private-address: site_a", " private-address: site_a",
" private-address: site_b", " private-address: site_b",
" private-address: site_c"}, " private-address: site_c"},
errsString: nil,
}, },
"all blocked with one errored": { "all blocked with one errored": {
malicious: blockParams{ malicious: blockParams{
@@ -482,7 +610,9 @@ func Test_buildBlockedIPs(t *testing.T) {
" private-address: site_a", " private-address: site_a",
" private-address: site_b", " private-address: site_b",
" private-address: site_c"}, " private-address: site_c"},
errsString: []string{"surveillance error"}, errsString: []string{
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance error`,
},
}, },
"blocked with private addresses": { "blocked with private addresses": {
malicious: blockParams{ malicious: blockParams{
@@ -501,34 +631,72 @@ func Test_buildBlockedIPs(t *testing.T) {
" private-address: site_d"}, " private-address: site_d"},
}, },
} }
for name, tc := range tests { //nolint:dupl for name, tc := range tests {
tc := tc tc := tc
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background() ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
clientCalls := map[models.URL]int{}
if tc.malicious.blocked { if tc.malicious.blocked {
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)). clientCalls[constants.MaliciousBlockListIPsURL] = 0
Return(tc.malicious.content, 200, tc.malicious.clientErr)
} }
if tc.ads.blocked { if tc.ads.blocked {
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)). clientCalls[constants.AdsBlockListIPsURL] = 0
Return(tc.ads.content, 200, tc.ads.clientErr)
} }
if tc.surveillance.blocked { if tc.surveillance.blocked {
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)). clientCalls[constants.SurveillanceBlockListIPsURL] = 0
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
} }
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
url := models.URL(r.URL.String())
if _, ok := clientCalls[url]; !ok {
t.Errorf("unknown URL %q", url)
return nil, nil
}
clientCalls[url]++
var body []byte
var err error
switch url {
case constants.MaliciousBlockListIPsURL:
body = tc.malicious.content
err = tc.malicious.clientErr
case constants.AdsBlockListIPsURL:
body = tc.ads.content
err = tc.ads.clientErr
case constants.SurveillanceBlockListIPsURL:
body = tc.surveillance.content
err = tc.surveillance.clientErr
default: // just in case if the test is badly written
t.Errorf("unknown URL %q", url)
return nil, nil
}
if err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(body)),
}, nil
}),
}
lines, errs := buildBlockedIPs(ctx, client, lines, errs := buildBlockedIPs(ctx, client,
tc.malicious.blocked, tc.ads.blocked, tc.malicious.blocked, tc.ads.blocked,
tc.surveillance.blocked, tc.privateAddresses) tc.surveillance.blocked, tc.privateAddresses)
var errsString []string var errsString []string
for _, err := range errs { for _, err := range errs {
errsString = append(errsString, err.Error()) errsString = append(errsString, err.Error())
} }
assert.ElementsMatch(t, tc.errsString, errsString) assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.lines, lines) assert.ElementsMatch(t, tc.lines, lines)
for url, count := range clientCalls {
assert.Equalf(t, 1, count, "for url %q", url)
}
}) })
} }
} }

View File

@@ -4,12 +4,12 @@ import (
"context" "context"
"io" "io"
"net" "net"
"net/http"
"github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
) )
type Configurator interface { type Configurator interface {
@@ -25,17 +25,17 @@ type Configurator interface {
type configurator struct { type configurator struct {
logger logging.Logger logger logging.Logger
client network.Client client *http.Client
openFile os.OpenFileFunc openFile os.OpenFileFunc
commander command.Commander commander command.Commander
lookupIP func(host string) ([]net.IP, error) lookupIP func(host string) ([]net.IP, error)
} }
func NewConfigurator(logger logging.Logger, client network.Client, func NewConfigurator(logger logging.Logger, httpClient *http.Client,
openFile os.OpenFileFunc) Configurator { openFile os.OpenFileFunc) Configurator {
return &configurator{ return &configurator{
logger: logger.WithPrefix("dns configurator: "), logger: logger.WithPrefix("dns configurator: "),
client: client, client: httpClient,
openFile: openFile, openFile: openFile,
commander: command.NewCommander(), commander: command.NewCommander(),
lookupIP: net.LookupIP, lookupIP: net.LookupIP,

8
internal/dns/errors.go Normal file
View File

@@ -0,0 +1,8 @@
package dns
import "errors"
var (
ErrBadStatusCode = errors.New("bad HTTP status")
ErrCannotReadBody = errors.New("cannot read response body")
)

View File

@@ -3,6 +3,7 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net/http" "net/http"
"os" "os"
@@ -21,11 +22,19 @@ func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error { func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error {
c.logger.Info("downloading %s from %s", logName, url) c.logger.Info("downloading %s from %s", logName, url)
content, status, err := c.client.Get(ctx, url) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return err return err
} else if status != http.StatusOK { }
return fmt.Errorf("HTTP status code is %d for %s", status, url)
response, err := c.client.Do(req)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
} }
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400) file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
@@ -33,7 +42,7 @@ func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepa
return err return err
} }
_, err = file.Write(content) _, err = io.Copy(file, response.Body)
if err != nil { if err != nil {
_ = file.Close() _ = file.Close()
return err return err

View File

@@ -1,9 +1,11 @@
package dns package dns
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"testing" "testing"
@@ -12,14 +14,15 @@ import (
"github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/os/mock_os" "github.com/qdm12/gluetun/internal/os/mock_os"
"github.com/qdm12/golibs/logging/mock_logging" "github.com/qdm12/golibs/logging/mock_logging"
"github.com/qdm12/golibs/network/mock_network"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func Test_downloadAndSave(t *testing.T) { func Test_downloadAndSave(t *testing.T) {
t.Parallel() t.Parallel()
const defaultURL = "https://test.com"
tests := map[string]struct { tests := map[string]struct {
url string // to trigger a new request error
content []byte content []byte
status int status int
clientErr error clientErr error
@@ -30,37 +33,39 @@ func Test_downloadAndSave(t *testing.T) {
err error err error
}{ }{
"no data": { "no data": {
url: defaultURL,
status: http.StatusOK, status: http.StatusOK,
}, },
"bad status": { "bad status": {
url: defaultURL,
status: http.StatusBadRequest, status: http.StatusBadRequest,
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/named.root.updated"), //nolint:lll err: fmt.Errorf("bad HTTP status from %s: Bad Request", defaultURL),
}, },
"client error": { "client error": {
url: defaultURL,
clientErr: fmt.Errorf("error"), clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("Get %q: error", defaultURL),
}, },
"open error": { "open error": {
url: defaultURL,
status: http.StatusOK, status: http.StatusOK,
openErr: fmt.Errorf("error"), openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"write error": {
status: http.StatusOK,
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"chown error": { "chown error": {
url: defaultURL,
status: http.StatusOK, status: http.StatusOK,
chownErr: fmt.Errorf("error"), chownErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"close error": { "close error": {
url: defaultURL,
status: http.StatusOK, status: http.StatusOK,
closeErr: fmt.Errorf("error"), closeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"data": { "data": {
url: defaultURL,
content: []byte("content"), content: []byte("content"),
status: http.StatusOK, status: http.StatusOK,
}, },
@@ -73,30 +78,46 @@ func Test_downloadAndSave(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL)) logger.EXPECT().Info("downloading %s from %s", "root hints", tc.url)
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.NamedRootURL)). client := &http.Client{
Return(tc.content, tc.status, tc.clientErr) Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, tc.url, r.URL.String())
if tc.clientErr != nil {
return nil, tc.clientErr
}
return &http.Response{
StatusCode: tc.status,
Status: http.StatusText(tc.status),
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
}, nil
}),
}
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) { openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
return nil, nil return nil, nil
} }
const filepath = "/test"
if tc.clientErr == nil && tc.status == http.StatusOK { if tc.clientErr == nil && tc.status == http.StatusOK {
file := mock_os.NewMockFile(mockCtrl) file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil { if tc.openErr == nil {
writeCall := file.EXPECT().Write(tc.content). if len(tc.content) > 0 {
Return(0, tc.writeErr) file.EXPECT().
if tc.writeErr != nil { Write(tc.content).
file.EXPECT().Close().Return(tc.closeErr).After(writeCall) Return(len(tc.content), tc.writeErr)
} else {
chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall)
file.EXPECT().Close().Return(tc.closeErr).After(chownCall)
} }
file.EXPECT().
Close().
Return(tc.closeErr)
file.EXPECT().
Chown(1000, 1000).
Return(tc.chownErr)
} }
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) { openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.RootHints), name) assert.Equal(t, filepath, name)
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag) assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
assert.Equal(t, os.FileMode(0400), perm) assert.Equal(t, os.FileMode(0400), perm)
return file, tc.openErr return file, tc.openErr
@@ -110,7 +131,7 @@ func Test_downloadAndSave(t *testing.T) {
} }
err := c.downloadAndSave(ctx, "root hints", err := c.downloadAndSave(ctx, "root hints",
string(constants.NamedRootURL), string(constants.RootHints), tc.url, filepath,
1000, 1000) 1000, 1000)
if tc.err != nil { if tc.err != nil {
@@ -130,9 +151,13 @@ func Test_DownloadRootHints(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL)) logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.NamedRootURL)). client := &http.Client{
Return(nil, http.StatusOK, errors.New("test")) Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, string(constants.NamedRootURL), r.URL.String())
return nil, errors.New("test")
}),
}
c := &configurator{ c := &configurator{
logger: logger, logger: logger,
@@ -141,7 +166,7 @@ func Test_DownloadRootHints(t *testing.T) {
err := c.DownloadRootHints(ctx, 1000, 1000) err := c.DownloadRootHints(ctx, 1000, 1000)
require.Error(t, err) require.Error(t, err)
assert.Equal(t, "test", err.Error()) assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated": test`, err.Error())
} }
func Test_DownloadRootKey(t *testing.T) { func Test_DownloadRootKey(t *testing.T) {
@@ -151,9 +176,13 @@ func Test_DownloadRootKey(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL)) logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.RootKeyURL)). client := &http.Client{
Return(nil, http.StatusOK, errors.New("test")) Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, string(constants.RootKeyURL), r.URL.String())
return nil, errors.New("test")
}),
}
c := &configurator{ c := &configurator{
logger: logger, logger: logger,
@@ -162,5 +191,5 @@ func Test_DownloadRootKey(t *testing.T) {
err := c.DownloadRootKey(ctx, 1000, 1000) err := c.DownloadRootKey(ctx, 1000, 1000)
require.Error(t, err) require.Error(t, err)
assert.Equal(t, "test", err.Error()) assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated": test`, err.Error())
} }

View File

@@ -0,0 +1,9 @@
package dns
import "net/http"
type roundTripFunc func(r *http.Request) (*http.Response, error)
func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return s(r)
}

View File

@@ -0,0 +1,8 @@
package publicip
import "errors"
var (
ErrBadStatusCode = errors.New("bad HTTP status")
ErrCannotReadBody = errors.New("cannot read response body")
)

View File

@@ -3,6 +3,7 @@ package publicip
import ( import (
"context" "context"
"net" "net"
"net/http"
"sync" "sync"
"time" "time"
@@ -11,7 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
) )
type Looper interface { type Looper interface {
@@ -45,7 +45,7 @@ type looper struct {
timeSince func(time.Time) time.Duration timeSince func(time.Time) time.Duration
} }
func NewLooper(client network.Client, logger logging.Logger, func NewLooper(client *http.Client, logger logging.Logger,
settings settings.PublicIP, uid, gid int, settings settings.PublicIP, uid, gid int,
os os.OS) Looper { os os.OS) Looper {
return &looper{ return &looper{

View File

@@ -3,12 +3,11 @@ package publicip
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
"strings" "strings"
"github.com/qdm12/golibs/network"
) )
type IPGetter interface { type IPGetter interface {
@@ -16,11 +15,11 @@ type IPGetter interface {
} }
type ipGetter struct { type ipGetter struct {
client network.Client client *http.Client
randIntn func(n int) int randIntn func(n int) int
} }
func NewIPGetter(client network.Client) IPGetter { func NewIPGetter(client *http.Client) IPGetter {
return &ipGetter{ return &ipGetter{
client: client, client: client,
randIntn: rand.Intn, randIntn: rand.Intn,
@@ -39,12 +38,27 @@ func (i *ipGetter) Get(ctx context.Context) (ip net.IP, err error) {
"https://ipinfo.io/ip", "https://ipinfo.io/ip",
} }
url := urls[i.randIntn(len(urls))] url := urls[i.randIntn(len(urls))]
content, status, err := i.client.Get(ctx, url, network.UseRandomUserAgent())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != http.StatusOK {
return nil, fmt.Errorf("received unexpected status code %d from %s", status, url)
} }
response, err := i.client.Do(req)
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
}
content, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err)
}
s := strings.ReplaceAll(string(content), "\n", "") s := strings.ReplaceAll(string(content), "\n", "")
ip = net.ParseIP(s) ip = net.ParseIP(s)
if ip == nil { if ip == nil {