package healthcheck import ( "context" "crypto/tls" "errors" "fmt" "net" "net/netip" "strings" "sync" "time" "github.com/qdm12/gluetun/internal/healthcheck/dns" "github.com/qdm12/gluetun/internal/healthcheck/icmp" ) type Checker struct { tlsDialAddrs []string dialer *net.Dialer echoer *icmp.Echoer dnsClient *dns.Client logger Logger icmpTargetIPs []netip.Addr smallCheckType string configMutex sync.Mutex icmpNotPermitted bool // Internal periodic service signals stop context.CancelFunc done <-chan struct{} } func NewChecker(logger Logger) *Checker { return &Checker{ dialer: &net.Dialer{ Resolver: &net.Resolver{ PreferGo: true, }, }, echoer: icmp.NewEchoer(logger), dnsClient: dns.New(), logger: logger, } } // SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address // to target and the desired small check type (dns or icmp). // This function MUST be called before calling [Checker.Start]. func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr, smallCheckType string, ) { c.configMutex.Lock() defer c.configMutex.Unlock() c.tlsDialAddrs = tlsDialAddrs c.icmpTargetIPs = icmpTargets c.smallCheckType = smallCheckType } // Start starts the checker by first running a blocking 6s-timed TCP+TLS check, // and, on success, starts the periodic checks in a separate goroutine: // - a "small" ICMP echo check every minute // - a "full" TCP+TLS check every 5 minutes // It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed. // It returns an error if the initial TCP+TLS check fails. // The Checker has to be ultimately stopped by calling [Checker.Stop]. func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) { if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" { panic("call Checker.SetConfig with non empty values before Checker.Start") } if c.icmpNotPermitted { // restore forced check type to dns if icmp was found to be not permitted c.smallCheckType = smallCheckDNS } err = c.startupCheck(ctx) if err != nil { return nil, fmt.Errorf("startup check: %w", err) } ready := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) c.stop = cancel done := make(chan struct{}) c.done = done const smallCheckPeriod = time.Minute smallCheckTimer := time.NewTimer(smallCheckPeriod) const fullCheckPeriod = 5 * time.Minute fullCheckTimer := time.NewTimer(fullCheckPeriod) runErrorCh := make(chan error) runError = runErrorCh go func() { defer close(done) close(ready) for { select { case <-ctx.Done(): fullCheckTimer.Stop() smallCheckTimer.Stop() return case <-smallCheckTimer.C: err := c.smallPeriodicCheck(ctx) if err != nil { err = fmt.Errorf("small periodic check: %w", err) } runErrorCh <- err smallCheckTimer.Reset(smallCheckPeriod) case <-fullCheckTimer.C: err := c.fullPeriodicCheck(ctx) if err != nil { err = fmt.Errorf("full periodic check: %w", err) } runErrorCh <- err fullCheckTimer.Reset(fullCheckPeriod) } } }() <-ready return runError, nil } func (c *Checker) Stop() error { c.stop() <-c.done c.tlsDialAddrs = nil c.icmpTargetIPs = nil c.smallCheckType = "" return nil } func (c *Checker) smallPeriodicCheck(ctx context.Context) error { c.configMutex.Lock() icmpTargetIPs := make([]netip.Addr, len(c.icmpTargetIPs)) copy(icmpTargetIPs, c.icmpTargetIPs) c.configMutex.Unlock() tryTimeouts := []time.Duration{ 5 * time.Second, 5 * time.Second, 5 * time.Second, 10 * time.Second, 10 * time.Second, 10 * time.Second, 15 * time.Second, 15 * time.Second, 15 * time.Second, 30 * time.Second, } check := func(ctx context.Context, try int) error { if c.smallCheckType == smallCheckDNS { return c.dnsClient.Check(ctx) } ip := icmpTargetIPs[try%len(icmpTargetIPs)] err := c.echoer.Echo(ctx, ip) if errors.Is(err, icmp.ErrNotPermitted) { c.icmpNotPermitted = true c.smallCheckType = smallCheckDNS c.logger.Infof("%s; permanently falling back to %s checks", smallCheckTypeToString(c.smallCheckType), err) return c.dnsClient.Check(ctx) } return err } return withRetries(ctx, tryTimeouts, c.logger, smallCheckTypeToString(c.smallCheckType), check) } func (c *Checker) fullPeriodicCheck(ctx context.Context) error { // 20s timeout in case the connection is under stress // See https://github.com/qdm12/gluetun/issues/2270 tryTimeouts := []time.Duration{10 * time.Second, 15 * time.Second, 30 * time.Second} check := func(ctx context.Context, try int) error { tlsDialAddr := c.tlsDialAddrs[try%len(c.tlsDialAddrs)] return tcpTLSCheck(ctx, c.dialer, tlsDialAddr) } return withRetries(ctx, tryTimeouts, c.logger, "TCP+TLS dial", check) } func tcpTLSCheck(ctx context.Context, dialer *net.Dialer, targetAddress string) error { // TODO use mullvad API if current provider is Mullvad address, err := makeAddressToDial(targetAddress) if err != nil { return err } const dialNetwork = "tcp4" connection, err := dialer.DialContext(ctx, dialNetwork, address) if err != nil { return fmt.Errorf("dialing: %w", err) } if strings.HasSuffix(address, ":443") { host, _, err := net.SplitHostPort(address) if err != nil { return fmt.Errorf("splitting host and port: %w", err) } tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, ServerName: host, } tlsConnection := tls.Client(connection, tlsConfig) err = tlsConnection.HandshakeContext(ctx) if err != nil { return fmt.Errorf("running TLS handshake: %w", err) } } err = connection.Close() if err != nil { return fmt.Errorf("closing connection: %w", err) } return nil } func makeAddressToDial(address string) (addressToDial string, err error) { host, port, err := net.SplitHostPort(address) if err != nil { addrErr := new(net.AddrError) ok := errors.As(err, &addrErr) if !ok || addrErr.Err != "missing port in address" { return "", fmt.Errorf("splitting host and port from address: %w", err) } host = address const defaultPort = "443" port = defaultPort } address = net.JoinHostPort(host, port) return address, nil } var ErrAllCheckTriesFailed = errors.New("all check tries failed") func withRetries(ctx context.Context, tryTimeouts []time.Duration, logger Logger, checkName string, check func(ctx context.Context, try int) error, ) error { maxTries := len(tryTimeouts) type errData struct { err error duration time.Duration } errs := make([]errData, maxTries) for i, timeout := range tryTimeouts { start := time.Now() checkCtx, cancel := context.WithTimeout(ctx, timeout) err := check(checkCtx, i) cancel() switch { case err == nil: return nil case ctx.Err() != nil: return fmt.Errorf("%s: %w", checkName, ctx.Err()) } logger.Debugf("%s attempt %d/%d failed: %s", checkName, i+1, maxTries, err) errs[i].err = err errs[i].duration = time.Since(start) } errStrings := make([]string, len(errs)) for i, err := range errs { errStrings[i] = fmt.Sprintf("attempt %d (%s): %s", i+1, err.duration, err.err) } return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", ")) } func (c *Checker) startupCheck(ctx context.Context) error { // connection isn't under load yet when the checker starts, so a short // 6 seconds timeout suffices and provides quick enough feedback that // the new connection is not working. However, since the addresses to dial // may be multiple, we run the check in parallel. If any succeeds, the check passes. // This is to prevent false negatives at startup, if one of the addresses is down // for external reasons. const timeout = 6 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() errCh := make(chan error) for _, address := range c.tlsDialAddrs { go func(addr string) { err := tcpTLSCheck(ctx, c.dialer, addr) errCh <- err }(address) } errs := make([]error, 0, len(c.tlsDialAddrs)) success := false for range c.tlsDialAddrs { err := <-errCh if err == nil { success = true cancel() continue } else if success { continue // ignore canceled errors after success } c.logger.Debugf("startup check parallel attempt failed: %s", err) errs = append(errs, err) } if success { return nil } errStrings := make([]string, len(errs)) for i, err := range errs { errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err) } return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", ")) } const ( smallCheckDNS = "dns" smallCheckICMP = "icmp" ) func smallCheckTypeToString(smallCheckType string) string { switch smallCheckType { case smallCheckICMP: return "ICMP echo" case smallCheckDNS: return "plain DNS over UDP" default: panic("unknown small check type: " + smallCheckType) } }