diff --git a/README.md b/README.md index a2bb3157..4cc2818b 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ - Based on Alpine 3.11 for a small Docker image below 50MB - Supports **Private Internet Access**, **Mullvad** and **Windscribe** servers - DNS over TLS baked in with service provider(s) of your choice -- DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses +- DNS fine blocking of malicious/ads/surveillance hostnames and IP addresses, with live update every 24 hours - Choose the vpn network protocol, `udp` or `tcp` - Built in firewall kill switch to allow traffic only with needed PIA servers and LAN devices - Built in SOCKS5 proxy (Shadowsocks, tunnels TCP+UDP) @@ -268,6 +268,7 @@ It can be useful to mount this file as a volume to read it from other containers A built-in HTTP server listens on port `8000` to modify the state of the container. You have the following routes available: - `http://:8000/openvpn/actions/restart` restarts the openvpn process +- `http://:8000/unbound/actions/restart` re-downloads the DNS files (crypto and block lists) and restarts the unbound process ## FAQ diff --git a/cmd/main.go b/cmd/main.go index 37cf11a1..1af0b8a1 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -249,7 +249,7 @@ func main() { go func() { <-connected.Done() // blocks until openvpn is connected onConnected(ctx, allSettings, logger, dnsConf, fileManager, waiter, - streamMerger, routingConf, defaultInterface, piaConf) + streamMerger, httpServer, routingConf, defaultInterface, piaConf) }() signalsCh := make(chan os.Signal, 1) @@ -354,7 +354,7 @@ func openvpnRunLoop(ctx context.Context, ovpnConf openvpn.Configurator, streamMe func onConnected(ctx context.Context, allSettings settings.Settings, logger logging.Logger, dnsConf dns.Configurator, fileManager files.FileManager, - waiter command.Waiter, streamMerger command.StreamMerger, + waiter command.Waiter, streamMerger command.StreamMerger, httpServer server.Server, routingConf routing.Routing, defaultInterface string, piaConf pia.Configurator, ) { @@ -365,12 +365,7 @@ func onConnected(ctx context.Context, allSettings settings.Settings, } if allSettings.DNS.Enabled { - err := setupUnbound(ctx, logger, dnsConf, allSettings.DNS, allSettings.System.UID, allSettings.System.GID, waiter, streamMerger) - if err != nil { - logger.Error("unbound dns over tls setup: %s", err) - } else { - logger.Info("unbound dns over tls setup: completed") - } + go unboundRunLoop(ctx, logger, dnsConf, allSettings.DNS, allSettings.System.UID, allSettings.System.GID, waiter, streamMerger, httpServer) } ip, err := routingConf.CurrentPublicIP(defaultInterface) @@ -389,53 +384,95 @@ func onConnected(ctx context.Context, allSettings settings.Settings, } } -func setupUnbound(ctx context.Context, logger logging.Logger, dnsConf dns.Configurator, - settings settings.DNS, uid, gid int, - waiter command.Waiter, streamMerger command.StreamMerger, -) (err error) { - ctx, cancel := context.WithCancel(ctx) - defer func() { - if err != nil { - cancel() - } - }() - initialDNSToUse := constants.DNSProviderMapping()[settings.Providers[0]] - var ipToUse net.IP - for _, ipToUse = range initialDNSToUse.IPs { - if settings.IPv6 && ipToUse.To4() == nil { +func fallbackToUnencryptedDNS(dnsConf dns.Configurator, provider models.DNSProvider, ipv6 bool) error { + targetDNS := constants.DNSProviderMapping()[provider] + var targetIP net.IP + for _, targetIP = range targetDNS.IPs { + if ipv6 && targetIP.To4() == nil { break - } else if !settings.IPv6 && ipToUse.To4() != nil { + } else if !ipv6 && targetIP.To4() != nil { break } } - dnsConf.UseDNSInternally(ipToUse) + dnsConf.UseDNSInternally(targetIP) + return dnsConf.UseDNSSystemWide(targetIP) +} + +func unboundRun(ctx, unboundCtx context.Context, unboundCancel context.CancelFunc, dnsConf dns.Configurator, settings settings.DNS, uid, gid int, + streamMerger command.StreamMerger, waiter command.Waiter, httpServer server.Server) (newCtx context.Context, newCancel context.CancelFunc, err error) { if err := dnsConf.DownloadRootHints(uid, gid); err != nil { - return err + return unboundCtx, unboundCancel, err } if err := dnsConf.DownloadRootKey(uid, gid); err != nil { - return err + return unboundCtx, unboundCancel, err } if err := dnsConf.MakeUnboundConf(settings, uid, gid); err != nil { - return err + return unboundCtx, unboundCancel, err } - stream, waitFn, err := dnsConf.Start(ctx, settings.VerbosityDetailsLevel) + unboundCancel() + newCtx, newCancel = context.WithTimeout(ctx, 24*time.Hour) + stream, waitFn, err := dnsConf.Start(newCtx, settings.VerbosityDetailsLevel) if err != nil { - return err + newCancel() + if fallbackErr := fallbackToUnencryptedDNS(dnsConf, settings.Providers[0], settings.IPv6); err != nil { + return newCtx, newCancel, fmt.Errorf("%s: %w", err, fallbackErr) + } + return newCtx, newCancel, err } - waiter.Add(func() error { //nolint:scopelint - err := waitFn() - logger.Error("unbound: %s", err) //nolint:scopelint - return err - }) - go streamMerger.Merge(ctx, stream, command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound())) + go streamMerger.Merge(newCtx, stream, command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound())) dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound if err := dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound - return err + newCancel() + if fallbackErr := fallbackToUnencryptedDNS(dnsConf, settings.Providers[0], settings.IPv6); err != nil { + return newCtx, newCancel, fmt.Errorf("%s: %w", err, fallbackErr) + } + return newCtx, newCancel, err } if err := dnsConf.WaitForUnbound(); err != nil { - return err + newCancel() + if fallbackErr := fallbackToUnencryptedDNS(dnsConf, settings.Providers[0], settings.IPv6); err != nil { + return newCtx, newCancel, fmt.Errorf("%s: %w", err, fallbackErr) + } + return newCtx, newCancel, err + } + // Unbound is up and running at this point + httpServer.SetUnboundRestart(newCancel) + waitErrors := make(chan error) + waiter.Add(func() error { //nolint:scopelint + return <-waitErrors + }) + err = waitFn() + waitErrors <- err + if newCtx.Err() == context.Canceled || newCtx.Err() == context.DeadlineExceeded { + return newCtx, newCancel, nil + } + return newCtx, newCancel, err +} + +func unboundRunLoop(ctx context.Context, logger logging.Logger, dnsConf dns.Configurator, + settings settings.DNS, uid, gid int, + waiter command.Waiter, streamMerger command.StreamMerger, httpServer server.Server, +) { + logger = logger.WithPrefix("unbound dns over tls setup: ") + if err := fallbackToUnencryptedDNS(dnsConf, settings.Providers[0], settings.IPv6); err != nil { + logger.Error(err) + } + unboundCtx, unboundCancel := context.WithCancel(ctx) + defer unboundCancel() + for { + if ctx.Err() == context.Canceled { + logger.Info("shutting down") + break + } + var err error + unboundCtx, unboundCancel, err = unboundRun(ctx, unboundCtx, unboundCancel, dnsConf, settings, uid, gid, streamMerger, waiter, httpServer) + if err != nil { + logger.Error(err) + time.Sleep(10 * time.Second) + continue + } + logger.Info("attempting restart") } - return nil } func setupPortForwarding(logger logging.Logger, piaConf pia.Configurator, settings settings.PIA, uid, gid int) { diff --git a/internal/server/server.go b/internal/server/server.go index 31776bb9..324c2123 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -12,6 +12,7 @@ import ( type Server interface { SetOpenVPNRestart(f func()) + SetUnboundRestart(f func()) Run(ctx context.Context) error } @@ -21,16 +22,22 @@ type server struct { restartOpenvpn func() restartOpenvpnSet context.Context restartOpenvpnSetSignal func() + restartUnbound func() + restartUnboundSet context.Context + restartUnboundSetSignal func() sync.RWMutex } func New(address string, logger logging.Logger) Server { restartOpenvpnSet, restartOpenvpnSetSignal := context.WithCancel(context.Background()) + restartUnboundSet, restartUnboundSetSignal := context.WithCancel(context.Background()) return &server{ address: address, logger: logger.WithPrefix("http server: "), restartOpenvpnSet: restartOpenvpnSet, restartOpenvpnSetSignal: restartOpenvpnSetSignal, + restartUnboundSet: restartUnboundSet, + restartUnboundSetSignal: restartUnboundSetSignal, } } @@ -39,6 +46,10 @@ func (s *server) Run(ctx context.Context) error { s.logger.Warn("restartOpenvpn function is not set, waiting...") <-s.restartOpenvpnSet.Done() } + if s.restartUnboundSet.Err() == nil { + s.logger.Warn("restartUnbound function is not set, waiting...") + <-s.restartUnboundSet.Done() + } server := http.Server{Addr: s.address, Handler: s.makeHandler()} go func() { <-ctx.Done() @@ -61,6 +72,15 @@ func (s *server) SetOpenVPNRestart(f func()) { } } +func (s *server) SetUnboundRestart(f func()) { + s.Lock() + defer s.Unlock() + s.restartUnbound = f + if s.restartUnboundSet.Err() == nil { + s.restartUnboundSetSignal() + } +} + func (s *server) makeHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { s.logger.Info("HTTP %s %s", r.Method, r.RequestURI) @@ -71,6 +91,10 @@ func (s *server) makeHandler() http.HandlerFunc { s.RLock() defer s.RUnlock() s.restartOpenvpn() + case "/unbound/actions/restart": + s.RLock() + defer s.RUnlock() + s.restartUnbound() default: routeDoesNotExist(s.logger, w, r) } @@ -87,4 +111,3 @@ func routeDoesNotExist(logger logging.Logger, w http.ResponseWriter, r *http.Req logger.Error(err) } } -