From 8d5f2fec09242879a44977976275936b302fcba7 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 29 Dec 2020 02:55:34 +0000 Subject: [PATCH] Code maintenance: use native Go HTTP client --- cmd/gluetun/main.go | 6 +- internal/dns/conf.go | 32 +++- internal/dns/conf_test.go | 294 ++++++++++++++++++++++++++------- internal/dns/dns.go | 8 +- internal/dns/errors.go | 8 + internal/dns/roots.go | 17 +- internal/dns/roots_test.go | 87 ++++++---- internal/dns/roundtrip_test.go | 9 + internal/publicip/errors.go | 8 + internal/publicip/loop.go | 4 +- internal/publicip/publicip.go | 28 +++- 11 files changed, 379 insertions(+), 122 deletions(-) create mode 100644 internal/dns/errors.go create mode 100644 internal/dns/roundtrip_test.go create mode 100644 internal/publicip/errors.go diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 0593a6dd..f07a2e15 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -36,7 +36,6 @@ import ( versionpkg "github.com/qdm12/gluetun/internal/version" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/network" ) //nolint:gochecknoglobals @@ -89,11 +88,10 @@ func _main(background context.Context, buildInfo models.BuildInformation, const clientTimeout = 15 * time.Second httpClient := &http.Client{Timeout: clientTimeout} - client := network.NewClient(clientTimeout) // Create configurators alpineConf := alpine.NewConfigurator(os.OpenFile, osUser) ovpnConf := openvpn.NewConfigurator(logger, os, unix) - dnsConf := dns.NewConfigurator(logger, client, os.OpenFile) + dnsConf := dns.NewConfigurator(logger, httpClient, os.OpenFile) routingConf := routing.NewRouting(logger) firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile) streamMerger := command.NewStreamMerger() @@ -253,7 +251,7 @@ func _main(background context.Context, buildInfo models.BuildInformation, go unboundLooper.Run(ctx, wg, signalDNSReady) publicIPLooper := publicip.NewLooper( - client, logger, allSettings.PublicIP, uid, gid, os) + httpClient, logger, allSettings.PublicIP, uid, gid, os) wg.Add(1) go publicIPLooper.Run(ctx, wg) wg.Add(1) diff --git a/internal/dns/conf.go b/internal/dns/conf.go index 3ab1b79a..d9f51e9b 100644 --- a/internal/dns/conf.go +++ b/internal/dns/conf.go @@ -3,6 +3,7 @@ package dns import ( "context" "fmt" + "io/ioutil" "net/http" "sort" "strings" @@ -11,7 +12,6 @@ import ( "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/network" ) func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS, @@ -48,7 +48,7 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN // MakeUnboundConf generates an Unbound configuration from the user provided settings. func generateUnboundConf(ctx context.Context, settings settings.DNS, username string, - client network.Client, logger logging.Logger) ( + client *http.Client, logger logging.Logger) ( lines []string, warnings []error) { doIPv6 := "no" if settings.IPv6 { @@ -151,7 +151,7 @@ func generateUnboundConf(ctx context.Context, settings settings.DNS, username st return lines, warnings } -func buildBlocked(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool, +func buildBlocked(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool, allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) { chHostnames := make(chan []string) chIPs := make(chan []string) @@ -181,13 +181,27 @@ func buildBlocked(ctx context.Context, client network.Client, blockMalicious, bl return hostnamesLines, ipsLines, errs } -func getList(ctx context.Context, client network.Client, url string) (results []string, err error) { - content, status, err := client.Get(ctx, url) +func getList(ctx context.Context, client *http.Client, url string) (results []string, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err - } else if status != http.StatusOK { - return nil, fmt.Errorf("HTTP status code is %d and not 200", status) } + + response, err := client.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status) + } + + content, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err) + } + results = strings.Split(string(content), "\n") // remove empty lines @@ -206,7 +220,7 @@ func getList(ctx context.Context, client network.Client, url string) (results [] return results, nil } -func buildBlockedHostnames(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool, +func buildBlockedHostnames(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool, allowedHostnames []string) (lines []string, errs []error) { chResults := make(chan []string) chError := make(chan error) @@ -258,7 +272,7 @@ func buildBlockedHostnames(ctx context.Context, client network.Client, blockMali return lines, errs } -func buildBlockedIPs(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool, +func buildBlockedIPs(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool, privateAddresses []string) (lines []string, errs []error) { chResults := make(chan []string) chError := make(chan error) diff --git a/internal/dns/conf_test.go b/internal/dns/conf_test.go index 8d287078..cec13466 100644 --- a/internal/dns/conf_test.go +++ b/internal/dns/conf_test.go @@ -1,8 +1,11 @@ package dns import ( + "bytes" "context" "fmt" + "io/ioutil" + "net/http" "strings" "testing" @@ -11,7 +14,6 @@ import ( "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/logging/mock_logging" - "github.com/qdm12/golibs/network/mock_network" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -32,17 +34,45 @@ func Test_generateUnboundConf(t *testing.T) { } mockCtrl := gomock.NewController(t) ctx := context.Background() - client := mock_network.NewMockClient(mockCtrl) - client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)). - Return([]byte("b\na\nc"), 200, nil) - client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)). - Return([]byte("c\nd\n"), 200, nil) + + clientCalls := map[models.URL]int{ + constants.MaliciousBlockListIPsURL: 0, + constants.MaliciousBlockListHostnamesURL: 0, + } + client := &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + url := models.URL(r.URL.String()) + if _, ok := clientCalls[url]; !ok { + t.Errorf("unknown URL %q", url) + return nil, nil + } + clientCalls[url]++ + var body string + switch url { + case constants.MaliciousBlockListIPsURL: + body = "c\nd" + case constants.MaliciousBlockListHostnamesURL: + body = "b\na\nc" + default: + t.Errorf("unknown URL %q", url) + return nil, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(body)), + }, nil + }), + } + logger := mock_logging.NewMockLogger(mockCtrl) logger.EXPECT().Info("%d hostnames blocked overall", 2) logger.EXPECT().Info("%d IP addresses blocked overall", 3) lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger) require.Len(t, warnings, 0) - expected := ` + for url, count := range clientCalls { + assert.Equalf(t, 1, count, "for url %q", url) + } + const expected = ` server: cache-max-ttl: 9000 cache-min-ttl: 3600 @@ -209,7 +239,10 @@ func Test_buildBlocked(t *testing.T) { ipsLines: []string{ " private-address: malicious", " private-address: surveillance"}, - errsString: []string{"ads error", "ads error"}, + errsString: []string{ + `Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads error`, + `Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads error`, + }, }, "all blocked with errors": { malicious: blockParams{ @@ -224,37 +257,74 @@ func Test_buildBlocked(t *testing.T) { blocked: true, clientErr: fmt.Errorf("surveillance"), }, - errsString: []string{"malicious", "malicious", "ads", "ads", "surveillance", "surveillance"}, + errsString: []string{ + `Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated": malicious`, + `Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated": malicious`, + `Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads`, + `Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads`, + `Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance`, + `Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance`, + }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() - mockCtrl := gomock.NewController(t) ctx := context.Background() - client := mock_network.NewMockClient(mockCtrl) + + clientCalls := map[models.URL]int{} if tc.malicious.blocked { - client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)). - Return(tc.malicious.content, 200, tc.malicious.clientErr) - client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)). - Return(tc.malicious.content, 200, tc.malicious.clientErr) + clientCalls[constants.MaliciousBlockListIPsURL] = 0 + clientCalls[constants.MaliciousBlockListHostnamesURL] = 0 } if tc.ads.blocked { - client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)). - Return(tc.ads.content, 200, tc.ads.clientErr) - client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)). - Return(tc.ads.content, 200, tc.ads.clientErr) + clientCalls[constants.AdsBlockListIPsURL] = 0 + clientCalls[constants.AdsBlockListHostnamesURL] = 0 } if tc.surveillance.blocked { - client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)). - Return(tc.surveillance.content, 200, tc.surveillance.clientErr) - client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)). - Return(tc.surveillance.content, 200, tc.surveillance.clientErr) + clientCalls[constants.SurveillanceBlockListIPsURL] = 0 + clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0 } + + client := &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + url := models.URL(r.URL.String()) + if _, ok := clientCalls[url]; !ok { + t.Errorf("unknown URL %q", url) + return nil, nil + } + clientCalls[url]++ + var body []byte + var err error + switch url { + case constants.MaliciousBlockListIPsURL, constants.MaliciousBlockListHostnamesURL: + body = tc.malicious.content + err = tc.malicious.clientErr + case constants.AdsBlockListIPsURL, constants.AdsBlockListHostnamesURL: + body = tc.ads.content + err = tc.ads.clientErr + case constants.SurveillanceBlockListIPsURL, constants.SurveillanceBlockListHostnamesURL: + body = tc.surveillance.content + err = tc.surveillance.clientErr + default: // just in case if the test is badly written + t.Errorf("unknown URL %q", url) + return nil, nil + } + if err != nil { + return nil, err + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader(body)), + }, nil + }), + } + hostnamesLines, ipsLines, errs := buildBlocked(ctx, client, tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.allowedHostnames, tc.privateAddresses) + var errsString []string for _, err := range errs { errsString = append(errsString, err.Error()) @@ -262,6 +332,10 @@ func Test_buildBlocked(t *testing.T) { assert.ElementsMatch(t, tc.errsString, errsString) assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines) assert.ElementsMatch(t, tc.ipsLines, ipsLines) + + for url, count := range clientCalls { + assert.Equalf(t, 1, count, "for url %q", url) + } }) } } @@ -275,20 +349,45 @@ func Test_getList(t *testing.T) { results []string err error }{ - "no result": {nil, 200, nil, nil, nil}, - "bad status": {nil, 500, nil, nil, fmt.Errorf("HTTP status code is 500 and not 200")}, - "network error": {nil, 200, fmt.Errorf("error"), nil, fmt.Errorf("error")}, - "results": {[]byte("a\nb\nc\n"), 200, nil, []string{"a", "b", "c"}, nil}, + "no result": { + status: http.StatusOK, + }, + "bad status": { + status: http.StatusInternalServerError, + err: fmt.Errorf("bad HTTP status from irrelevant_url: Internal Server Error"), + }, + "network error": { + status: http.StatusOK, + clientErr: fmt.Errorf("error"), + err: fmt.Errorf(`Get "irrelevant_url": error`), + }, + "results": { + content: []byte("a\nb\nc\n"), + status: http.StatusOK, + results: []string{"a", "b", "c"}, + }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() - mockCtrl := gomock.NewController(t) + ctx := context.Background() - client := mock_network.NewMockClient(mockCtrl) - client.EXPECT().Get(ctx, "irrelevant_url"). - Return(tc.content, tc.status, tc.clientErr) + + client := &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "irrelevant_url", r.URL.String()) + if tc.clientErr != nil { + return nil, tc.clientErr + } + return &http.Response{ + StatusCode: tc.status, + Status: http.StatusText(tc.status), + Body: ioutil.NopCloser(bytes.NewReader(tc.content)), + }, nil + }), + } + results, err := getList(ctx, client, "irrelevant_url") if tc.err != nil { require.Error(t, err) @@ -316,10 +415,7 @@ func Test_buildBlockedHostnames(t *testing.T) { lines []string errsString []string }{ - "nothing blocked": { - lines: nil, - errsString: nil, - }, + "nothing blocked": {}, "only malicious blocked": { malicious: blockParams{ blocked: true, @@ -329,7 +425,6 @@ func Test_buildBlockedHostnames(t *testing.T) { lines: []string{ " local-zone: \"site_a\" static", " local-zone: \"site_b\" static"}, - errsString: nil, }, "all blocked with some duplicates": { malicious: blockParams{ @@ -348,7 +443,6 @@ func Test_buildBlockedHostnames(t *testing.T) { " local-zone: \"site_a\" static", " local-zone: \"site_b\" static", " local-zone: \"site_c\" static"}, - errsString: nil, }, "all blocked with one errored": { malicious: blockParams{ @@ -367,7 +461,9 @@ func Test_buildBlockedHostnames(t *testing.T) { " local-zone: \"site_a\" static", " local-zone: \"site_b\" static", " local-zone: \"site_c\" static"}, - errsString: []string{"surveillance error"}, + errsString: []string{ + `Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance error`, + }, }, "blocked with allowed hostnames": { malicious: blockParams{ @@ -384,34 +480,71 @@ func Test_buildBlockedHostnames(t *testing.T) { " local-zone: \"site_d\" static"}, }, } - for name, tc := range tests { //nolint:dupl + for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() - mockCtrl := gomock.NewController(t) ctx := context.Background() - client := mock_network.NewMockClient(mockCtrl) + + clientCalls := map[models.URL]int{} if tc.malicious.blocked { - client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)). - Return(tc.malicious.content, 200, tc.malicious.clientErr) + clientCalls[constants.MaliciousBlockListHostnamesURL] = 0 } if tc.ads.blocked { - client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)). - Return(tc.ads.content, 200, tc.ads.clientErr) + clientCalls[constants.AdsBlockListHostnamesURL] = 0 } if tc.surveillance.blocked { - client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)). - Return(tc.surveillance.content, 200, tc.surveillance.clientErr) + clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0 } + + client := &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + url := models.URL(r.URL.String()) + if _, ok := clientCalls[url]; !ok { + t.Errorf("unknown URL %q", url) + return nil, nil + } + clientCalls[url]++ + var body []byte + var err error + switch url { + case constants.MaliciousBlockListHostnamesURL: + body = tc.malicious.content + err = tc.malicious.clientErr + case constants.AdsBlockListHostnamesURL: + body = tc.ads.content + err = tc.ads.clientErr + case constants.SurveillanceBlockListHostnamesURL: + body = tc.surveillance.content + err = tc.surveillance.clientErr + default: // just in case if the test is badly written + t.Errorf("unknown URL %q", url) + return nil, nil + } + if err != nil { + return nil, err + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader(body)), + }, nil + }), + } + lines, errs := buildBlockedHostnames(ctx, client, tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.allowedHostnames) + var errsString []string for _, err := range errs { errsString = append(errsString, err.Error()) } assert.ElementsMatch(t, tc.errsString, errsString) assert.ElementsMatch(t, tc.lines, lines) + + for url, count := range clientCalls { + assert.Equalf(t, 1, count, "for url %q", url) + } }) } } @@ -431,10 +564,7 @@ func Test_buildBlockedIPs(t *testing.T) { lines []string errsString []string }{ - "nothing blocked": { - lines: nil, - errsString: nil, - }, + "nothing blocked": {}, "only malicious blocked": { malicious: blockParams{ blocked: true, @@ -444,7 +574,6 @@ func Test_buildBlockedIPs(t *testing.T) { lines: []string{ " private-address: site_a", " private-address: site_b"}, - errsString: nil, }, "all blocked with some duplicates": { malicious: blockParams{ @@ -463,7 +592,6 @@ func Test_buildBlockedIPs(t *testing.T) { " private-address: site_a", " private-address: site_b", " private-address: site_c"}, - errsString: nil, }, "all blocked with one errored": { malicious: blockParams{ @@ -482,7 +610,9 @@ func Test_buildBlockedIPs(t *testing.T) { " private-address: site_a", " private-address: site_b", " private-address: site_c"}, - errsString: []string{"surveillance error"}, + errsString: []string{ + `Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance error`, + }, }, "blocked with private addresses": { malicious: blockParams{ @@ -501,34 +631,72 @@ func Test_buildBlockedIPs(t *testing.T) { " private-address: site_d"}, }, } - for name, tc := range tests { //nolint:dupl + for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() - mockCtrl := gomock.NewController(t) + ctx := context.Background() - client := mock_network.NewMockClient(mockCtrl) + + clientCalls := map[models.URL]int{} if tc.malicious.blocked { - client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)). - Return(tc.malicious.content, 200, tc.malicious.clientErr) + clientCalls[constants.MaliciousBlockListIPsURL] = 0 } if tc.ads.blocked { - client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)). - Return(tc.ads.content, 200, tc.ads.clientErr) + clientCalls[constants.AdsBlockListIPsURL] = 0 } if tc.surveillance.blocked { - client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)). - Return(tc.surveillance.content, 200, tc.surveillance.clientErr) + clientCalls[constants.SurveillanceBlockListIPsURL] = 0 } + + client := &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + url := models.URL(r.URL.String()) + if _, ok := clientCalls[url]; !ok { + t.Errorf("unknown URL %q", url) + return nil, nil + } + clientCalls[url]++ + var body []byte + var err error + switch url { + case constants.MaliciousBlockListIPsURL: + body = tc.malicious.content + err = tc.malicious.clientErr + case constants.AdsBlockListIPsURL: + body = tc.ads.content + err = tc.ads.clientErr + case constants.SurveillanceBlockListIPsURL: + body = tc.surveillance.content + err = tc.surveillance.clientErr + default: // just in case if the test is badly written + t.Errorf("unknown URL %q", url) + return nil, nil + } + if err != nil { + return nil, err + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader(body)), + }, nil + }), + } + lines, errs := buildBlockedIPs(ctx, client, tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.privateAddresses) + var errsString []string for _, err := range errs { errsString = append(errsString, err.Error()) } assert.ElementsMatch(t, tc.errsString, errsString) assert.ElementsMatch(t, tc.lines, lines) + + for url, count := range clientCalls { + assert.Equalf(t, 1, count, "for url %q", url) + } }) } } diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 6a5d522c..28f4a304 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -4,12 +4,12 @@ import ( "context" "io" "net" + "net/http" "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/network" ) type Configurator interface { @@ -25,17 +25,17 @@ type Configurator interface { type configurator struct { logger logging.Logger - client network.Client + client *http.Client openFile os.OpenFileFunc commander command.Commander lookupIP func(host string) ([]net.IP, error) } -func NewConfigurator(logger logging.Logger, client network.Client, +func NewConfigurator(logger logging.Logger, httpClient *http.Client, openFile os.OpenFileFunc) Configurator { return &configurator{ logger: logger.WithPrefix("dns configurator: "), - client: client, + client: httpClient, openFile: openFile, commander: command.NewCommander(), lookupIP: net.LookupIP, diff --git a/internal/dns/errors.go b/internal/dns/errors.go new file mode 100644 index 00000000..9a7e860a --- /dev/null +++ b/internal/dns/errors.go @@ -0,0 +1,8 @@ +package dns + +import "errors" + +var ( + ErrBadStatusCode = errors.New("bad HTTP status") + ErrCannotReadBody = errors.New("cannot read response body") +) diff --git a/internal/dns/roots.go b/internal/dns/roots.go index 103c81da..f6a950fb 100644 --- a/internal/dns/roots.go +++ b/internal/dns/roots.go @@ -3,6 +3,7 @@ package dns import ( "context" "fmt" + "io" "net/http" "os" @@ -21,11 +22,19 @@ func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error { c.logger.Info("downloading %s from %s", logName, url) - content, status, err := c.client.Get(ctx, url) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return err - } else if status != http.StatusOK { - return fmt.Errorf("HTTP status code is %d for %s", status, url) + } + + response, err := c.client.Do(req) + if err != nil { + return err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status) } file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400) @@ -33,7 +42,7 @@ func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepa return err } - _, err = file.Write(content) + _, err = io.Copy(file, response.Body) if err != nil { _ = file.Close() return err diff --git a/internal/dns/roots_test.go b/internal/dns/roots_test.go index 67e69fca..fa741a48 100644 --- a/internal/dns/roots_test.go +++ b/internal/dns/roots_test.go @@ -1,9 +1,11 @@ package dns import ( + "bytes" "context" "errors" "fmt" + "io/ioutil" "net/http" "testing" @@ -12,14 +14,15 @@ import ( "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/os/mock_os" "github.com/qdm12/golibs/logging/mock_logging" - "github.com/qdm12/golibs/network/mock_network" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_downloadAndSave(t *testing.T) { t.Parallel() + const defaultURL = "https://test.com" tests := map[string]struct { + url string // to trigger a new request error content []byte status int clientErr error @@ -30,37 +33,39 @@ func Test_downloadAndSave(t *testing.T) { err error }{ "no data": { + url: defaultURL, status: http.StatusOK, }, "bad status": { + url: defaultURL, status: http.StatusBadRequest, - err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/named.root.updated"), //nolint:lll + err: fmt.Errorf("bad HTTP status from %s: Bad Request", defaultURL), }, "client error": { + url: defaultURL, clientErr: fmt.Errorf("error"), - err: fmt.Errorf("error"), + err: fmt.Errorf("Get %q: error", defaultURL), }, "open error": { + url: defaultURL, status: http.StatusOK, openErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, - "write error": { - status: http.StatusOK, - writeErr: fmt.Errorf("error"), - err: fmt.Errorf("error"), - }, "chown error": { + url: defaultURL, status: http.StatusOK, chownErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, "close error": { + url: defaultURL, status: http.StatusOK, closeErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, "data": { + url: defaultURL, content: []byte("content"), status: http.StatusOK, }, @@ -73,30 +78,46 @@ func Test_downloadAndSave(t *testing.T) { ctx := context.Background() logger := mock_logging.NewMockLogger(mockCtrl) - logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL)) - client := mock_network.NewMockClient(mockCtrl) - client.EXPECT().Get(ctx, string(constants.NamedRootURL)). - Return(tc.content, tc.status, tc.clientErr) + logger.EXPECT().Info("downloading %s from %s", "root hints", tc.url) + + client := &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, tc.url, r.URL.String()) + if tc.clientErr != nil { + return nil, tc.clientErr + } + return &http.Response{ + StatusCode: tc.status, + Status: http.StatusText(tc.status), + Body: ioutil.NopCloser(bytes.NewReader(tc.content)), + }, nil + }), + } openFile := func(name string, flag int, perm os.FileMode) (os.File, error) { return nil, nil } + const filepath = "/test" + if tc.clientErr == nil && tc.status == http.StatusOK { file := mock_os.NewMockFile(mockCtrl) if tc.openErr == nil { - writeCall := file.EXPECT().Write(tc.content). - Return(0, tc.writeErr) - if tc.writeErr != nil { - file.EXPECT().Close().Return(tc.closeErr).After(writeCall) - } else { - chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall) - file.EXPECT().Close().Return(tc.closeErr).After(chownCall) + if len(tc.content) > 0 { + file.EXPECT(). + Write(tc.content). + Return(len(tc.content), tc.writeErr) } + file.EXPECT(). + Close(). + Return(tc.closeErr) + file.EXPECT(). + Chown(1000, 1000). + Return(tc.chownErr) } openFile = func(name string, flag int, perm os.FileMode) (os.File, error) { - assert.Equal(t, string(constants.RootHints), name) + assert.Equal(t, filepath, name) assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag) assert.Equal(t, os.FileMode(0400), perm) return file, tc.openErr @@ -110,7 +131,7 @@ func Test_downloadAndSave(t *testing.T) { } err := c.downloadAndSave(ctx, "root hints", - string(constants.NamedRootURL), string(constants.RootHints), + tc.url, filepath, 1000, 1000) if tc.err != nil { @@ -130,9 +151,13 @@ func Test_DownloadRootHints(t *testing.T) { ctx := context.Background() logger := mock_logging.NewMockLogger(mockCtrl) logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL)) - client := mock_network.NewMockClient(mockCtrl) - client.EXPECT().Get(ctx, string(constants.NamedRootURL)). - Return(nil, http.StatusOK, errors.New("test")) + + client := &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, string(constants.NamedRootURL), r.URL.String()) + return nil, errors.New("test") + }), + } c := &configurator{ logger: logger, @@ -141,7 +166,7 @@ func Test_DownloadRootHints(t *testing.T) { err := c.DownloadRootHints(ctx, 1000, 1000) require.Error(t, err) - assert.Equal(t, "test", err.Error()) + assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated": test`, err.Error()) } func Test_DownloadRootKey(t *testing.T) { @@ -151,9 +176,13 @@ func Test_DownloadRootKey(t *testing.T) { ctx := context.Background() logger := mock_logging.NewMockLogger(mockCtrl) logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL)) - client := mock_network.NewMockClient(mockCtrl) - client.EXPECT().Get(ctx, string(constants.RootKeyURL)). - Return(nil, http.StatusOK, errors.New("test")) + + client := &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, string(constants.RootKeyURL), r.URL.String()) + return nil, errors.New("test") + }), + } c := &configurator{ logger: logger, @@ -162,5 +191,5 @@ func Test_DownloadRootKey(t *testing.T) { err := c.DownloadRootKey(ctx, 1000, 1000) require.Error(t, err) - assert.Equal(t, "test", err.Error()) + assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated": test`, err.Error()) } diff --git a/internal/dns/roundtrip_test.go b/internal/dns/roundtrip_test.go new file mode 100644 index 00000000..eaaac8ee --- /dev/null +++ b/internal/dns/roundtrip_test.go @@ -0,0 +1,9 @@ +package dns + +import "net/http" + +type roundTripFunc func(r *http.Request) (*http.Response, error) + +func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return s(r) +} diff --git a/internal/publicip/errors.go b/internal/publicip/errors.go new file mode 100644 index 00000000..ea1ff226 --- /dev/null +++ b/internal/publicip/errors.go @@ -0,0 +1,8 @@ +package publicip + +import "errors" + +var ( + ErrBadStatusCode = errors.New("bad HTTP status") + ErrCannotReadBody = errors.New("cannot read response body") +) diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index d3a6dc11..339e2837 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -3,6 +3,7 @@ package publicip import ( "context" "net" + "net/http" "sync" "time" @@ -11,7 +12,6 @@ import ( "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/network" ) type Looper interface { @@ -45,7 +45,7 @@ type looper struct { timeSince func(time.Time) time.Duration } -func NewLooper(client network.Client, logger logging.Logger, +func NewLooper(client *http.Client, logger logging.Logger, settings settings.PublicIP, uid, gid int, os os.OS) Looper { return &looper{ diff --git a/internal/publicip/publicip.go b/internal/publicip/publicip.go index 845fedaa..9d3a9e6a 100644 --- a/internal/publicip/publicip.go +++ b/internal/publicip/publicip.go @@ -3,12 +3,11 @@ package publicip import ( "context" "fmt" + "io/ioutil" "math/rand" "net" "net/http" "strings" - - "github.com/qdm12/golibs/network" ) type IPGetter interface { @@ -16,11 +15,11 @@ type IPGetter interface { } type ipGetter struct { - client network.Client + client *http.Client randIntn func(n int) int } -func NewIPGetter(client network.Client) IPGetter { +func NewIPGetter(client *http.Client) IPGetter { return &ipGetter{ client: client, randIntn: rand.Intn, @@ -39,12 +38,27 @@ func (i *ipGetter) Get(ctx context.Context) (ip net.IP, err error) { "https://ipinfo.io/ip", } url := urls[i.randIntn(len(urls))] - content, status, err := i.client.Get(ctx, url, network.UseRandomUserAgent()) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err - } else if status != http.StatusOK { - return nil, fmt.Errorf("received unexpected status code %d from %s", status, url) } + + response, err := i.client.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status) + } + + content, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err) + } + s := strings.ReplaceAll(string(content), "\n", "") ip = net.ParseIP(s) if ip == nil {