Code maintenance: Using qdm/dns and qdm12/updated

This commit is contained in:
Quentin McGaw
2021-01-02 18:31:39 +00:00
parent 5dcbe79fa8
commit a67efd1ad1
22 changed files with 154 additions and 1855 deletions

View File

@@ -1,98 +0,0 @@
package constants
import (
"net"
"github.com/qdm12/gluetun/internal/models"
)
const (
// Cloudflare is a DNS over TLS provider.
Cloudflare models.DNSProvider = "cloudflare"
// Google is a DNS over TLS provider.
Google models.DNSProvider = "google"
// Quad9 is a DNS over TLS provider.
Quad9 models.DNSProvider = "quad9"
// Quadrant is a DNS over TLS provider.
Quadrant models.DNSProvider = "quadrant"
// CleanBrowsing is a DNS over TLS provider.
CleanBrowsing models.DNSProvider = "cleanbrowsing"
)
// DNSProviderMapping returns a constant mapping of dns provider name
// to their data such as IP addresses or TLS host name.
func DNSProviderMapping() map[models.DNSProvider]models.DNSProviderData {
return map[models.DNSProvider]models.DNSProviderData{
Cloudflare: {
IPs: []net.IP{
{1, 1, 1, 1},
{1, 0, 0, 1},
{0x26, 0x6, 0x47, 0x0, 0x47, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x11, 0x11},
{0x26, 0x6, 0x47, 0x0, 0x47, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x01},
},
SupportsTLS: true,
SupportsIPv6: true,
Host: models.DNSHost("cloudflare-dns.com"),
},
Google: {
IPs: []net.IP{
{8, 8, 8, 8},
{8, 8, 4, 4},
{0x20, 0x1, 0x48, 0x60, 0x48, 0x60, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x88, 0x88},
{0x20, 0x1, 0x48, 0x60, 0x48, 0x60, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x88, 0x44},
},
SupportsTLS: true,
SupportsIPv6: true,
Host: models.DNSHost("dns.google"),
},
Quad9: {
IPs: []net.IP{
{9, 9, 9, 9},
{149, 112, 112, 112},
{0x26, 0x20, 0x0, 0xfe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xfe},
{0x26, 0x20, 0x0, 0xfe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
},
SupportsTLS: true,
SupportsIPv6: true,
Host: models.DNSHost("dns.quad9.net"),
},
Quadrant: {
IPs: []net.IP{
{12, 159, 2, 159},
{0x20, 0x1, 0x18, 0x90, 0x14, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x59},
},
SupportsTLS: true,
SupportsIPv6: true,
Host: models.DNSHost("dns-tls.qis.io"),
},
CleanBrowsing: {
IPs: []net.IP{
{185, 228, 168, 9},
{185, 228, 169, 9},
{0x2a, 0xd, 0x2a, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2},
{0x2a, 0xd, 0x2a, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2},
},
SupportsTLS: true,
SupportsIPv6: true,
Host: models.DNSHost("security-filter-dns.cleanbrowsing.org"),
},
}
}
// Block lists URLs.
//nolint:lll
const (
AdsBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated"
AdsBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated"
MaliciousBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated"
MaliciousBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated"
SurveillanceBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated"
SurveillanceBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated"
)
// DNS certificates to fetch.
// TODO obtain from source directly, see qdm12/updated).
const (
NamedRootURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated"
RootKeyURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"
)

View File

@@ -1,43 +0,0 @@
package dns
import (
"context"
"fmt"
"io"
"strings"
"github.com/qdm12/gluetun/internal/constants"
)
func (c *configurator) Start(ctx context.Context, verbosityDetailsLevel uint8) (
stdout io.ReadCloser, waitFn func() error, err error) {
c.logger.Info("starting unbound")
args := []string{"-d", "-c", string(constants.UnboundConf)}
if verbosityDetailsLevel > 0 {
args = append(args, "-"+strings.Repeat("v", int(verbosityDetailsLevel)))
}
// Only logs to stderr
_, stdout, waitFn, err = c.commander.Start(ctx, "unbound", args...)
return stdout, waitFn, err
}
func (c *configurator) Version(ctx context.Context) (version string, err error) {
output, err := c.commander.Run(ctx, "unbound", "-V")
if err != nil {
return "", fmt.Errorf("unbound version: %w", err)
}
for _, line := range strings.Split(output, "\n") {
if strings.Contains(line, "Version ") {
words := strings.Fields(line)
const minWords = 2
if len(words) < minWords {
continue
}
version = words[1]
}
}
if version == "" {
return "", fmt.Errorf("unbound version was not found in %q", output)
}
return version, nil
}

View File

@@ -1,70 +0,0 @@
package dns
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/command/mock_command"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Start(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("starting unbound")
commander := mock_command.NewMockCommander(mockCtrl)
commander.EXPECT().Start(context.Background(), "unbound", "-d", "-c", string(constants.UnboundConf), "-vv").
Return(nil, nil, nil, nil)
c := &configurator{commander: commander, logger: logger}
stdout, waitFn, err := c.Start(context.Background(), 2)
assert.Nil(t, stdout)
assert.Nil(t, waitFn)
assert.NoError(t, err)
}
func Test_Version(t *testing.T) {
t.Parallel()
tests := map[string]struct {
runOutput string
runErr error
version string
err error
}{
"no data": {
err: fmt.Errorf(`unbound version was not found in ""`),
},
"2 lines with version": {
runOutput: "Version \nVersion 1.0-a hello\n",
version: "1.0-a",
},
"run error": {
runErr: fmt.Errorf("error"),
err: fmt.Errorf("unbound version: error"),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
commander := mock_command.NewMockCommander(mockCtrl)
commander.EXPECT().Run(context.Background(), "unbound", "-V").
Return(tc.runOutput, tc.runErr)
c := &configurator{commander: commander}
version, err := c.Version(context.Background())
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.version, version)
})
}
}

View File

