DNS_KEEP_NAMESERVER variable, refers to #188

This commit is contained in:
Quentin McGaw
2020-07-11 23:51:53 +00:00
parent 78b63174ce
commit 8b096af04e
9 changed files with 32 additions and 11 deletions

View File

@@ -76,6 +76,7 @@ ENV VPNSP=pia \
UNBLOCK= \ UNBLOCK= \
DNS_UPDATE_PERIOD=24h \ DNS_UPDATE_PERIOD=24h \
DNS_PLAINTEXT_ADDRESS=1.1.1.1 \ DNS_PLAINTEXT_ADDRESS=1.1.1.1 \
DNS_KEEP_NAMESERVER=off \
# Firewall # Firewall
FIREWALL=on \ FIREWALL=on \
EXTRA_SUBNETS= \ EXTRA_SUBNETS= \

View File

@@ -221,6 +221,7 @@ None of the following values are required.
| `BLOCK_ADS` | `off` | `on`, `off` | Block ads hostnames and IPs with Unbound | | `BLOCK_ADS` | `off` | `on`, `off` | Block ads hostnames and IPs with Unbound |
| `UNBLOCK` | |i.e. `domain1.com,x.domain2.co.uk` | Comma separated list of domain names to leave unblocked with Unbound | | `UNBLOCK` | |i.e. `domain1.com,x.domain2.co.uk` | Comma separated list of domain names to leave unblocked with Unbound |
| `DNS_PLAINTEXT_ADDRESS` | `1.1.1.1` | Any IP address | IP address to use as DNS resolver if `DOT` is `off` | | `DNS_PLAINTEXT_ADDRESS` | `1.1.1.1` | Any IP address | IP address to use as DNS resolver if `DOT` is `off` |
| `DNS_KEEP_NAMESERVER` | `off` | `on` or `off` | Keep the nameservers in /etc/resolv.conf untouched, but disabled DNS blocking features |
### Firewall ### Firewall

View File

@@ -17,7 +17,7 @@ type Configurator interface {
DownloadRootKey(uid, gid int) error DownloadRootKey(uid, gid int) error
MakeUnboundConf(settings settings.DNS, uid, gid int) (err error) MakeUnboundConf(settings settings.DNS, uid, gid int) (err error)
UseDNSInternally(IP net.IP) UseDNSInternally(IP net.IP)
UseDNSSystemWide(IP net.IP) error UseDNSSystemWide(ip net.IP, keepNameserver bool) error
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error) Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)
WaitForUnbound() (err error) WaitForUnbound() (err error)
Version(ctx context.Context) (version string, err error) Version(ctx context.Context) (version string, err error)

View File

