Files
gluetun/internal/updater/resolver.go
2021-02-12 21:27:26 +00:00

131 lines
2.5 KiB
Go

package updater
import (
"bytes"
"context"
"net"
"sort"
"time"
)
func newResolver(resolverAddress string) *net.Resolver {
d := net.Dialer{}
resolverAddress = net.JoinHostPort(resolverAddress, "53")
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContext(ctx, "udp", resolverAddress)
},
}
}
func newLookupIP(r *net.Resolver) lookupIPFunc {
return func(ctx context.Context, host string) (ips []net.IP, err error) {
addresses, err := r.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
ips = make([]net.IP, len(addresses))
for i := range addresses {
ips[i] = addresses[i].IP
}
return ips, nil
}
}
func parallelResolve(ctx context.Context, lookupIP lookupIPFunc, hosts []string,
repetition int, timeBetween time.Duration, failOnErr bool) (
hostToIPs map[string][]net.IP, warnings []string, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
type result struct {
host string
ips []net.IP
}
results := make(chan result)
defer close(results)
errors := make(chan error)
defer close(errors)
for _, host := range hosts {
go func(host string) {
ips, err := resolveRepeat(ctx, lookupIP, host, repetition, timeBetween)
if err != nil {
errors <- err
return
}
results <- result{
host: host,
ips: ips,
}
}(host)
}
hostToIPs = make(map[string][]net.IP, len(hosts))
for range hosts {
select {
case newErr := <-errors:
if !failOnErr {
warnings = append(warnings, newErr.Error())
} else if err == nil {
err = newErr
cancel()
}
case r := <-results:
hostToIPs[r.host] = r.ips
}
}
return hostToIPs, warnings, err
}
func resolveRepeat(ctx context.Context, lookupIP lookupIPFunc, host string,
repetition int, timeBetween time.Duration) (ips []net.IP, err error) {
uniqueIPs := make(map[string]struct{})
i := 0
for {
newIPs, err := lookupIP(ctx, host)
if err != nil {
return nil, err
}
for _, ip := range newIPs {
key := ip.String()
uniqueIPs[key] = struct{}{}
}
i++
if i == repetition {
break
}
timer := time.NewTimer(timeBetween)
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return nil, ctx.Err()
}
}
ips = make([]net.IP, 0, len(uniqueIPs))
for key := range uniqueIPs {
ip := net.ParseIP(key)
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
}
ips = append(ips, ip)
}
sort.Slice(ips, func(i, j int) bool {
return bytes.Compare(ips[i], ips[j]) < 1
})
return ips, nil
}