@@ -1,325 +0,0 @@
package dns
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"sort"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS,
username string, puid, pgid int) (err error) {
c.logger.Info("generating Unbound configuration")
lines, warnings := generateUnboundConf(ctx, settings, username, c.client, c.logger)
for _, warning := range warnings {
c.logger.Warn(warning)
}
const filepath = string(constants.UnboundConf)
file, err := c.openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0400)
if err != nil {
return err
}
_, err = file.WriteString(strings.Join(lines, "\n"))
if err != nil {
_ = file.Close()
return err
}
if err := file.Chown(puid, pgid); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
return nil
}
// MakeUnboundConf generates an Unbound configuration from the user provided settings.
func generateUnboundConf(ctx context.Context, settings settings.DNS, username string,
client *http.Client, logger logging.Logger) (
lines []string, warnings []error) {
doIPv6 := "no"
if settings.IPv6 {
doIPv6 = "yes"
}
serverSection := map[string]string{
// Logging
"verbosity": fmt.Sprintf("%d", settings.VerbosityLevel),
"val-log-level": fmt.Sprintf("%d", settings.ValidationLogLevel),
"use-syslog": "no",
// Performance
"num-threads": "1",
"prefetch": "yes",
"prefetch-key": "yes",
"key-cache-size": "16m",
"key-cache-slabs": "4",
"msg-cache-size": "4m",
"msg-cache-slabs": "4",
"rrset-cache-size": "4m",
"rrset-cache-slabs": "4",
"cache-min-ttl": "3600",
"cache-max-ttl": "9000",
// Privacy
"rrset-roundrobin": "yes",
"hide-identity": "yes",
"hide-version": "yes",
// Security
"tls-cert-bundle": fmt.Sprintf("%q", constants.CACertificates),
"root-hints": fmt.Sprintf("%q", constants.RootHints),
"trust-anchor-file": fmt.Sprintf("%q", constants.RootKey),
"harden-below-nxdomain": "yes",
"harden-referral-path": "yes",
"harden-algo-downgrade": "yes",
// Network
"do-ip4": "yes",
"do-ip6": doIPv6,
"interface": "0.0.0.0",
"port": "53",
// Other
"username": fmt.Sprintf("%q", username),
}
// Block lists
hostnamesLines, ipsLines, warnings := buildBlocked(ctx, client,
settings.BlockMalicious, settings.BlockAds, settings.BlockSurveillance,
settings.AllowedHostnames, settings.PrivateAddresses,
)
logger.Info("%d hostnames blocked overall", len(hostnamesLines))
logger.Info("%d IP addresses blocked overall", len(ipsLines))
sort.Slice(hostnamesLines, func(i, j int) bool { // for unit tests really
return hostnamesLines[i] < hostnamesLines[j]
})
sort.Slice(ipsLines, func(i, j int) bool { // for unit tests really
return ipsLines[i] < ipsLines[j]
})
// Server
lines = append(lines, "server:")
serverLines := make([]string, len(serverSection))
i := 0
for k, v := range serverSection {
serverLines[i] = " " + k + ": " + v
i++
}
sort.Slice(serverLines, func(i, j int) bool {
return serverLines[i] < serverLines[j]
})
lines = append(lines, serverLines...)
lines = append(lines, hostnamesLines...)
lines = append(lines, ipsLines...)
// Forward zone
lines = append(lines, "forward-zone:")
forwardZoneSection := map[string]string{
"name": "\".\"",
"forward-tls-upstream": "yes",
}
if settings.Caching {
forwardZoneSection["forward-no-cache"] = "no"
} else {
forwardZoneSection["forward-no-cache"] = "yes"
}
forwardZoneLines := make([]string, len(forwardZoneSection))
i = 0
for k, v := range forwardZoneSection {
forwardZoneLines[i] = " " + k + ": " + v
i++
}
sort.Slice(forwardZoneLines, func(i, j int) bool {
return forwardZoneLines[i] < forwardZoneLines[j]
})
for _, provider := range settings.Providers {
providerData := constants.DNSProviderMapping()[provider]
for _, IP := range providerData.IPs {
forwardZoneLines = append(forwardZoneLines,
fmt.Sprintf(" forward-addr: %s@853#%s", IP, providerData.Host))
}
}
lines = append(lines, forwardZoneLines...)
return lines, warnings
}
func buildBlocked(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
chHostnames := make(chan []string)
chIPs := make(chan []string)
chErrors := make(chan []error)
go func() {
lines, errs := buildBlockedHostnames(ctx, client, blockMalicious, blockAds, blockSurveillance, allowedHostnames)
chHostnames <- lines
chErrors <- errs
}()
go func() {
lines, errs := buildBlockedIPs(ctx, client, blockMalicious, blockAds, blockSurveillance, privateAddresses)
chIPs <- lines
chErrors <- errs
}()
n := 2
for n > 0 {
select {
case lines := <-chHostnames:
hostnamesLines = append(hostnamesLines, lines...)
case lines := <-chIPs:
ipsLines = append(ipsLines, lines...)
case routineErrs := <-chErrors:
errs = append(errs, routineErrs...)
n--
}
}
return hostnamesLines, ipsLines, errs
}
func getList(ctx context.Context, client *http.Client, url string) (results []string, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
response, err := client.Do(req)
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
}
content, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err)
}
results = strings.Split(string(content), "\n")
// remove empty lines
last := len(results) - 1
for i := range results {
if len(results[i]) == 0 {
results[i] = results[last]
last--
}
}
results = results[:last+1]
if len(results) == 0 {
return nil, nil
}
return results, nil
}
func buildBlockedHostnames(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
allowedHostnames []string) (lines []string, errs []error) {
chResults := make(chan []string)
chError := make(chan error)
listsLeftToFetch := 0
if blockMalicious {
listsLeftToFetch++
go func() {
results, err := getList(ctx, client, string(constants.MaliciousBlockListHostnamesURL))
chResults <- results
chError <- err
}()
}
if blockAds {
listsLeftToFetch++
go func() {
results, err := getList(ctx, client, string(constants.AdsBlockListHostnamesURL))
chResults <- results
chError <- err
}()
}
if blockSurveillance {
listsLeftToFetch++
go func() {
results, err := getList(ctx, client, string(constants.SurveillanceBlockListHostnamesURL))
chResults <- results
chError <- err
}()
}
uniqueResults := make(map[string]struct{})
for listsLeftToFetch > 0 {
select {
case results := <-chResults:
for _, result := range results {
uniqueResults[result] = struct{}{}
}
case err := <-chError:
listsLeftToFetch--
if err != nil {
errs = append(errs, err)
}
}
}
for _, allowedHostname := range allowedHostnames {
delete(uniqueResults, allowedHostname)
}
for result := range uniqueResults {
lines = append(lines, " local-zone: \""+result+"\" static")
}
return lines, errs
}
func buildBlockedIPs(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
privateAddresses []string) (lines []string, errs []error) {
chResults := make(chan []string)
chError := make(chan error)
listsLeftToFetch := 0
if blockMalicious {
listsLeftToFetch++
go func() {
results, err := getList(ctx, client, string(constants.MaliciousBlockListIPsURL))
chResults <- results
chError <- err
}()
}
if blockAds {
listsLeftToFetch++
go func() {
results, err := getList(ctx, client, string(constants.AdsBlockListIPsURL))
chResults <- results
chError <- err
}()
}
if blockSurveillance {
listsLeftToFetch++
go func() {
results, err := getList(ctx, client, string(constants.SurveillanceBlockListIPsURL))
chResults <- results
chError <- err
}()
}
uniqueResults := make(map[string]struct{})
for listsLeftToFetch > 0 {
select {
case results := <-chResults:
for _, result := range results {
uniqueResults[result] = struct{}{}
}
case err := <-chError:
listsLeftToFetch--
if err != nil {
errs = append(errs, err)
}
}
}
for _, privateAddress := range privateAddresses {
uniqueResults[privateAddress] = struct{}{}
}
for result := range uniqueResults {
lines = append(lines, " private-address: "+result)
}
return lines, errs
}

