diff --git a/cmd/main.go b/cmd/main.go index ca361933..378cf0b8 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "time" "github.com/qdm12/golibs/command" @@ -75,6 +76,8 @@ func main() { e.FatalOnError(err) if allSettings.DNS.Enabled { + initialDNSToUse := constants.DNSProviderMapping()[allSettings.DNS.Providers[0]] + dnsConf.UseDNSInternally(initialDNSToUse.IPs[0]) err = dnsConf.DownloadRootHints(uid, gid) e.FatalOnError(err) err = dnsConf.DownloadRootKey(uid, gid) @@ -84,7 +87,10 @@ func main() { stream, err := dnsConf.Start(allSettings.DNS.VerbosityDetailsLevel) e.FatalOnError(err) go streamMerger.Merge("unbound", stream) - err = dnsConf.SetLocalNameserver() + dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound + err = dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}) // use Unbound + e.FatalOnError(err) + err = dnsConf.WaitForUnbound() e.FatalOnError(err) } diff --git a/go.sum b/go.sum index 3d77f9f2..1aa2e130 100644 --- a/go.sum +++ b/go.sum @@ -44,8 +44,8 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/qdm12/golibs v0.0.0-20200208143139-ccf3a2b4be96 h1:TAhod+kjxTQHdqMNnDkg6IbLRrGovtI+3l1Isjd1wbI= -github.com/qdm12/golibs v0.0.0-20200208143139-ccf3a2b4be96/go.mod h1:YULaFjj6VGmhjak6f35sUWwEleHUmngN5IQ3kdvd6XE= +github.com/qdm12/golibs v0.0.0-20200208153322-66b2eb719e21 h1:Nza/Ar6tPYhDzkiNzbaJZHl4+GUXTqbtjGXuWenkqpQ= +github.com/qdm12/golibs v0.0.0-20200208153322-66b2eb719e21/go.mod h1:YULaFjj6VGmhjak6f35sUWwEleHUmngN5IQ3kdvd6XE= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 6b637f5f..f2d52159 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -2,6 +2,7 @@ package dns import ( "io" + "net" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/files" @@ -16,8 +17,10 @@ type Configurator interface { DownloadRootHints(uid, gid int) error DownloadRootKey(uid, gid int) error MakeUnboundConf(settings settings.DNS, uid, gid int) (err error) - SetLocalNameserver() error + UseDNSInternally(IP net.IP) + UseDNSSystemWide(IP net.IP) error Start(logLevel uint8) (stdout io.ReadCloser, err error) + WaitForUnbound() (err error) Version() (version string, err error) } @@ -26,6 +29,7 @@ type configurator struct { client network.Client fileManager files.FileManager commander command.Commander + lookupIP func(host string) ([]net.IP, error) } func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator { @@ -34,5 +38,6 @@ func NewConfigurator(logger logging.Logger, client network.Client, fileManager f client: client, fileManager: fileManager, commander: command.NewCommander(), + lookupIP: net.LookupIP, } } diff --git a/internal/dns/nameserver.go b/internal/dns/nameserver.go new file mode 100644 index 00000000..bddeec13 --- /dev/null +++ b/internal/dns/nameserver.go @@ -0,0 +1,47 @@ +package dns + +import ( + "context" + "net" + "strings" + + "github.com/qdm12/private-internet-access-docker/internal/constants" +) + +// UseDNSInternally is to change the Go program DNS only +func (c *configurator) UseDNSInternally(IP net.IP) { + c.logger.Info("%s: using DNS address %s internally", logPrefix, IP.String()) + net.DefaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "udp", net.JoinHostPort(IP.String(), "53")) + }, + } +} + +// UseDNSSystemWide changes the nameserver to use for DNS system wide +func (c *configurator) UseDNSSystemWide(IP net.IP) error { + c.logger.Info("%s: using DNS address %s system wide", logPrefix, IP.String()) + data, err := c.fileManager.ReadFile(string(constants.ResolvConf)) + if err != nil { + return err + } + s := strings.TrimSuffix(string(data), "\n") + lines := strings.Split(s, "\n") + if len(lines) == 1 && lines[0] == "" { + lines = nil + } + found := false + for i := range lines { + if strings.HasPrefix(lines[i], "nameserver ") { + lines[i] = "nameserver " + IP.String() + found = true + } + } + if !found { + lines = append(lines, "nameserver "+IP.String()) + } + data = []byte(strings.Join(lines, "\n")) + return c.fileManager.WriteToFile(string(constants.ResolvConf), data) +} diff --git a/internal/dns/os_test.go b/internal/dns/nameserver_test.go similarity index 90% rename from internal/dns/os_test.go rename to internal/dns/nameserver_test.go index f6201d1f..0dcd9768 100644 --- a/internal/dns/os_test.go +++ b/internal/dns/nameserver_test.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + "net" "testing" filesmocks "github.com/qdm12/golibs/files/mocks" @@ -11,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func Test_SetLocalNameserver(t *testing.T) { +func Test_UseDNSSystemWide(t *testing.T) { t.Parallel() tests := map[string]struct { data []byte @@ -53,12 +54,12 @@ func Test_SetLocalNameserver(t *testing.T) { Return(tc.writeErr).Once() } logger := &loggingmocks.Logger{} - logger.On("Info", "%s: setting local nameserver to 127.0.0.1", logPrefix).Once() + logger.On("Info", "%s: using DNS address %s system wide", logPrefix, "127.0.0.1").Once() c := &configurator{ fileManager: fileManager, logger: logger, } - err := c.SetLocalNameserver() + err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}) if tc.err != nil { require.Error(t, err) assert.Equal(t, tc.err.Error(), err.Error()) diff --git a/internal/dns/os.go b/internal/dns/os.go index 6a0afda2..7ab6b4ed 100644 --- a/internal/dns/os.go +++ b/internal/dns/os.go @@ -1,13 +1,14 @@ package dns import ( + "net" "strings" "github.com/qdm12/private-internet-access-docker/internal/constants" ) -func (c *configurator) SetLocalNameserver() error { - c.logger.Info("%s: setting local nameserver to 127.0.0.1", logPrefix) +func (c *configurator) SetNameserver(IP net.IP) error { + c.logger.Info("%s: setting local nameserver to %s", logPrefix, IP.String()) data, err := c.fileManager.ReadFile(string(constants.ResolvConf)) if err != nil { return err @@ -20,12 +21,12 @@ func (c *configurator) SetLocalNameserver() error { found := false for i := range lines { if strings.HasPrefix(lines[i], "nameserver ") { - lines[i] = "nameserver 127.0.0.1" + lines[i] = "nameserver " + IP.String() found = true } } if !found { - lines = append(lines, "nameserver 127.0.0.1") + lines = append(lines, "nameserver "+IP.String()) } data = []byte(strings.Join(lines, "\n")) return c.fileManager.WriteToFile(string(constants.ResolvConf), data) diff --git a/internal/dns/wait.go b/internal/dns/wait.go new file mode 100644 index 00000000..62e6da1e --- /dev/null +++ b/internal/dns/wait.go @@ -0,0 +1,20 @@ +package dns + +import ( + "fmt" + "time" +) + +func (c *configurator) WaitForUnbound() (err error) { + const maxTries = 10 + const hostToResolve = "github.com" + for try := 1; try <= maxTries; try++ { + _, err := c.lookupIP(hostToResolve) + if err == nil { + return nil + } + c.logger.Warn("could not resolve %s (try %d of %d)", hostToResolve, try, maxTries) + time.Sleep(time.Duration(maxTries * 50 * time.Millisecond)) + } + return fmt.Errorf("Unbound does not seem to be working after %d tries", maxTries) +}