diff --git a/internal/params/dns.go b/internal/params/dns.go index dd48f0d0..57a30d6c 100644 --- a/internal/params/dns.go +++ b/internal/params/dns.go @@ -2,6 +2,7 @@ package params import ( "fmt" + "net" "strings" libparams "github.com/qdm12/golibs/params" @@ -88,8 +89,7 @@ func (p *reader) GetDNSUnblockedHostnames() (hostnames []string, err error) { s, err := p.envParams.GetEnv("UNBLOCK") if err != nil { return nil, err - } - if len(s) == 0 { + } else if len(s) == 0 { return nil, nil } hostnames = strings.Split(s, ",") @@ -109,10 +109,22 @@ func (p *reader) GetDNSOverTLSCaching() (caching bool, err error) { // GetDNSOverTLSPrivateAddresses obtains if Unbound caching should be enable or not // from the environment variable DOT_PRIVATE_ADDRESS -func (p *reader) GetDNSOverTLSPrivateAddresses() (privateAddresses []string) { - s, _ := p.envParams.GetEnv("DOT_PRIVATE_ADDRESS") - privateAddresses = append(privateAddresses, strings.Split(s, ",")...) - return privateAddresses +func (p *reader) GetDNSOverTLSPrivateAddresses() (privateAddresses []string, err error) { + s, err := p.envParams.GetEnv("DOT_PRIVATE_ADDRESS") + if err != nil { + return nil, err + } else if len(s) == 0 { + return nil, nil + } + privateAddresses = strings.Split(s, ",") + for _, address := range privateAddresses { + ip := net.ParseIP(address) + _, _, err := net.ParseCIDR(address) + if ip == nil && err != nil { + return nil, fmt.Errorf("private address %q is not a valid IP or CIDR range", address) + } + } + return privateAddresses, nil } // GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using ipv6 DNS over TLS diff --git a/internal/params/params.go b/internal/params/params.go index a20ac5b8..7fc0a603 100644 --- a/internal/params/params.go +++ b/internal/params/params.go @@ -25,7 +25,7 @@ type Reader interface { GetDNSSurveillanceBlocking() (blocking bool, err error) GetDNSAdsBlocking() (blocking bool, err error) GetDNSUnblockedHostnames() (hostnames []string, err error) - GetDNSOverTLSPrivateAddresses() (privateAddresses []string) + GetDNSOverTLSPrivateAddresses() (privateAddresses []string, err error) GetDNSOverTLSIPv6() (ipv6 bool, err error) // System diff --git a/internal/settings/dns.go b/internal/settings/dns.go index de5d0762..d92b89db 100644 --- a/internal/settings/dns.go +++ b/internal/settings/dns.go @@ -112,7 +112,10 @@ func GetDNSSettings(paramsReader params.Reader) (settings DNS, err error) { if err != nil { return settings, err } - settings.PrivateAddresses = paramsReader.GetDNSOverTLSPrivateAddresses() + settings.PrivateAddresses, err = paramsReader.GetDNSOverTLSPrivateAddresses() + if err != nil { + return settings, err + } settings.IPv6, err = paramsReader.GetDNSOverTLSIPv6() if err != nil { return settings, err