View File

@@ -1,702 +0,0 @@
package dns
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_generateUnboundConf(t *testing.T) {
t.Parallel()
settings := settings.DNS{
Providers: []models.DNSProvider{constants.Cloudflare, constants.Quad9},
AllowedHostnames: []string{"a"},
PrivateAddresses: []string{"9.9.9.9"},
BlockMalicious: true,
BlockSurveillance: false,
BlockAds: false,
VerbosityLevel: 2,
ValidationLogLevel: 3,
Caching: true,
IPv6: true,
}
mockCtrl := gomock.NewController(t)
ctx := context.Background()
clientCalls := map[models.URL]int{
constants.MaliciousBlockListIPsURL: 0,
constants.MaliciousBlockListHostnamesURL: 0,
}
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
url := models.URL(r.URL.String())
if _, ok := clientCalls[url]; !ok {
t.Errorf("unknown URL %q", url)
return nil, nil
}
clientCalls[url]++
var body string
switch url {
case constants.MaliciousBlockListIPsURL:
body = "c\nd"
case constants.MaliciousBlockListHostnamesURL:
body = "b\na\nc"
default:
t.Errorf("unknown URL %q", url)
return nil, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(body)),
}, nil
}),
}
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("%d hostnames blocked overall", 2)
logger.EXPECT().Info("%d IP addresses blocked overall", 3)
lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger)
require.Len(t, warnings, 0)
for url, count := range clientCalls {
assert.Equalf(t, 1, count, "for url %q", url)
}
const expected = `
server:
cache-max-ttl: 9000
cache-min-ttl: 3600
do-ip4: yes
do-ip6: yes
harden-algo-downgrade: yes
harden-below-nxdomain: yes
harden-referral-path: yes
hide-identity: yes
hide-version: yes
interface: 0.0.0.0
key-cache-size: 16m
key-cache-slabs: 4
msg-cache-size: 4m
msg-cache-slabs: 4
num-threads: 1
port: 53
prefetch-key: yes
prefetch: yes
root-hints: "/etc/unbound/root.hints"
rrset-cache-size: 4m
rrset-cache-slabs: 4
rrset-roundrobin: yes
tls-cert-bundle: "/etc/ssl/certs/ca-certificates.crt"
trust-anchor-file: "/etc/unbound/root.key"
use-syslog: no
username: "nonrootuser"
val-log-level: 3
verbosity: 2
local-zone: "b" static
local-zone: "c" static
private-address: 9.9.9.9
private-address: c
private-address: d
forward-zone:
forward-no-cache: no
forward-tls-upstream: yes
name: "."
forward-addr: 1.1.1.1@853#cloudflare-dns.com
forward-addr: 1.0.0.1@853#cloudflare-dns.com
forward-addr: 2606:4700:4700::1111@853#cloudflare-dns.com
forward-addr: 2606:4700:4700::1001@853#cloudflare-dns.com
forward-addr: 9.9.9.9@853#dns.quad9.net
forward-addr: 149.112.112.112@853#dns.quad9.net
forward-addr: 2620:fe::fe@853#dns.quad9.net
forward-addr: 2620:fe::9@853#dns.quad9.net`
assert.Equal(t, expected, "\n"+strings.Join(lines, "\n"))
}
func Test_buildBlocked(t *testing.T) {
t.Parallel()
type blockParams struct {
blocked bool
content []byte
clientErr error
}
tests := map[string]struct {
malicious blockParams
ads blockParams
surveillance blockParams
allowedHostnames []string
privateAddresses []string
hostnamesLines []string
ipsLines []string
errsString []string
}{
"none blocked": {},
"all blocked without lists": {
malicious: blockParams{
blocked: true,
},
ads: blockParams{
blocked: true,
},
surveillance: blockParams{
blocked: true,
},
},
"all blocked with lists": {
malicious: blockParams{
blocked: true,
content: []byte("malicious"),
},
ads: blockParams{
blocked: true,
content: []byte("ads"),
},
surveillance: blockParams{
blocked: true,
content: []byte("surveillance"),
},
hostnamesLines: []string{
" local-zone: \"ads\" static",
" local-zone: \"malicious\" static",
" local-zone: \"surveillance\" static"},
ipsLines: []string{
" private-address: ads",
" private-address: malicious",
" private-address: surveillance"},
},
"all blocked with allowed hostnames": {
malicious: blockParams{
blocked: true,
content: []byte("malicious"),
},
ads: blockParams{
blocked: true,
content: []byte("ads"),
},
surveillance: blockParams{
blocked: true,
content: []byte("surveillance"),
},
allowedHostnames: []string{"ads"},
hostnamesLines: []string{
" local-zone: \"malicious\" static",
" local-zone: \"surveillance\" static"},
ipsLines: []string{
" private-address: ads",
" private-address: malicious",
" private-address: surveillance"},
},
"all blocked with private addresses": {
malicious: blockParams{
blocked: true,
content: []byte("malicious"),
},
ads: blockParams{
blocked: true,
content: []byte("ads"),
},
surveillance: blockParams{
blocked: true,
content: []byte("surveillance"),
},
privateAddresses: []string{"ads", "192.100.1.5"},
hostnamesLines: []string{
" local-zone: \"ads\" static",
" local-zone: \"malicious\" static",
" local-zone: \"surveillance\" static"},
ipsLines: []string{
" private-address: 192.100.1.5",
" private-address: ads",
" private-address: malicious",
" private-address: surveillance"},
},
"all blocked with lists and one error": {
malicious: blockParams{
blocked: true,
content: []byte("malicious"),
},
ads: blockParams{
blocked: true,
content: []byte("ads"),
clientErr: fmt.Errorf("ads error"),
},
surveillance: blockParams{
blocked: true,
content: []byte("surveillance"),
},
hostnamesLines: []string{
" local-zone: \"malicious\" static",
" local-zone: \"surveillance\" static"},
ipsLines: []string{
" private-address: malicious",
" private-address: surveillance"},
errsString: []string{
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads error`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads error`,
},
},
"all blocked with errors": {
malicious: blockParams{
blocked: true,
clientErr: fmt.Errorf("malicious"),
},
ads: blockParams{
blocked: true,
clientErr: fmt.Errorf("ads"),
},
surveillance: blockParams{
blocked: true,
clientErr: fmt.Errorf("surveillance"),
},
errsString: []string{
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated": malicious`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated": malicious`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance`,
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance`,
},
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
clientCalls := map[models.URL]int{}
if tc.malicious.blocked {
clientCalls[constants.MaliciousBlockListIPsURL] = 0
clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
}
if tc.ads.blocked {
clientCalls[constants.AdsBlockListIPsURL] = 0
clientCalls[constants.AdsBlockListHostnamesURL] = 0
}
if tc.surveillance.blocked {
clientCalls[constants.SurveillanceBlockListIPsURL] = 0
clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
}
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
url := models.URL(r.URL.String())
if _, ok := clientCalls[url]; !ok {
t.Errorf("unknown URL %q", url)
return nil, nil
}
clientCalls[url]++
var body []byte
var err error
switch url {
case constants.MaliciousBlockListIPsURL, constants.MaliciousBlockListHostnamesURL:
body = tc.malicious.content
err = tc.malicious.clientErr
case constants.AdsBlockListIPsURL, constants.AdsBlockListHostnamesURL:
body = tc.ads.content
err = tc.ads.clientErr
case constants.SurveillanceBlockListIPsURL, constants.SurveillanceBlockListHostnamesURL:
body = tc.surveillance.content
err = tc.surveillance.clientErr
default: // just in case if the test is badly written
t.Errorf("unknown URL %q", url)
return nil, nil
}
if err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(body)),
}, nil
}),
}
hostnamesLines, ipsLines, errs := buildBlocked(ctx, client,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
tc.allowedHostnames, tc.privateAddresses)
var errsString []string
for _, err := range errs {
errsString = append(errsString, err.Error())
}
assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines)
assert.ElementsMatch(t, tc.ipsLines, ipsLines)
for url, count := range clientCalls {
assert.Equalf(t, 1, count, "for url %q", url)
}
})
}
}
func Test_getList(t *testing.T) {
t.Parallel()
tests := map[string]struct {
content []byte
status int
clientErr error
results []string
err error
}{
"no result": {
status: http.StatusOK,
},
"bad status": {
status: http.StatusInternalServerError,
err: fmt.Errorf("bad HTTP status from irrelevant_url: Internal Server Error"),
},
"network error": {
status: http.StatusOK,
clientErr: fmt.Errorf("error"),
err: fmt.Errorf(`Get "irrelevant_url": error`),
},
"results": {
content: []byte("a\nb\nc\n"),
status: http.StatusOK,
results: []string{"a", "b", "c"},
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "irrelevant_url", r.URL.String())
if tc.clientErr != nil {
return nil, tc.clientErr
}
return &http.Response{
StatusCode: tc.status,
Status: http.StatusText(tc.status),
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
}, nil
}),
}
results, err := getList(ctx, client, "irrelevant_url")
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.results, results)
})
}
}
func Test_buildBlockedHostnames(t *testing.T) {
t.Parallel()
type blockParams struct {
blocked bool
content []byte
clientErr error
}
tests := map[string]struct {
malicious blockParams
ads blockParams
surveillance blockParams
allowedHostnames []string
lines []string
errsString []string
}{
"nothing blocked": {},
"only malicious blocked": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
clientErr: nil,
},
lines: []string{
" local-zone: \"site_a\" static",
" local-zone: \"site_b\" static"},
},
"all blocked with some duplicates": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_a\nsite_c"),
},
surveillance: blockParams{
blocked: true,
content: []byte("site_c\nsite_a"),
},
lines: []string{
" local-zone: \"site_a\" static",
" local-zone: \"site_b\" static",
" local-zone: \"site_c\" static"},
},
"all blocked with one errored": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_a\nsite_c"),
},
surveillance: blockParams{
blocked: true,
clientErr: fmt.Errorf("surveillance error"),
},
lines: []string{
" local-zone: \"site_a\" static",
" local-zone: \"site_b\" static",
" local-zone: \"site_c\" static"},
errsString: []string{
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance error`,
},
},
"blocked with allowed hostnames": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_c\nsite_d"),
},
allowedHostnames: []string{"site_b", "site_c"},
lines: []string{
" local-zone: \"site_a\" static",
" local-zone: \"site_d\" static"},
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
clientCalls := map[models.URL]int{}
if tc.malicious.blocked {
clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
}
if tc.ads.blocked {
clientCalls[constants.AdsBlockListHostnamesURL] = 0
}
if tc.surveillance.blocked {
clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
}
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
url := models.URL(r.URL.String())
if _, ok := clientCalls[url]; !ok {
t.Errorf("unknown URL %q", url)
return nil, nil
}
clientCalls[url]++
var body []byte
var err error
switch url {
case constants.MaliciousBlockListHostnamesURL:
body = tc.malicious.content
err = tc.malicious.clientErr
case constants.AdsBlockListHostnamesURL:
body = tc.ads.content
err = tc.ads.clientErr
case constants.SurveillanceBlockListHostnamesURL:
body = tc.surveillance.content
err = tc.surveillance.clientErr
default: // just in case if the test is badly written
t.Errorf("unknown URL %q", url)
return nil, nil
}
if err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(body)),
}, nil
}),
}
lines, errs := buildBlockedHostnames(ctx, client,
tc.malicious.blocked, tc.ads.blocked,
tc.surveillance.blocked, tc.allowedHostnames)
var errsString []string
for _, err := range errs {
errsString = append(errsString, err.Error())
}
assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.lines, lines)
for url, count := range clientCalls {
assert.Equalf(t, 1, count, "for url %q", url)
}
})
}
}
func Test_buildBlockedIPs(t *testing.T) {
t.Parallel()
type blockParams struct {
blocked bool
content []byte
clientErr error
}
tests := map[string]struct {
malicious blockParams
ads blockParams
surveillance blockParams
privateAddresses []string
lines []string
errsString []string
}{
"nothing blocked": {},
"only malicious blocked": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
clientErr: nil,
},
lines: []string{
" private-address: site_a",
" private-address: site_b"},
},
"all blocked with some duplicates": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_a\nsite_c"),
},
surveillance: blockParams{
blocked: true,
content: []byte("site_c\nsite_a"),
},
lines: []string{
" private-address: site_a",
" private-address: site_b",
" private-address: site_c"},
},
"all blocked with one errored": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_a\nsite_c"),
},
surveillance: blockParams{
blocked: true,
clientErr: fmt.Errorf("surveillance error"),
},
lines: []string{
" private-address: site_a",
" private-address: site_b",
" private-address: site_c"},
errsString: []string{
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance error`,
},
},
"blocked with private addresses": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_c"),
},
privateAddresses: []string{"site_c", "site_d"},
lines: []string{
" private-address: site_a",
" private-address: site_b",
" private-address: site_c",
" private-address: site_d"},
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
clientCalls := map[models.URL]int{}
if tc.malicious.blocked {
clientCalls[constants.MaliciousBlockListIPsURL] = 0
}
if tc.ads.blocked {
clientCalls[constants.AdsBlockListIPsURL] = 0
}
if tc.surveillance.blocked {
clientCalls[constants.SurveillanceBlockListIPsURL] = 0
}
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
url := models.URL(r.URL.String())
if _, ok := clientCalls[url]; !ok {
t.Errorf("unknown URL %q", url)
return nil, nil
}
clientCalls[url]++
var body []byte
var err error
switch url {
case constants.MaliciousBlockListIPsURL:
body = tc.malicious.content
err = tc.malicious.clientErr
case constants.AdsBlockListIPsURL:
body = tc.ads.content
err = tc.ads.clientErr
case constants.SurveillanceBlockListIPsURL:
body = tc.surveillance.content
err = tc.surveillance.clientErr
default: // just in case if the test is badly written
t.Errorf("unknown URL %q", url)
return nil, nil
}
if err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(body)),
}, nil
}),
}
lines, errs := buildBlockedIPs(ctx, client,
tc.malicious.blocked, tc.ads.blocked,
tc.surveillance.blocked, tc.privateAddresses)
var errsString []string
for _, err := range errs {
errsString = append(errsString, err.Error())
}
assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.lines, lines)
for url, count := range clientCalls {
assert.Equalf(t, 1, count, "for url %q", url)
}
})
}
}

View File

@@ -1,43 +0,0 @@
package dns
import (
"context"
"io"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
type Configurator interface {
DownloadRootHints(ctx context.Context, puid, pgid int) error
DownloadRootKey(ctx context.Context, puid, pgid int) error
MakeUnboundConf(ctx context.Context, settings settings.DNS, username string, puid, pgid int) (err error)
UseDNSInternally(IP net.IP)
UseDNSSystemWide(ip net.IP, keepNameserver bool) error
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)
WaitForUnbound() (err error)
Version(ctx context.Context) (version string, err error)
}
type configurator struct {
logger logging.Logger
client *http.Client
openFile os.OpenFileFunc
commander command.Commander
lookupIP func(host string) ([]net.IP, error)
}
func NewConfigurator(logger logging.Logger, httpClient *http.Client,
openFile os.OpenFileFunc) Configurator {
return &configurator{
logger: logger.WithPrefix("dns configurator: "),
client: httpClient,
openFile: openFile,
commander: command.NewCommander(),
lookupIP: net.LookupIP,
}
}

View File

@@ -1,8 +0,0 @@
package dns
import "errors"
var (
ErrBadStatusCode = errors.New("bad HTTP status")
ErrCannotReadBody = errors.New("cannot read response body")
)

View File

@@ -4,9 +4,11 @@ import (
"context"
"errors"
"net"
"net/http"
"sync"
"time"
"github.com/qdm12/dns/pkg/unbound"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
@@ -25,7 +27,8 @@ type Looper interface {
type looper struct {
state state
conf Configurator
conf unbound.Configurator
client *http.Client
logger logging.Logger
streamMerger command.StreamMerger
username string
@@ -44,14 +47,16 @@ type looper struct {
const defaultBackoffTime = 10 * time.Second
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
streamMerger command.StreamMerger, username string, puid, pgid int) Looper {
func NewLooper(conf unbound.Configurator, settings settings.DNS, client *http.Client,
logger logging.Logger, streamMerger command.StreamMerger,
username string, puid, pgid int) Looper {
return &looper{
state: state{
status: constants.Stopped,
settings: settings,
},
conf: conf,
client: client,
logger: logger.WithPrefix("dns over tls: "),
username: username,
puid: puid,
@@ -170,7 +175,7 @@ func (l *looper) setupUnbound(ctx context.Context,
settings := l.GetSettings()
unboundCtx, cancel := context.WithCancel(context.Background())
stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel)
stream, waitFn, err := l.conf.Start(unboundCtx, settings.Unbound.VerbosityDetailsLevel)
if err != nil {
cancel()
if !previousCrashed {
@@ -187,7 +192,7 @@ func (l *looper) setupUnbound(ctx context.Context,
l.logger.Error(err)
}
if err := l.conf.WaitForUnbound(); err != nil {
if err := l.conf.WaitForUnbound(ctx); err != nil {
if !previousCrashed {
l.running <- constants.Crashed
}
@@ -229,8 +234,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
}
// Try with any IPv4 address from the providers chosen
for _, provider := range settings.Providers {
data := constants.DNSProviderMapping()[provider]
for _, provider := range settings.Unbound.Providers {
data, _ := unbound.GetProviderData(provider)
for _, targetIP = range data.IPs {
if targetIP.To4() != nil {
if fallback {
@@ -248,7 +253,7 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
}
// No IPv4 address found
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Providers)
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Unbound.Providers)
}
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
@@ -310,14 +315,22 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
}
func (l *looper) updateFiles(ctx context.Context) (err error) {
if err := l.conf.DownloadRootHints(ctx, l.puid, l.pgid); err != nil {
return err
}
if err := l.conf.DownloadRootKey(ctx, l.puid, l.pgid); err != nil {
if err := l.conf.SetupFiles(ctx); err != nil {
return err
}
settings := l.GetSettings()
if err := l.conf.MakeUnboundConf(ctx, settings, l.username, l.puid, l.pgid); err != nil {
hostnameLines, ipLines, errs := l.conf.BuildBlocked(ctx, l.client,
settings.BlockMalicious, settings.BlockAds, settings.BlockSurveillance,
settings.Unbound.BlockedHostnames, settings.Unbound.BlockedIPs,
settings.Unbound.AllowedHostnames)
for _, err := range errs {
l.logger.Warn(err)
}
if err := l.conf.MakeUnboundConf(
settings.Unbound, hostnameLines, ipLines,
l.username, l.puid, l.pgid); err != nil {
return err
}
return nil

View File

@@ -1,62 +0,0 @@
package dns
import (
"context"
"io/ioutil"
"net"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/os"
)
// UseDNSInternally is to change the Go program DNS only.
func (c *configurator) UseDNSInternally(ip net.IP) {
c.logger.Info("using DNS address %s internally", ip.String())
net.DefaultResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", net.JoinHostPort(ip.String(), "53"))
},
}
}
// UseDNSSystemWide changes the nameserver to use for DNS system wide.
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
c.logger.Info("using DNS address %s system wide", ip.String())
const filepath = string(constants.ResolvConf)
file, err := c.openFile(filepath, os.O_RDWR|os.O_TRUNC, 0644)
if err != nil {
return err
}
data, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
s := strings.TrimSuffix(string(data), "\n")
lines := strings.Split(s, "\n")
if len(lines) == 1 && lines[0] == "" {
lines = nil
}
found := false
if !keepNameserver { // default
for i := range lines {
if strings.HasPrefix(lines[i], "nameserver ") {
lines[i] = "nameserver " + ip.String()
found = true
}
}
}
if !found {
lines = append(lines, "nameserver "+ip.String())
}
s = strings.Join(lines, "\n") + "\n"
_, err = file.WriteString(s)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -1,106 +0,0 @@
package dns
import (
"fmt"
"io"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/os/mock_os"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
writtenData string
openErr error
readErr error
writeErr error
closeErr error
err error
}{
"no data": {
writtenData: "nameserver 127.0.0.1\n",
},
"open error": {
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
writtenData: "nameserver 127.0.0.1\n",
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"lines without nameserver": {
data: []byte("abc\ndef\n"),
writtenData: "abc\ndef\nnameserver 127.0.0.1\n",
},
"lines with nameserver": {
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "abc\nnameserver 127.0.0.1\ndef\n",
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil {
firstReadCall := file.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
DoAndReturn(func(b []byte) (int, error) {
copy(b, tc.data)
return len(tc.data), nil
})
readErr := tc.readErr
if readErr == nil {
readErr = io.EOF
}
finalReadCall := file.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
Return(0, readErr).After(firstReadCall)
if tc.readErr == nil {
writeCall := file.EXPECT().WriteString(tc.writtenData).
Return(0, tc.writeErr).After(finalReadCall)
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
}
}
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.ResolvConf), name)
assert.Equal(t, os.O_RDWR|os.O_TRUNC, flag)
assert.Equal(t, os.FileMode(0644), perm)
return file, tc.openErr
}
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1")
c := &configurator{
openFile: openFile,
logger: logger,
}
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,58 +0,0 @@
package dns
import (
"context"
"fmt"
"io"
"net/http"
"os"
"github.com/qdm12/gluetun/internal/constants"
)
func (c *configurator) DownloadRootHints(ctx context.Context, puid, pgid int) error {
return c.downloadAndSave(ctx, "root hints",
string(constants.NamedRootURL), string(constants.RootHints), puid, pgid)
}
func (c *configurator) DownloadRootKey(ctx context.Context, puid, pgid int) error {
return c.downloadAndSave(ctx, "root key",
string(constants.RootKeyURL), string(constants.RootKey), puid, pgid)
}
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, puid, pgid int) error {
c.logger.Info("downloading %s from %s", logName, url)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
response, err := c.client.Do(req)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
}
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
if err != nil {
return err
}
_, err = io.Copy(file, response.Body)
if err != nil {
_ = file.Close()
return err
}
err = file.Chown(puid, pgid)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -1,195 +0,0 @@
package dns
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/os/mock_os"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_downloadAndSave(t *testing.T) {
t.Parallel()
const defaultURL = "https://test.com"
tests := map[string]struct {
url string // to trigger a new request error
content []byte
status int
clientErr error
openErr error
writeErr error
chownErr error
closeErr error
err error
}{
"no data": {
url: defaultURL,
status: http.StatusOK,
},
"bad status": {
url: defaultURL,
status: http.StatusBadRequest,
err: fmt.Errorf("bad HTTP status from %s: Bad Request", defaultURL),
},
"client error": {
url: defaultURL,
clientErr: fmt.Errorf("error"),
err: fmt.Errorf("Get %q: error", defaultURL),
},
"open error": {
url: defaultURL,
status: http.StatusOK,
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"chown error": {
url: defaultURL,
status: http.StatusOK,
chownErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"close error": {
url: defaultURL,
status: http.StatusOK,
closeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": {
url: defaultURL,
content: []byte("content"),
status: http.StatusOK,
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root hints", tc.url)
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, tc.url, r.URL.String())
if tc.clientErr != nil {
return nil, tc.clientErr
}
return &http.Response{
StatusCode: tc.status,
Status: http.StatusText(tc.status),
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
}, nil
}),
}
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
return nil, nil
}
const filepath = "/test"
if tc.clientErr == nil && tc.status == http.StatusOK {
file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil {
if len(tc.content) > 0 {
file.EXPECT().
Write(tc.content).
Return(len(tc.content), tc.writeErr)
}
file.EXPECT().
Close().
Return(tc.closeErr)
file.EXPECT().
Chown(1000, 1000).
Return(tc.chownErr)
}
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, filepath, name)
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
assert.Equal(t, os.FileMode(0400), perm)
return file, tc.openErr
}
}
c := &configurator{
logger: logger,
client: client,
openFile: openFile,
}
err := c.downloadAndSave(ctx, "root hints",
tc.url, filepath,
1000, 1000)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func Test_DownloadRootHints(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, string(constants.NamedRootURL), r.URL.String())
return nil, errors.New("test")
}),
}
c := &configurator{
logger: logger,
client: client,
}
err := c.DownloadRootHints(ctx, 1000, 1000)
require.Error(t, err)
assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated": test`, err.Error())
}
func Test_DownloadRootKey(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
client := &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, string(constants.RootKeyURL), r.URL.String())
return nil, errors.New("test")
}),
}
c := &configurator{
logger: logger,
client: client,
}
err := c.DownloadRootKey(ctx, 1000, 1000)
require.Error(t, err)
assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated": test`, err.Error())
}

