feat(healthcheck): add HEALTH_SMALL_CHECK_TYPE option which can be dns or icmp (default)

Note if icmp is not permitted, it fallsback to dns anyway
This commit is contained in:
Quentin McGaw
2025-11-20 15:05:38 +00:00
parent 815fcdb711
commit 9e5624d32b
6 changed files with 76 additions and 28 deletions

View File

@@ -165,6 +165,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \ HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \
HEALTH_TARGET_ADDRESSES=cloudflare.com:443,github.com:443 \ HEALTH_TARGET_ADDRESSES=cloudflare.com:443,github.com:443 \
HEALTH_ICMP_TARGET_IPS=1.1.1.1,8.8.8.8 \ HEALTH_ICMP_TARGET_IPS=1.1.1.1,8.8.8.8 \
HEALTH_SMALL_CHECK_TYPE=icmp \
HEALTH_RESTART_VPN=on \ HEALTH_RESTART_VPN=on \
# DNS # DNS
DNS_SERVER=on \ DNS_SERVER=on \

View File

@@ -29,6 +29,10 @@ type Health struct {
// although this can be less reliable. It defaults to [1.1.1.1,8.8.8.8], // although this can be less reliable. It defaults to [1.1.1.1,8.8.8.8],
// and cannot be left empty in the internal state. // and cannot be left empty in the internal state.
ICMPTargetIPs []netip.Addr ICMPTargetIPs []netip.Addr
// SmallCheckType is the type of small health check to perform.
// It can be "icmp" or "dns", and defaults to "icmp".
// Note it changes automatically to dns if icmp is not supported.
SmallCheckType string
// RestartVPN indicates whether to restart the VPN connection // RestartVPN indicates whether to restart the VPN connection
// when the healthcheck fails. // when the healthcheck fails.
RestartVPN *bool RestartVPN *bool
@@ -37,6 +41,7 @@ type Health struct {
var ( var (
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid") ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible") ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
) )
func (h Health) Validate() (err error) { func (h Health) Validate() (err error) {
@@ -55,6 +60,11 @@ func (h Health) Validate() (err error) {
} }
} }
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
if err != nil {
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err)
}
return nil return nil
} }
@@ -63,6 +73,7 @@ func (h *Health) copy() (copied Health) {
ServerAddress: h.ServerAddress, ServerAddress: h.ServerAddress,
TargetAddresses: h.TargetAddresses, TargetAddresses: h.TargetAddresses,
ICMPTargetIPs: gosettings.CopySlice(h.ICMPTargetIPs), ICMPTargetIPs: gosettings.CopySlice(h.ICMPTargetIPs),
SmallCheckType: h.SmallCheckType,
RestartVPN: gosettings.CopyPointer(h.RestartVPN), RestartVPN: gosettings.CopyPointer(h.RestartVPN),
} }
} }
@@ -74,6 +85,7 @@ func (h *Health) OverrideWith(other Health) {
h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress) h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress)
h.TargetAddresses = gosettings.OverrideWithSlice(h.TargetAddresses, other.TargetAddresses) h.TargetAddresses = gosettings.OverrideWithSlice(h.TargetAddresses, other.TargetAddresses)
h.ICMPTargetIPs = gosettings.OverrideWithSlice(h.ICMPTargetIPs, other.ICMPTargetIPs) h.ICMPTargetIPs = gosettings.OverrideWithSlice(h.ICMPTargetIPs, other.ICMPTargetIPs)
h.SmallCheckType = gosettings.OverrideWithComparable(h.SmallCheckType, other.SmallCheckType)
h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN) h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN)
} }
@@ -84,6 +96,7 @@ func (h *Health) SetDefaults() {
netip.AddrFrom4([4]byte{1, 1, 1, 1}), netip.AddrFrom4([4]byte{1, 1, 1, 1}),
netip.AddrFrom4([4]byte{8, 8, 8, 8}), netip.AddrFrom4([4]byte{8, 8, 8, 8}),
}) })
h.SmallCheckType = gosettings.DefaultComparable(h.SmallCheckType, "icmp")
h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true) h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true)
} }
@@ -98,13 +111,19 @@ func (h Health) toLinesNode() (node *gotree.Node) {
for _, targetAddr := range h.TargetAddresses { for _, targetAddr := range h.TargetAddresses {
targetAddrs.Append(targetAddr) targetAddrs.Append(targetAddr)
} }
if len(h.ICMPTargetIPs) == 1 && h.ICMPTargetIPs[0].IsUnspecified() { switch h.SmallCheckType {
node.Appendf("ICMP target IP: VPN server IP address") case "icmp":
} else { icmpNode := node.Appendf("Small health check type: ICMP echo request")
icmpIPs := node.Appendf("ICMP target IPs:") if len(h.ICMPTargetIPs) == 1 && h.ICMPTargetIPs[0].IsUnspecified() {
for _, ip := range h.ICMPTargetIPs { icmpNode.Appendf("ICMP target IP: VPN server IP address")
icmpIPs.Append(ip.String()) } else {
icmpIPs := icmpNode.Appendf("ICMP target IPs:")
for _, ip := range h.ICMPTargetIPs {
icmpIPs.Append(ip.String())
}
} }
case "dns":
node.Appendf("Small health check type: Plain DNS lookup over UDP")
} }
node.Appendf("Restart VPN on healthcheck failure: %s", gosettings.BoolToYesNo(h.RestartVPN)) node.Appendf("Restart VPN on healthcheck failure: %s", gosettings.BoolToYesNo(h.RestartVPN))
return node return node
@@ -118,6 +137,7 @@ func (h *Health) Read(r *reader.Reader) (err error) {
if err != nil { if err != nil {
return err return err
} }
h.SmallCheckType = r.String("HEALTH_SMALL_CHECK_TYPE")
h.RestartVPN, err = r.BoolPtr("HEALTH_RESTART_VPN") h.RestartVPN, err = r.BoolPtr("HEALTH_RESTART_VPN")
if err != nil { if err != nil {
return err return err

View File

@@ -60,9 +60,10 @@ func Test_Settings_String(t *testing.T) {
| ├── Target addresses: | ├── Target addresses:
| | ├── cloudflare.com:443 | | ├── cloudflare.com:443
| | └── github.com:443 | | └── github.com:443
| ├── ICMP target IPs: | ├── Small health check type: ICMP echo request
| | ── 1.1.1.1 | | ── ICMP target IPs:
| | └── 8.8.8.8 | | ├── 1.1.1.1
| | └── 8.8.8.8
| └── Restart VPN on healthcheck failure: yes | └── Restart VPN on healthcheck failure: yes
├── Shadowsocks server settings: ├── Shadowsocks server settings:
| └── Enabled: no | └── Enabled: no

View File

@@ -16,16 +16,16 @@ import (
) )
type Checker struct { type Checker struct {
tlsDialAddrs []string tlsDialAddrs []string
dialer *net.Dialer dialer *net.Dialer
echoer *icmp.Echoer echoer *icmp.Echoer
dnsClient *dns.Client dnsClient *dns.Client
logger Logger logger Logger
icmpTargetIPs []netip.Addr icmpTargetIPs []netip.Addr
configMutex sync.Mutex smallCheckType string
configMutex sync.Mutex
icmpNotPermitted bool icmpNotPermitted bool
smallCheckName string
// Internal periodic service signals // Internal periodic service signals
stop context.CancelFunc stop context.CancelFunc
@@ -45,14 +45,17 @@ func NewChecker(logger Logger) *Checker {
} }
} }
// SetConfig sets the TCP+TLS dial addresses and the ICMP echo IP address // SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address
// to target by the [Checker]. // to target and the desired small check type (dns or icmp).
// This function MUST be called before calling [Checker.Start]. // This function MUST be called before calling [Checker.Start].
func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr) { func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr,
smallCheckType string,
) {
c.configMutex.Lock() c.configMutex.Lock()
defer c.configMutex.Unlock() defer c.configMutex.Unlock()
c.tlsDialAddrs = tlsDialAddrs c.tlsDialAddrs = tlsDialAddrs
c.icmpTargetIPs = icmpTargets c.icmpTargetIPs = icmpTargets
c.smallCheckType = smallCheckType
} }
// Start starts the checker by first running a blocking 6s-timed TCP+TLS check, // Start starts the checker by first running a blocking 6s-timed TCP+TLS check,
@@ -63,10 +66,15 @@ func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr) {
// It returns an error if the initial TCP+TLS check fails. // It returns an error if the initial TCP+TLS check fails.
// The Checker has to be ultimately stopped by calling [Checker.Stop]. // The Checker has to be ultimately stopped by calling [Checker.Stop].
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) { func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 { if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" {
panic("call Checker.SetConfig with non empty values before Checker.Start") panic("call Checker.SetConfig with non empty values before Checker.Start")
} }
if c.icmpNotPermitted {
// restore forced check type to dns if icmp was found to be not permitted
c.smallCheckType = smallCheckDNS
}
err = c.startupCheck(ctx) err = c.startupCheck(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("startup check: %w", err) return nil, fmt.Errorf("startup check: %w", err)
@@ -77,7 +85,6 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
c.stop = cancel c.stop = cancel
done := make(chan struct{}) done := make(chan struct{})
c.done = done c.done = done
c.smallCheckName = "ICMP echo"
const smallCheckPeriod = time.Minute const smallCheckPeriod = time.Minute
smallCheckTimer := time.NewTimer(smallCheckPeriod) smallCheckTimer := time.NewTimer(smallCheckPeriod)
const fullCheckPeriod = 5 * time.Minute const fullCheckPeriod = 5 * time.Minute
@@ -119,6 +126,7 @@ func (c *Checker) Stop() error {
<-c.done <-c.done
c.tlsDialAddrs = nil c.tlsDialAddrs = nil
c.icmpTargetIPs = nil c.icmpTargetIPs = nil
c.smallCheckType = ""
return nil return nil
} }
@@ -140,20 +148,21 @@ func (c *Checker) smallPeriodicCheck(ctx context.Context) error {
30 * time.Second, 30 * time.Second,
} }
check := func(ctx context.Context, try int) error { check := func(ctx context.Context, try int) error {
if c.icmpNotPermitted { if c.smallCheckType == smallCheckDNS {
return c.dnsClient.Check(ctx) return c.dnsClient.Check(ctx)
} }
ip := icmpTargetIPs[try%len(icmpTargetIPs)] ip := icmpTargetIPs[try%len(icmpTargetIPs)]
err := c.echoer.Echo(ctx, ip) err := c.echoer.Echo(ctx, ip)
if errors.Is(err, icmp.ErrNotPermitted) { if errors.Is(err, icmp.ErrNotPermitted) {
c.icmpNotPermitted = true c.icmpNotPermitted = true
c.smallCheckName = "plain DNS over UDP" c.smallCheckType = smallCheckDNS
c.logger.Infof("%s; permanently falling back to %s checks.", c.smallCheckName, err) c.logger.Infof("%s; permanently falling back to %s checks",
smallCheckTypeToString(c.smallCheckType), err)
return c.dnsClient.Check(ctx) return c.dnsClient.Check(ctx)
} }
return err return err
} }
return withRetries(ctx, tryTimeouts, c.logger, c.smallCheckName, check) return withRetries(ctx, tryTimeouts, c.logger, smallCheckTypeToString(c.smallCheckType), check)
} }
func (c *Checker) fullPeriodicCheck(ctx context.Context) error { func (c *Checker) fullPeriodicCheck(ctx context.Context) error {
@@ -299,3 +308,19 @@ func (c *Checker) startupCheck(ctx context.Context) error {
} }
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", ")) return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
} }
const (
smallCheckDNS = "dns"
smallCheckICMP = "icmp"
)
func smallCheckTypeToString(smallCheckType string) string {
switch smallCheckType {
case smallCheckICMP:
return "ICMP echo"
case smallCheckDNS:
return "plain DNS over UDP"
default:
panic("unknown small check type: " + smallCheckType)
}
}

View File

@@ -101,7 +101,7 @@ type CmdStarter interface {
} }
type HealthChecker interface { type HealthChecker interface {
SetConfig(tlsDialAddrs []string, icmpTargetIPs []netip.Addr) SetConfig(tlsDialAddrs []string, icmpTargetIPs []netip.Addr, smallCheckType string)
Start(ctx context.Context) (runError <-chan error, err error) Start(ctx context.Context) (runError <-chan error, err error)
Stop() error Stop() error
} }

View File

@@ -35,7 +35,8 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() { if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() {
icmpTargetIPs = []netip.Addr{data.serverIP} icmpTargetIPs = []netip.Addr{data.serverIP}
} }
l.healthChecker.SetConfig(l.healthSettings.TargetAddresses, icmpTargetIPs) l.healthChecker.SetConfig(l.healthSettings.TargetAddresses, icmpTargetIPs,
l.healthSettings.SmallCheckType)
healthErrCh, err := l.healthChecker.Start(ctx) healthErrCh, err := l.healthChecker.Start(ctx)
l.healthServer.SetError(err) l.healthServer.SetError(err)