Better checks for user provided private addresses

This commit is contained in:
Quentin McGaw
2020-04-26 13:28:14 +00:00
parent 97ea5f63b8
commit 36424c08ac
3 changed files with 23 additions and 8 deletions

View File

@@ -2,6 +2,7 @@ package params
import (
"fmt"
"net"
"strings"
libparams "github.com/qdm12/golibs/params"
@@ -88,8 +89,7 @@ func (p *reader) GetDNSUnblockedHostnames() (hostnames []string, err error) {
s, err := p.envParams.GetEnv("UNBLOCK")
if err != nil {
return nil, err
}
if len(s) == 0 {
} else if len(s) == 0 {
return nil, nil
}
hostnames = strings.Split(s, ",")
@@ -109,10 +109,22 @@ func (p *reader) GetDNSOverTLSCaching() (caching bool, err error) {
// GetDNSOverTLSPrivateAddresses obtains if Unbound caching should be enable or not
// from the environment variable DOT_PRIVATE_ADDRESS
func (p *reader) GetDNSOverTLSPrivateAddresses() (privateAddresses []string) {
s, _ := p.envParams.GetEnv("DOT_PRIVATE_ADDRESS")
privateAddresses = append(privateAddresses, strings.Split(s, ",")...)
return privateAddresses
func (p *reader) GetDNSOverTLSPrivateAddresses() (privateAddresses []string, err error) {
s, err := p.envParams.GetEnv("DOT_PRIVATE_ADDRESS")
if err != nil {
return nil, err
} else if len(s) == 0 {
return nil, nil
}
privateAddresses = strings.Split(s, ",")
for _, address := range privateAddresses {
ip := net.ParseIP(address)
_, _, err := net.ParseCIDR(address)
if ip == nil && err != nil {
return nil, fmt.Errorf("private address %q is not a valid IP or CIDR range", address)
}
}
return privateAddresses, nil
}
// GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using ipv6 DNS over TLS

View File

@@ -25,7 +25,7 @@ type Reader interface {
GetDNSSurveillanceBlocking() (blocking bool, err error)
GetDNSAdsBlocking() (blocking bool, err error)
GetDNSUnblockedHostnames() (hostnames []string, err error)
GetDNSOverTLSPrivateAddresses() (privateAddresses []string)
GetDNSOverTLSPrivateAddresses() (privateAddresses []string, err error)
GetDNSOverTLSIPv6() (ipv6 bool, err error)
// System

View File

@@ -112,7 +112,10 @@ func GetDNSSettings(paramsReader params.Reader) (settings DNS, err error) {
if err != nil {
return settings, err
}
settings.PrivateAddresses = paramsReader.GetDNSOverTLSPrivateAddresses()
settings.PrivateAddresses, err = paramsReader.GetDNSOverTLSPrivateAddresses()
if err != nil {
return settings, err
}
settings.IPv6, err = paramsReader.GetDNSOverTLSIPv6()
if err != nil {
return settings, err