View File

@@ -1,9 +0,0 @@
package dns
import "net/http"
type roundTripFunc func(r *http.Request) (*http.Response, error)
func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return s(r)
}

View File

@@ -1,28 +0,0 @@
package dns
import (
"fmt"
"time"
)
func (c *configurator) WaitForUnbound() (err error) {
const hostToResolve = "github.com"
waitDurations := [...]time.Duration{
300 * time.Millisecond,
100 * time.Millisecond,
300 * time.Millisecond,
500 * time.Millisecond,
time.Second,
2 * time.Second,
}
maxTries := len(waitDurations)
for i, waitDuration := range waitDurations {
time.Sleep(waitDuration)
_, err := c.lookupIP(hostToResolve)
if err == nil {
return nil
}
c.logger.Warn("could not resolve %s (try %d of %d): %s", hostToResolve, i+1, maxTries, err)
}
return fmt.Errorf("Unbound does not seem to be working after %d tries", maxTries)
}

View File

@@ -8,8 +8,6 @@ import (
type (
// VPNDevice is the device name used to tunnel using Openvpn.
VPNDevice string
// DNSProvider is a DNS over TLS server provider name.
DNSProvider string
// DNSHost is the DNS host to use for TLS validation.
DNSHost string
// URL is an HTTP(s) URL address.

View File

@@ -4,8 +4,9 @@ import "net"
// DNSProviderData contains information for a DNS provider.
type DNSProviderData struct {
IPs []net.IP
SupportsTLS bool
SupportsIPv6 bool
Host DNSHost
IPs []net.IP
SupportsTLS bool
SupportsIPv6 bool
SupportsDNSSec bool
Host DNSHost
}

View File

@@ -6,8 +6,7 @@ import (
"strings"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
dns "github.com/qdm12/dns/pkg/unbound"
libparams "github.com/qdm12/golibs/params"
)
@@ -19,20 +18,17 @@ func (r *reader) GetDNSOverTLS() (DNSOverTLS bool, err error) { //nolint:gocriti
// GetDNSOverTLSProviders obtains the DNS over TLS providers to use
// from the environment variable DOT_PROVIDERS.
func (r *reader) GetDNSOverTLSProviders() (providers []models.DNSProvider, err error) {
func (r *reader) GetDNSOverTLSProviders() (providers []string, err error) {
s, err := r.envParams.GetEnv("DOT_PROVIDERS", libparams.Default("cloudflare"))
if err != nil {
return nil, err
}
for _, word := range strings.Split(s, ",") {
provider := models.DNSProvider(word)
switch provider {
case constants.Cloudflare, constants.Google, constants.Quad9,
constants.Quadrant, constants.CleanBrowsing:
providers = append(providers, provider)
default:
for _, provider := range strings.Split(s, ",") {
_, ok := dns.GetProviderData(provider)
if !ok {
return nil, fmt.Errorf("DNS over TLS provider %q is not valid", provider)
}
providers = append(providers, provider)
}
return providers, nil
}

View File

@@ -17,7 +17,7 @@ type Reader interface {
// DNS over TLS getters
GetDNSOverTLS() (DNSOverTLS bool, err error)
GetDNSOverTLSProviders() (providers []models.DNSProvider, err error)
GetDNSOverTLSProviders() (providers []string, err error)
GetDNSOverTLSCaching() (caching bool, err error)
GetDNSOverTLSVerbosity() (verbosityLevel uint8, err error)
GetDNSOverTLSVerbosityDetails() (verbosityDetailsLevel uint8, err error)

View File

@@ -6,38 +6,28 @@ import (
"strings"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
unboundmodels "github.com/qdm12/dns/pkg/models"
unbound "github.com/qdm12/dns/pkg/unbound"
"github.com/qdm12/gluetun/internal/params"
)
// DNS contains settings to configure Unbound for DNS over TLS operation.
type DNS struct {
Enabled bool
KeepNameserver bool
Providers []models.DNSProvider
PlaintextAddress net.IP
AllowedHostnames []string
PrivateAddresses []string
Caching bool
BlockMalicious bool
BlockSurveillance bool
BlockAds bool
VerbosityLevel uint8
VerbosityDetailsLevel uint8
ValidationLogLevel uint8
IPv6 bool
UpdatePeriod time.Duration
type DNS struct { //nolint:maligned
Enabled bool
PlaintextAddress net.IP
KeepNameserver bool
BlockMalicious bool
BlockAds bool
BlockSurveillance bool
UpdatePeriod time.Duration
Unbound unboundmodels.Settings
}
func (d *DNS) String() string {
if !d.Enabled {
return fmt.Sprintf("DNS over TLS disabled, using plaintext DNS %s", d.PlaintextAddress)
}
caching, blockMalicious, blockSurveillance, blockAds, ipv6 := disabled, disabled, disabled, disabled, disabled
if d.Caching {
caching = enabled
}
blockMalicious, blockSurveillance, blockAds := disabled, disabled, disabled
if d.BlockMalicious {
blockMalicious = enabled
}
@@ -47,13 +37,6 @@ func (d *DNS) String() string {
if d.BlockAds {
blockAds = enabled
}
if d.IPv6 {
ipv6 = enabled
}
providersStr := make([]string, len(d.Providers))
for i := range d.Providers {
providersStr[i] = string(d.Providers[i])
}
update := "deactivated"
if d.UpdatePeriod > 0 {
update = fmt.Sprintf("every %s", d.UpdatePeriod)
@@ -63,20 +46,13 @@ func (d *DNS) String() string {
keepNameserver = "yes"
}
settingsList := []string{
"DNS over TLS settings:",
"DNS over TLS provider:\n |--" + strings.Join(providersStr, "\n |--"),
"Caching: " + caching,
"DNS settings:",
"Block malicious: " + blockMalicious,
"Block surveillance: " + blockSurveillance,
"Block ads: " + blockAds,
"Allowed hostnames:\n |--" + strings.Join(d.AllowedHostnames, "\n |--"),
"Private addresses:\n |--" + strings.Join(d.PrivateAddresses, "\n |--"),
"Verbosity level: " + fmt.Sprintf("%d/5", d.VerbosityLevel),
"Verbosity details level: " + fmt.Sprintf("%d/4", d.VerbosityDetailsLevel),
"Validation log level: " + fmt.Sprintf("%d/2", d.ValidationLogLevel),
"IPv6 resolution: " + ipv6,
"Update: " + update,
"Keep nameserver (disabled blocking): " + keepNameserver,
"Unbound settings: " + "\n |--" + strings.Join(d.Unbound.Lines(), "\n |--"),
}
return strings.Join(settingsList, "\n |--")
}
@@ -87,22 +63,18 @@ func GetDNSSettings(paramsReader params.Reader) (settings DNS, err error) {
if err != nil {
return settings, err
}
if !settings.Enabled {
settings.PlaintextAddress, err = paramsReader.GetDNSPlaintext()
return settings, err
}
settings.Providers, err = paramsReader.GetDNSOverTLSProviders()
// Plain DNS settings
settings.PlaintextAddress, err = paramsReader.GetDNSPlaintext()
if err != nil {
return settings, err
}
settings.AllowedHostnames, err = paramsReader.GetDNSUnblockedHostnames()
if err != nil {
return settings, err
}
settings.Caching, err = paramsReader.GetDNSOverTLSCaching()
settings.KeepNameserver, err = paramsReader.GetDNSKeepNameserver()
if err != nil {
return settings, err
}
// DNS over TLS external settings
settings.BlockMalicious, err = paramsReader.GetDNSMaliciousBlocking()
if err != nil {
return settings, err
@@ -115,50 +87,71 @@ func GetDNSSettings(paramsReader params.Reader) (settings DNS, err error) {
if err != nil {
return settings, err
}
settings.VerbosityLevel, err = paramsReader.GetDNSOverTLSVerbosity()
if err != nil {
return settings, err
}
settings.VerbosityDetailsLevel, err = paramsReader.GetDNSOverTLSVerbosityDetails()
if err != nil {
return settings, err
}
settings.ValidationLogLevel, err = paramsReader.GetDNSOverTLSValidationLogLevel()
if err != nil {
return settings, err
}
settings.PrivateAddresses, err = paramsReader.GetDNSOverTLSPrivateAddresses()
if err != nil {
return settings, err
}
settings.IPv6, err = paramsReader.GetDNSOverTLSIPv6()
if err != nil {
return settings, err
}
settings.UpdatePeriod, err = paramsReader.GetDNSUpdatePeriod()
if err != nil {
return settings, err
}
settings.KeepNameserver, err = paramsReader.GetDNSKeepNameserver()
// Unbound specific settings
settings.Unbound, err = getUnboundSettings(paramsReader)
if err != nil {
return settings, err
}
// Consistency check
IPv6Support := false
for _, provider := range settings.Providers {
providerData, ok := constants.DNSProviderMapping()[provider]
for _, provider := range settings.Unbound.Providers {
providerData, ok := unbound.GetProviderData(provider)
switch {
case !ok:
return settings, fmt.Errorf("DNS provider %q does not have associated data", provider)
case !providerData.SupportsTLS:
case providerData.SupportsTLS:
return settings, fmt.Errorf("DNS provider %q does not support DNS over TLS", provider)
case providerData.SupportsIPv6:
IPv6Support = true
}
}
if settings.IPv6 && !IPv6Support {
if settings.Unbound.IPv6 && !IPv6Support {
return settings, fmt.Errorf("None of the DNS over TLS provider(s) set support IPv6")
}
return settings, nil
}
func getUnboundSettings(reader params.Reader) (settings unboundmodels.Settings, err error) {
settings.Providers, err = reader.GetDNSOverTLSProviders()
if err != nil {
return settings, err
}
settings.ListeningPort = 53
settings.Caching, err = reader.GetDNSOverTLSCaching()
if err != nil {
return settings, err
}
settings.IPv4 = true
settings.IPv6, err = reader.GetDNSOverTLSIPv6()
if err != nil {
return settings, err
}
settings.VerbosityLevel, err = reader.GetDNSOverTLSVerbosity()
if err != nil {
return settings, err
}
settings.VerbosityDetailsLevel, err = reader.GetDNSOverTLSVerbosityDetails()
if err != nil {
return settings, err
}
settings.ValidationLogLevel, err = reader.GetDNSOverTLSValidationLogLevel()
if err != nil {
return settings, err
}
settings.BlockedHostnames = []string{}
settings.BlockedIPs, err = reader.GetDNSOverTLSPrivateAddresses()
if err != nil {
return settings, err
}
settings.AllowedHostnames, err = reader.GetDNSUnblockedHostnames()
if err != nil {
return settings, err
}
return settings, nil
}