@@ -104,8 +104,8 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
// Started successfully // Started successfully
go l.streamMerger.Merge(unboundCtx, stream, go l.streamMerger.Merge(unboundCtx, stream,
command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound())) command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, l.settings.KeepNameserver); err != nil { // use Unbound
l.logger.Error(err) l.logger.Error(err)
} }
if err := l.conf.WaitForUnbound(); err != nil { if err := l.conf.WaitForUnbound(); err != nil {
@@ -148,7 +148,7 @@ func (l *looper) fallbackToUnencryptedDNS() {
if targetIP != nil { if targetIP != nil {
l.logger.Info("falling back on plaintext DNS at address %s", targetIP) l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
l.conf.UseDNSInternally(targetIP) l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP); err != nil { if err := l.conf.UseDNSSystemWide(targetIP, l.settings.KeepNameserver); err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
return return
@@ -161,7 +161,7 @@ func (l *looper) fallbackToUnencryptedDNS() {
if targetIP.To4() != nil { if targetIP.To4() != nil {
l.logger.Info("falling back on plaintext DNS at address %s", targetIP) l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
l.conf.UseDNSInternally(targetIP) l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP); err != nil { if err := l.conf.UseDNSSystemWide(targetIP, l.settings.KeepNameserver); err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
return return

View File

@@ -21,7 +21,7 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
} }
// UseDNSSystemWide changes the nameserver to use for DNS system wide // UseDNSSystemWide changes the nameserver to use for DNS system wide
func (c *configurator) UseDNSSystemWide(ip net.IP) error { func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
c.logger.Info("using DNS address %s system wide", ip.String()) c.logger.Info("using DNS address %s system wide", ip.String())
data, err := c.fileManager.ReadFile(string(constants.ResolvConf)) data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
if err != nil { if err != nil {
@@ -33,10 +33,12 @@ func (c *configurator) UseDNSSystemWide(ip net.IP) error {
lines = nil lines = nil
} }
found := false found := false
for i := range lines { if !keepNameserver { // default
if strings.HasPrefix(lines[i], "nameserver ") { for i := range lines {
lines[i] = "nameserver " + ip.String() if strings.HasPrefix(lines[i], "nameserver ") {
found = true lines[i] = "nameserver " + ip.String()
found = true
}
} }
} }
if !found { if !found {

View File

@@ -62,7 +62,7 @@ func Test_UseDNSSystemWide(t *testing.T) {
fileManager: fileManager, fileManager: fileManager,
logger: logger, logger: logger,
} }
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}) err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
if tc.err != nil { if tc.err != nil {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error()) assert.Equal(t, tc.err.Error(), err.Error())

View File

@@ -157,3 +157,9 @@ func (r *reader) GetDNSPlaintext() (ip net.IP, err error) {
} }
return ip, nil return ip, nil
} }
// GetDNSKeepNameserver obtains if the nameserver present in /etc/resolv.conf
// should be kept instead of overridden, from the environment variable DNS_KEEP_NAMESERVER
func (r *reader) GetDNSKeepNameserver() (on bool, err error) {
return r.envParams.GetOnOff("DNS_KEEP_NAMESERVER", libparams.Default("off"))
}

View File

@@ -30,6 +30,7 @@ type Reader interface {
GetDNSOverTLSIPv6() (ipv6 bool, err error) GetDNSOverTLSIPv6() (ipv6 bool, err error)
GetDNSUpdatePeriod() (period time.Duration, err error) GetDNSUpdatePeriod() (period time.Duration, err error)
GetDNSPlaintext() (ip net.IP, err error) GetDNSPlaintext() (ip net.IP, err error)
GetDNSKeepNameserver() (on bool, err error)
// System // System
GetUID() (uid int, err error) GetUID() (uid int, err error)

View File

@@ -14,6 +14,7 @@ import (
// DNS contains settings to configure Unbound for DNS over TLS operation // DNS contains settings to configure Unbound for DNS over TLS operation
type DNS struct { type DNS struct {
Enabled bool Enabled bool
KeepNameserver bool
Providers []models.DNSProvider Providers []models.DNSProvider
PlaintextAddress net.IP PlaintextAddress net.IP
AllowedHostnames []string AllowedHostnames []string
@@ -61,6 +62,10 @@ func (d *DNS) String() string {
if d.UpdatePeriod > 0 { if d.UpdatePeriod > 0 {
update = fmt.Sprintf("every %s", d.UpdatePeriod) update = fmt.Sprintf("every %s", d.UpdatePeriod)
} }
keepNameserver := "no"
if d.KeepNameserver {
keepNameserver = "yes"
}
settingsList := []string{ settingsList := []string{
"DNS over TLS settings:", "DNS over TLS settings:",
"DNS over TLS provider:\n |--" + strings.Join(providersStr, "\n |--"), "DNS over TLS provider:\n |--" + strings.Join(providersStr, "\n |--"),
@@ -75,6 +80,7 @@ func (d *DNS) String() string {
"Validation log level: " + fmt.Sprintf("%d/2", d.ValidationLogLevel), "Validation log level: " + fmt.Sprintf("%d/2", d.ValidationLogLevel),
"IPv6 resolution: " + ipv6, "IPv6 resolution: " + ipv6,
"Update: " + update, "Update: " + update,
"Keep nameserver (disabled blocking): " + keepNameserver,
} }
return strings.Join(settingsList, "\n |--") return strings.Join(settingsList, "\n |--")
} }
@@ -137,6 +143,10 @@ func GetDNSSettings(paramsReader params.Reader) (settings DNS, err error) {
if err != nil { if err != nil {
return settings, err return settings, err
} }
settings.KeepNameserver, err = paramsReader.GetDNSKeepNameserver()
if err != nil {
return settings, err
}
// Consistency check // Consistency check
IPv6Support := false IPv6Support := false