Code maintenance: use native Go HTTP client
This commit is contained in:
@@ -36,7 +36,6 @@ import (
|
|||||||
versionpkg "github.com/qdm12/gluetun/internal/version"
|
versionpkg "github.com/qdm12/gluetun/internal/version"
|
||||||
"github.com/qdm12/golibs/command"
|
"github.com/qdm12/golibs/command"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
"github.com/qdm12/golibs/network"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//nolint:gochecknoglobals
|
//nolint:gochecknoglobals
|
||||||
@@ -89,11 +88,10 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
|||||||
|
|
||||||
const clientTimeout = 15 * time.Second
|
const clientTimeout = 15 * time.Second
|
||||||
httpClient := &http.Client{Timeout: clientTimeout}
|
httpClient := &http.Client{Timeout: clientTimeout}
|
||||||
client := network.NewClient(clientTimeout)
|
|
||||||
// Create configurators
|
// Create configurators
|
||||||
alpineConf := alpine.NewConfigurator(os.OpenFile, osUser)
|
alpineConf := alpine.NewConfigurator(os.OpenFile, osUser)
|
||||||
ovpnConf := openvpn.NewConfigurator(logger, os, unix)
|
ovpnConf := openvpn.NewConfigurator(logger, os, unix)
|
||||||
dnsConf := dns.NewConfigurator(logger, client, os.OpenFile)
|
dnsConf := dns.NewConfigurator(logger, httpClient, os.OpenFile)
|
||||||
routingConf := routing.NewRouting(logger)
|
routingConf := routing.NewRouting(logger)
|
||||||
firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile)
|
firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile)
|
||||||
streamMerger := command.NewStreamMerger()
|
streamMerger := command.NewStreamMerger()
|
||||||
@@ -253,7 +251,7 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
|||||||
go unboundLooper.Run(ctx, wg, signalDNSReady)
|
go unboundLooper.Run(ctx, wg, signalDNSReady)
|
||||||
|
|
||||||
publicIPLooper := publicip.NewLooper(
|
publicIPLooper := publicip.NewLooper(
|
||||||
client, logger, allSettings.PublicIP, uid, gid, os)
|
httpClient, logger, allSettings.PublicIP, uid, gid, os)
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go publicIPLooper.Run(ctx, wg)
|
go publicIPLooper.Run(ctx, wg)
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,7 +12,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/os"
|
"github.com/qdm12/gluetun/internal/os"
|
||||||
"github.com/qdm12/gluetun/internal/settings"
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
"github.com/qdm12/golibs/network"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS,
|
func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS,
|
||||||
@@ -48,7 +48,7 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN
|
|||||||
|
|
||||||
// MakeUnboundConf generates an Unbound configuration from the user provided settings.
|
// MakeUnboundConf generates an Unbound configuration from the user provided settings.
|
||||||
func generateUnboundConf(ctx context.Context, settings settings.DNS, username string,
|
func generateUnboundConf(ctx context.Context, settings settings.DNS, username string,
|
||||||
client network.Client, logger logging.Logger) (
|
client *http.Client, logger logging.Logger) (
|
||||||
lines []string, warnings []error) {
|
lines []string, warnings []error) {
|
||||||
doIPv6 := "no"
|
doIPv6 := "no"
|
||||||
if settings.IPv6 {
|
if settings.IPv6 {
|
||||||
@@ -151,7 +151,7 @@ func generateUnboundConf(ctx context.Context, settings settings.DNS, username st
|
|||||||
return lines, warnings
|
return lines, warnings
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildBlocked(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
func buildBlocked(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||||
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
|
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
|
||||||
chHostnames := make(chan []string)
|
chHostnames := make(chan []string)
|
||||||
chIPs := make(chan []string)
|
chIPs := make(chan []string)
|
||||||
@@ -181,13 +181,27 @@ func buildBlocked(ctx context.Context, client network.Client, blockMalicious, bl
|
|||||||
return hostnamesLines, ipsLines, errs
|
return hostnamesLines, ipsLines, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
func getList(ctx context.Context, client network.Client, url string) (results []string, err error) {
|
func getList(ctx context.Context, client *http.Client, url string) (results []string, err error) {
|
||||||
content, status, err := client.Get(ctx, url)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if status != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("HTTP status code is %d and not 200", status)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
response, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
if response.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err)
|
||||||
|
}
|
||||||
|
|
||||||
results = strings.Split(string(content), "\n")
|
results = strings.Split(string(content), "\n")
|
||||||
|
|
||||||
// remove empty lines
|
// remove empty lines
|
||||||
@@ -206,7 +220,7 @@ func getList(ctx context.Context, client network.Client, url string) (results []
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildBlockedHostnames(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
func buildBlockedHostnames(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||||
allowedHostnames []string) (lines []string, errs []error) {
|
allowedHostnames []string) (lines []string, errs []error) {
|
||||||
chResults := make(chan []string)
|
chResults := make(chan []string)
|
||||||
chError := make(chan error)
|
chError := make(chan error)
|
||||||
@@ -258,7 +272,7 @@ func buildBlockedHostnames(ctx context.Context, client network.Client, blockMali
|
|||||||
return lines, errs
|
return lines, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildBlockedIPs(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
func buildBlockedIPs(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||||
privateAddresses []string) (lines []string, errs []error) {
|
privateAddresses []string) (lines []string, errs []error) {
|
||||||
chResults := make(chan []string)
|
chResults := make(chan []string)
|
||||||
chError := make(chan error)
|
chError := make(chan error)
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -11,7 +14,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/settings"
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
"github.com/qdm12/golibs/logging/mock_logging"
|
"github.com/qdm12/golibs/logging/mock_logging"
|
||||||
"github.com/qdm12/golibs/network/mock_network"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -32,17 +34,45 @@ func Test_generateUnboundConf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mockCtrl := gomock.NewController(t)
|
mockCtrl := gomock.NewController(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client := mock_network.NewMockClient(mockCtrl)
|
|
||||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
clientCalls := map[models.URL]int{
|
||||||
Return([]byte("b\na\nc"), 200, nil)
|
constants.MaliciousBlockListIPsURL: 0,
|
||||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
constants.MaliciousBlockListHostnamesURL: 0,
|
||||||
Return([]byte("c\nd\n"), 200, nil)
|
}
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
url := models.URL(r.URL.String())
|
||||||
|
if _, ok := clientCalls[url]; !ok {
|
||||||
|
t.Errorf("unknown URL %q", url)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
clientCalls[url]++
|
||||||
|
var body string
|
||||||
|
switch url {
|
||||||
|
case constants.MaliciousBlockListIPsURL:
|
||||||
|
body = "c\nd"
|
||||||
|
case constants.MaliciousBlockListHostnamesURL:
|
||||||
|
body = "b\na\nc"
|
||||||
|
default:
|
||||||
|
t.Errorf("unknown URL %q", url)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: ioutil.NopCloser(strings.NewReader(body)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||||
logger.EXPECT().Info("%d hostnames blocked overall", 2)
|
logger.EXPECT().Info("%d hostnames blocked overall", 2)
|
||||||
logger.EXPECT().Info("%d IP addresses blocked overall", 3)
|
logger.EXPECT().Info("%d IP addresses blocked overall", 3)
|
||||||
lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger)
|
lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger)
|
||||||
require.Len(t, warnings, 0)
|
require.Len(t, warnings, 0)
|
||||||
expected := `
|
for url, count := range clientCalls {
|
||||||
|
assert.Equalf(t, 1, count, "for url %q", url)
|
||||||
|
}
|
||||||
|
const expected = `
|
||||||
server:
|
server:
|
||||||
cache-max-ttl: 9000
|
cache-max-ttl: 9000
|
||||||
cache-min-ttl: 3600
|
cache-min-ttl: 3600
|
||||||
@@ -209,7 +239,10 @@ func Test_buildBlocked(t *testing.T) {
|
|||||||
ipsLines: []string{
|
ipsLines: []string{
|
||||||
" private-address: malicious",
|
" private-address: malicious",
|
||||||
" private-address: surveillance"},
|
" private-address: surveillance"},
|
||||||
errsString: []string{"ads error", "ads error"},
|
errsString: []string{
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads error`,
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads error`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"all blocked with errors": {
|
"all blocked with errors": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
@@ -224,37 +257,74 @@ func Test_buildBlocked(t *testing.T) {
|
|||||||
blocked: true,
|
blocked: true,
|
||||||
clientErr: fmt.Errorf("surveillance"),
|
clientErr: fmt.Errorf("surveillance"),
|
||||||
},
|
},
|
||||||
errsString: []string{"malicious", "malicious", "ads", "ads", "surveillance", "surveillance"},
|
errsString: []string{
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated": malicious`,
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated": malicious`,
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads`,
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads`,
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance`,
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for name, tc := range tests {
|
for name, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client := mock_network.NewMockClient(mockCtrl)
|
|
||||||
|
clientCalls := map[models.URL]int{}
|
||||||
if tc.malicious.blocked {
|
if tc.malicious.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
clientCalls[constants.MaliciousBlockListIPsURL] = 0
|
||||||
Return(tc.malicious.content, 200, tc.malicious.clientErr)
|
clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
|
||||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
|
||||||
Return(tc.malicious.content, 200, tc.malicious.clientErr)
|
|
||||||
}
|
}
|
||||||
if tc.ads.blocked {
|
if tc.ads.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)).
|
clientCalls[constants.AdsBlockListIPsURL] = 0
|
||||||
Return(tc.ads.content, 200, tc.ads.clientErr)
|
clientCalls[constants.AdsBlockListHostnamesURL] = 0
|
||||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
|
|
||||||
Return(tc.ads.content, 200, tc.ads.clientErr)
|
|
||||||
}
|
}
|
||||||
if tc.surveillance.blocked {
|
if tc.surveillance.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)).
|
clientCalls[constants.SurveillanceBlockListIPsURL] = 0
|
||||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
|
clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
|
||||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
|
|
||||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
url := models.URL(r.URL.String())
|
||||||
|
if _, ok := clientCalls[url]; !ok {
|
||||||
|
t.Errorf("unknown URL %q", url)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
clientCalls[url]++
|
||||||
|
var body []byte
|
||||||
|
var err error
|
||||||
|
switch url {
|
||||||
|
case constants.MaliciousBlockListIPsURL, constants.MaliciousBlockListHostnamesURL:
|
||||||
|
body = tc.malicious.content
|
||||||
|
err = tc.malicious.clientErr
|
||||||
|
case constants.AdsBlockListIPsURL, constants.AdsBlockListHostnamesURL:
|
||||||
|
body = tc.ads.content
|
||||||
|
err = tc.ads.clientErr
|
||||||
|
case constants.SurveillanceBlockListIPsURL, constants.SurveillanceBlockListHostnamesURL:
|
||||||
|
body = tc.surveillance.content
|
||||||
|
err = tc.surveillance.clientErr
|
||||||
|
default: // just in case if the test is badly written
|
||||||
|
t.Errorf("unknown URL %q", url)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
hostnamesLines, ipsLines, errs := buildBlocked(ctx, client,
|
hostnamesLines, ipsLines, errs := buildBlocked(ctx, client,
|
||||||
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
|
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
|
||||||
tc.allowedHostnames, tc.privateAddresses)
|
tc.allowedHostnames, tc.privateAddresses)
|
||||||
|
|
||||||
var errsString []string
|
var errsString []string
|
||||||
for _, err := range errs {
|
for _, err := range errs {
|
||||||
errsString = append(errsString, err.Error())
|
errsString = append(errsString, err.Error())
|
||||||
@@ -262,6 +332,10 @@ func Test_buildBlocked(t *testing.T) {
|
|||||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||||
assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines)
|
assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines)
|
||||||
assert.ElementsMatch(t, tc.ipsLines, ipsLines)
|
assert.ElementsMatch(t, tc.ipsLines, ipsLines)
|
||||||
|
|
||||||
|
for url, count := range clientCalls {
|
||||||
|
assert.Equalf(t, 1, count, "for url %q", url)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -275,20 +349,45 @@ func Test_getList(t *testing.T) {
|
|||||||
results []string
|
results []string
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
"no result": {nil, 200, nil, nil, nil},
|
"no result": {
|
||||||
"bad status": {nil, 500, nil, nil, fmt.Errorf("HTTP status code is 500 and not 200")},
|
status: http.StatusOK,
|
||||||
"network error": {nil, 200, fmt.Errorf("error"), nil, fmt.Errorf("error")},
|
},
|
||||||
"results": {[]byte("a\nb\nc\n"), 200, nil, []string{"a", "b", "c"}, nil},
|
"bad status": {
|
||||||
|
status: http.StatusInternalServerError,
|
||||||
|
err: fmt.Errorf("bad HTTP status from irrelevant_url: Internal Server Error"),
|
||||||
|
},
|
||||||
|
"network error": {
|
||||||
|
status: http.StatusOK,
|
||||||
|
clientErr: fmt.Errorf("error"),
|
||||||
|
err: fmt.Errorf(`Get "irrelevant_url": error`),
|
||||||
|
},
|
||||||
|
"results": {
|
||||||
|
content: []byte("a\nb\nc\n"),
|
||||||
|
status: http.StatusOK,
|
||||||
|
results: []string{"a", "b", "c"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for name, tc := range tests {
|
for name, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client := mock_network.NewMockClient(mockCtrl)
|
|
||||||
client.EXPECT().Get(ctx, "irrelevant_url").
|
client := &http.Client{
|
||||||
Return(tc.content, tc.status, tc.clientErr)
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, "irrelevant_url", r.URL.String())
|
||||||
|
if tc.clientErr != nil {
|
||||||
|
return nil, tc.clientErr
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: tc.status,
|
||||||
|
Status: http.StatusText(tc.status),
|
||||||
|
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
results, err := getList(ctx, client, "irrelevant_url")
|
results, err := getList(ctx, client, "irrelevant_url")
|
||||||
if tc.err != nil {
|
if tc.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -316,10 +415,7 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
|||||||
lines []string
|
lines []string
|
||||||
errsString []string
|
errsString []string
|
||||||
}{
|
}{
|
||||||
"nothing blocked": {
|
"nothing blocked": {},
|
||||||
lines: nil,
|
|
||||||
errsString: nil,
|
|
||||||
},
|
|
||||||
"only malicious blocked": {
|
"only malicious blocked": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
blocked: true,
|
blocked: true,
|
||||||
@@ -329,7 +425,6 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
|||||||
lines: []string{
|
lines: []string{
|
||||||
" local-zone: \"site_a\" static",
|
" local-zone: \"site_a\" static",
|
||||||
" local-zone: \"site_b\" static"},
|
" local-zone: \"site_b\" static"},
|
||||||
errsString: nil,
|
|
||||||
},
|
},
|
||||||
"all blocked with some duplicates": {
|
"all blocked with some duplicates": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
@@ -348,7 +443,6 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
|||||||
" local-zone: \"site_a\" static",
|
" local-zone: \"site_a\" static",
|
||||||
" local-zone: \"site_b\" static",
|
" local-zone: \"site_b\" static",
|
||||||
" local-zone: \"site_c\" static"},
|
" local-zone: \"site_c\" static"},
|
||||||
errsString: nil,
|
|
||||||
},
|
},
|
||||||
"all blocked with one errored": {
|
"all blocked with one errored": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
@@ -367,7 +461,9 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
|||||||
" local-zone: \"site_a\" static",
|
" local-zone: \"site_a\" static",
|
||||||
" local-zone: \"site_b\" static",
|
" local-zone: \"site_b\" static",
|
||||||
" local-zone: \"site_c\" static"},
|
" local-zone: \"site_c\" static"},
|
||||||
errsString: []string{"surveillance error"},
|
errsString: []string{
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance error`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"blocked with allowed hostnames": {
|
"blocked with allowed hostnames": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
@@ -384,34 +480,71 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
|||||||
" local-zone: \"site_d\" static"},
|
" local-zone: \"site_d\" static"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for name, tc := range tests { //nolint:dupl
|
for name, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client := mock_network.NewMockClient(mockCtrl)
|
|
||||||
|
clientCalls := map[models.URL]int{}
|
||||||
if tc.malicious.blocked {
|
if tc.malicious.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
|
||||||
Return(tc.malicious.content, 200, tc.malicious.clientErr)
|
|
||||||
}
|
}
|
||||||
if tc.ads.blocked {
|
if tc.ads.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)).
|
clientCalls[constants.AdsBlockListHostnamesURL] = 0
|
||||||
Return(tc.ads.content, 200, tc.ads.clientErr)
|
|
||||||
}
|
}
|
||||||
if tc.surveillance.blocked {
|
if tc.surveillance.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)).
|
clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
|
||||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
url := models.URL(r.URL.String())
|
||||||
|
if _, ok := clientCalls[url]; !ok {
|
||||||
|
t.Errorf("unknown URL %q", url)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
clientCalls[url]++
|
||||||
|
var body []byte
|
||||||
|
var err error
|
||||||
|
switch url {
|
||||||
|
case constants.MaliciousBlockListHostnamesURL:
|
||||||
|
body = tc.malicious.content
|
||||||
|
err = tc.malicious.clientErr
|
||||||
|
case constants.AdsBlockListHostnamesURL:
|
||||||
|
body = tc.ads.content
|
||||||
|
err = tc.ads.clientErr
|
||||||
|
case constants.SurveillanceBlockListHostnamesURL:
|
||||||
|
body = tc.surveillance.content
|
||||||
|
err = tc.surveillance.clientErr
|
||||||
|
default: // just in case if the test is badly written
|
||||||
|
t.Errorf("unknown URL %q", url)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
lines, errs := buildBlockedHostnames(ctx, client,
|
lines, errs := buildBlockedHostnames(ctx, client,
|
||||||
tc.malicious.blocked, tc.ads.blocked,
|
tc.malicious.blocked, tc.ads.blocked,
|
||||||
tc.surveillance.blocked, tc.allowedHostnames)
|
tc.surveillance.blocked, tc.allowedHostnames)
|
||||||
|
|
||||||
var errsString []string
|
var errsString []string
|
||||||
for _, err := range errs {
|
for _, err := range errs {
|
||||||
errsString = append(errsString, err.Error())
|
errsString = append(errsString, err.Error())
|
||||||
}
|
}
|
||||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||||
assert.ElementsMatch(t, tc.lines, lines)
|
assert.ElementsMatch(t, tc.lines, lines)
|
||||||
|
|
||||||
|
for url, count := range clientCalls {
|
||||||
|
assert.Equalf(t, 1, count, "for url %q", url)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -431,10 +564,7 @@ func Test_buildBlockedIPs(t *testing.T) {
|
|||||||
lines []string
|
lines []string
|
||||||
errsString []string
|
errsString []string
|
||||||
}{
|
}{
|
||||||
"nothing blocked": {
|
"nothing blocked": {},
|
||||||
lines: nil,
|
|
||||||
errsString: nil,
|
|
||||||
},
|
|
||||||
"only malicious blocked": {
|
"only malicious blocked": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
blocked: true,
|
blocked: true,
|
||||||
@@ -444,7 +574,6 @@ func Test_buildBlockedIPs(t *testing.T) {
|
|||||||
lines: []string{
|
lines: []string{
|
||||||
" private-address: site_a",
|
" private-address: site_a",
|
||||||
" private-address: site_b"},
|
" private-address: site_b"},
|
||||||
errsString: nil,
|
|
||||||
},
|
},
|
||||||
"all blocked with some duplicates": {
|
"all blocked with some duplicates": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
@@ -463,7 +592,6 @@ func Test_buildBlockedIPs(t *testing.T) {
|
|||||||
" private-address: site_a",
|
" private-address: site_a",
|
||||||
" private-address: site_b",
|
" private-address: site_b",
|
||||||
" private-address: site_c"},
|
" private-address: site_c"},
|
||||||
errsString: nil,
|
|
||||||
},
|
},
|
||||||
"all blocked with one errored": {
|
"all blocked with one errored": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
@@ -482,7 +610,9 @@ func Test_buildBlockedIPs(t *testing.T) {
|
|||||||
" private-address: site_a",
|
" private-address: site_a",
|
||||||
" private-address: site_b",
|
" private-address: site_b",
|
||||||
" private-address: site_c"},
|
" private-address: site_c"},
|
||||||
errsString: []string{"surveillance error"},
|
errsString: []string{
|
||||||
|
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance error`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"blocked with private addresses": {
|
"blocked with private addresses": {
|
||||||
malicious: blockParams{
|
malicious: blockParams{
|
||||||
@@ -501,34 +631,72 @@ func Test_buildBlockedIPs(t *testing.T) {
|
|||||||
" private-address: site_d"},
|
" private-address: site_d"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for name, tc := range tests { //nolint:dupl
|
for name, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client := mock_network.NewMockClient(mockCtrl)
|
|
||||||
|
clientCalls := map[models.URL]int{}
|
||||||
if tc.malicious.blocked {
|
if tc.malicious.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
clientCalls[constants.MaliciousBlockListIPsURL] = 0
|
||||||
Return(tc.malicious.content, 200, tc.malicious.clientErr)
|
|
||||||
}
|
}
|
||||||
if tc.ads.blocked {
|
if tc.ads.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
|
clientCalls[constants.AdsBlockListIPsURL] = 0
|
||||||
Return(tc.ads.content, 200, tc.ads.clientErr)
|
|
||||||
}
|
}
|
||||||
if tc.surveillance.blocked {
|
if tc.surveillance.blocked {
|
||||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
|
clientCalls[constants.SurveillanceBlockListIPsURL] = 0
|
||||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
url := models.URL(r.URL.String())
|
||||||
|
if _, ok := clientCalls[url]; !ok {
|
||||||
|
t.Errorf("unknown URL %q", url)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
clientCalls[url]++
|
||||||
|
var body []byte
|
||||||
|
var err error
|
||||||
|
switch url {
|
||||||
|
case constants.MaliciousBlockListIPsURL:
|
||||||
|
body = tc.malicious.content
|
||||||
|
err = tc.malicious.clientErr
|
||||||
|
case constants.AdsBlockListIPsURL:
|
||||||
|
body = tc.ads.content
|
||||||
|
err = tc.ads.clientErr
|
||||||
|
case constants.SurveillanceBlockListIPsURL:
|
||||||
|
body = tc.surveillance.content
|
||||||
|
err = tc.surveillance.clientErr
|
||||||
|
default: // just in case if the test is badly written
|
||||||
|
t.Errorf("unknown URL %q", url)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
lines, errs := buildBlockedIPs(ctx, client,
|
lines, errs := buildBlockedIPs(ctx, client,
|
||||||
tc.malicious.blocked, tc.ads.blocked,
|
tc.malicious.blocked, tc.ads.blocked,
|
||||||
tc.surveillance.blocked, tc.privateAddresses)
|
tc.surveillance.blocked, tc.privateAddresses)
|
||||||
|
|
||||||
var errsString []string
|
var errsString []string
|
||||||
for _, err := range errs {
|
for _, err := range errs {
|
||||||
errsString = append(errsString, err.Error())
|
errsString = append(errsString, err.Error())
|
||||||
}
|
}
|
||||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||||
assert.ElementsMatch(t, tc.lines, lines)
|
assert.ElementsMatch(t, tc.lines, lines)
|
||||||
|
|
||||||
|
for url, count := range clientCalls {
|
||||||
|
assert.Equalf(t, 1, count, "for url %q", url)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/os"
|
"github.com/qdm12/gluetun/internal/os"
|
||||||
"github.com/qdm12/gluetun/internal/settings"
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
"github.com/qdm12/golibs/command"
|
"github.com/qdm12/golibs/command"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
"github.com/qdm12/golibs/network"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Configurator interface {
|
type Configurator interface {
|
||||||
@@ -25,17 +25,17 @@ type Configurator interface {
|
|||||||
|
|
||||||
type configurator struct {
|
type configurator struct {
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
client network.Client
|
client *http.Client
|
||||||
openFile os.OpenFileFunc
|
openFile os.OpenFileFunc
|
||||||
commander command.Commander
|
commander command.Commander
|
||||||
lookupIP func(host string) ([]net.IP, error)
|
lookupIP func(host string) ([]net.IP, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigurator(logger logging.Logger, client network.Client,
|
func NewConfigurator(logger logging.Logger, httpClient *http.Client,
|
||||||
openFile os.OpenFileFunc) Configurator {
|
openFile os.OpenFileFunc) Configurator {
|
||||||
return &configurator{
|
return &configurator{
|
||||||
logger: logger.WithPrefix("dns configurator: "),
|
logger: logger.WithPrefix("dns configurator: "),
|
||||||
client: client,
|
client: httpClient,
|
||||||
openFile: openFile,
|
openFile: openFile,
|
||||||
commander: command.NewCommander(),
|
commander: command.NewCommander(),
|
||||||
lookupIP: net.LookupIP,
|
lookupIP: net.LookupIP,
|
||||||
|
|||||||
8
internal/dns/errors.go
Normal file
8
internal/dns/errors.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBadStatusCode = errors.New("bad HTTP status")
|
||||||
|
ErrCannotReadBody = errors.New("cannot read response body")
|
||||||
|
)
|
||||||
@@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -21,11 +22,19 @@ func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error
|
|||||||
|
|
||||||
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error {
|
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error {
|
||||||
c.logger.Info("downloading %s from %s", logName, url)
|
c.logger.Info("downloading %s from %s", logName, url)
|
||||||
content, status, err := c.client.Get(ctx, url)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if status != http.StatusOK {
|
}
|
||||||
return fmt.Errorf("HTTP status code is %d for %s", status, url)
|
|
||||||
|
response, err := c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
if response.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
|
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
|
||||||
@@ -33,7 +42,7 @@ func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepa
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = file.Write(content)
|
_, err = io.Copy(file, response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = file.Close()
|
_ = file.Close()
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -12,14 +14,15 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/os"
|
"github.com/qdm12/gluetun/internal/os"
|
||||||
"github.com/qdm12/gluetun/internal/os/mock_os"
|
"github.com/qdm12/gluetun/internal/os/mock_os"
|
||||||
"github.com/qdm12/golibs/logging/mock_logging"
|
"github.com/qdm12/golibs/logging/mock_logging"
|
||||||
"github.com/qdm12/golibs/network/mock_network"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_downloadAndSave(t *testing.T) {
|
func Test_downloadAndSave(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
const defaultURL = "https://test.com"
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
|
url string // to trigger a new request error
|
||||||
content []byte
|
content []byte
|
||||||
status int
|
status int
|
||||||
clientErr error
|
clientErr error
|
||||||
@@ -30,37 +33,39 @@ func Test_downloadAndSave(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
"no data": {
|
"no data": {
|
||||||
|
url: defaultURL,
|
||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
},
|
},
|
||||||
"bad status": {
|
"bad status": {
|
||||||
|
url: defaultURL,
|
||||||
status: http.StatusBadRequest,
|
status: http.StatusBadRequest,
|
||||||
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/named.root.updated"), //nolint:lll
|
err: fmt.Errorf("bad HTTP status from %s: Bad Request", defaultURL),
|
||||||
},
|
},
|
||||||
"client error": {
|
"client error": {
|
||||||
|
url: defaultURL,
|
||||||
clientErr: fmt.Errorf("error"),
|
clientErr: fmt.Errorf("error"),
|
||||||
err: fmt.Errorf("error"),
|
err: fmt.Errorf("Get %q: error", defaultURL),
|
||||||
},
|
},
|
||||||
"open error": {
|
"open error": {
|
||||||
|
url: defaultURL,
|
||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
openErr: fmt.Errorf("error"),
|
openErr: fmt.Errorf("error"),
|
||||||
err: fmt.Errorf("error"),
|
err: fmt.Errorf("error"),
|
||||||
},
|
},
|
||||||
"write error": {
|
|
||||||
status: http.StatusOK,
|
|
||||||
writeErr: fmt.Errorf("error"),
|
|
||||||
err: fmt.Errorf("error"),
|
|
||||||
},
|
|
||||||
"chown error": {
|
"chown error": {
|
||||||
|
url: defaultURL,
|
||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
chownErr: fmt.Errorf("error"),
|
chownErr: fmt.Errorf("error"),
|
||||||
err: fmt.Errorf("error"),
|
err: fmt.Errorf("error"),
|
||||||
},
|
},
|
||||||
"close error": {
|
"close error": {
|
||||||
|
url: defaultURL,
|
||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
closeErr: fmt.Errorf("error"),
|
closeErr: fmt.Errorf("error"),
|
||||||
err: fmt.Errorf("error"),
|
err: fmt.Errorf("error"),
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
|
url: defaultURL,
|
||||||
content: []byte("content"),
|
content: []byte("content"),
|
||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
},
|
},
|
||||||
@@ -73,30 +78,46 @@ func Test_downloadAndSave(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||||
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
logger.EXPECT().Info("downloading %s from %s", "root hints", tc.url)
|
||||||
client := mock_network.NewMockClient(mockCtrl)
|
|
||||||
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
|
client := &http.Client{
|
||||||
Return(tc.content, tc.status, tc.clientErr)
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, tc.url, r.URL.String())
|
||||||
|
if tc.clientErr != nil {
|
||||||
|
return nil, tc.clientErr
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: tc.status,
|
||||||
|
Status: http.StatusText(tc.status),
|
||||||
|
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const filepath = "/test"
|
||||||
|
|
||||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||||
file := mock_os.NewMockFile(mockCtrl)
|
file := mock_os.NewMockFile(mockCtrl)
|
||||||
if tc.openErr == nil {
|
if tc.openErr == nil {
|
||||||
writeCall := file.EXPECT().Write(tc.content).
|
if len(tc.content) > 0 {
|
||||||
Return(0, tc.writeErr)
|
file.EXPECT().
|
||||||
if tc.writeErr != nil {
|
Write(tc.content).
|
||||||
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
|
Return(len(tc.content), tc.writeErr)
|
||||||
} else {
|
|
||||||
chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall)
|
|
||||||
file.EXPECT().Close().Return(tc.closeErr).After(chownCall)
|
|
||||||
}
|
}
|
||||||
|
file.EXPECT().
|
||||||
|
Close().
|
||||||
|
Return(tc.closeErr)
|
||||||
|
file.EXPECT().
|
||||||
|
Chown(1000, 1000).
|
||||||
|
Return(tc.chownErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
|
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||||
assert.Equal(t, string(constants.RootHints), name)
|
assert.Equal(t, filepath, name)
|
||||||
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
|
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
|
||||||
assert.Equal(t, os.FileMode(0400), perm)
|
assert.Equal(t, os.FileMode(0400), perm)
|
||||||
return file, tc.openErr
|
return file, tc.openErr
|
||||||
@@ -110,7 +131,7 @@ func Test_downloadAndSave(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err := c.downloadAndSave(ctx, "root hints",
|
err := c.downloadAndSave(ctx, "root hints",
|
||||||
string(constants.NamedRootURL), string(constants.RootHints),
|
tc.url, filepath,
|
||||||
1000, 1000)
|
1000, 1000)
|
||||||
|
|
||||||
if tc.err != nil {
|
if tc.err != nil {
|
||||||
@@ -130,9 +151,13 @@ func Test_DownloadRootHints(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||||
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
||||||
client := mock_network.NewMockClient(mockCtrl)
|
|
||||||
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
|
client := &http.Client{
|
||||||
Return(nil, http.StatusOK, errors.New("test"))
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, string(constants.NamedRootURL), r.URL.String())
|
||||||
|
return nil, errors.New("test")
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
c := &configurator{
|
c := &configurator{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
@@ -141,7 +166,7 @@ func Test_DownloadRootHints(t *testing.T) {
|
|||||||
|
|
||||||
err := c.DownloadRootHints(ctx, 1000, 1000)
|
err := c.DownloadRootHints(ctx, 1000, 1000)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Equal(t, "test", err.Error())
|
assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated": test`, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_DownloadRootKey(t *testing.T) {
|
func Test_DownloadRootKey(t *testing.T) {
|
||||||
@@ -151,9 +176,13 @@ func Test_DownloadRootKey(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||||
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
|
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
|
||||||
client := mock_network.NewMockClient(mockCtrl)
|
|
||||||
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
|
client := &http.Client{
|
||||||
Return(nil, http.StatusOK, errors.New("test"))
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, string(constants.RootKeyURL), r.URL.String())
|
||||||
|
return nil, errors.New("test")
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
c := &configurator{
|
c := &configurator{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
@@ -162,5 +191,5 @@ func Test_DownloadRootKey(t *testing.T) {
|
|||||||
|
|
||||||
err := c.DownloadRootKey(ctx, 1000, 1000)
|
err := c.DownloadRootKey(ctx, 1000, 1000)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Equal(t, "test", err.Error())
|
assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated": test`, err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
9
internal/dns/roundtrip_test.go
Normal file
9
internal/dns/roundtrip_test.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type roundTripFunc func(r *http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
return s(r)
|
||||||
|
}
|
||||||
8
internal/publicip/errors.go
Normal file
8
internal/publicip/errors.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package publicip
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBadStatusCode = errors.New("bad HTTP status")
|
||||||
|
ErrCannotReadBody = errors.New("cannot read response body")
|
||||||
|
)
|
||||||
@@ -3,6 +3,7 @@ package publicip
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -11,7 +12,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/os"
|
"github.com/qdm12/gluetun/internal/os"
|
||||||
"github.com/qdm12/gluetun/internal/settings"
|
"github.com/qdm12/gluetun/internal/settings"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
"github.com/qdm12/golibs/network"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
@@ -45,7 +45,7 @@ type looper struct {
|
|||||||
timeSince func(time.Time) time.Duration
|
timeSince func(time.Time) time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLooper(client network.Client, logger logging.Logger,
|
func NewLooper(client *http.Client, logger logging.Logger,
|
||||||
settings settings.PublicIP, uid, gid int,
|
settings settings.PublicIP, uid, gid int,
|
||||||
os os.OS) Looper {
|
os os.OS) Looper {
|
||||||
return &looper{
|
return &looper{
|
||||||
|
|||||||
@@ -3,12 +3,11 @@ package publicip
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/golibs/network"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPGetter interface {
|
type IPGetter interface {
|
||||||
@@ -16,11 +15,11 @@ type IPGetter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ipGetter struct {
|
type ipGetter struct {
|
||||||
client network.Client
|
client *http.Client
|
||||||
randIntn func(n int) int
|
randIntn func(n int) int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIPGetter(client network.Client) IPGetter {
|
func NewIPGetter(client *http.Client) IPGetter {
|
||||||
return &ipGetter{
|
return &ipGetter{
|
||||||
client: client,
|
client: client,
|
||||||
randIntn: rand.Intn,
|
randIntn: rand.Intn,
|
||||||
@@ -39,12 +38,27 @@ func (i *ipGetter) Get(ctx context.Context) (ip net.IP, err error) {
|
|||||||
"https://ipinfo.io/ip",
|
"https://ipinfo.io/ip",
|
||||||
}
|
}
|
||||||
url := urls[i.randIntn(len(urls))]
|
url := urls[i.randIntn(len(urls))]
|
||||||
content, status, err := i.client.Get(ctx, url, network.UseRandomUserAgent())
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if status != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("received unexpected status code %d from %s", status, url)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
response, err := i.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
if response.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err)
|
||||||
|
}
|
||||||
|
|
||||||
s := strings.ReplaceAll(string(content), "\n", "")
|
s := strings.ReplaceAll(string(content), "\n", "")
|
||||||
ip = net.ParseIP(s)
|
ip = net.ParseIP(s)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user