diff --git a/Dockerfile b/Dockerfile index d6e6c07c..d6e2a984 100644 --- a/Dockerfile +++ b/Dockerfile @@ -165,6 +165,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \ HEALTH_TARGET_ADDRESSES=cloudflare.com:443,github.com:443 \ HEALTH_ICMP_TARGET_IPS=1.1.1.1,8.8.8.8 \ + HEALTH_SMALL_CHECK_TYPE=icmp \ HEALTH_RESTART_VPN=on \ # DNS DNS_SERVER=on \ diff --git a/internal/configuration/settings/health.go b/internal/configuration/settings/health.go index e2ac010e..55756780 100644 --- a/internal/configuration/settings/health.go +++ b/internal/configuration/settings/health.go @@ -29,6 +29,10 @@ type Health struct { // although this can be less reliable. It defaults to [1.1.1.1,8.8.8.8], // and cannot be left empty in the internal state. ICMPTargetIPs []netip.Addr + // SmallCheckType is the type of small health check to perform. + // It can be "icmp" or "dns", and defaults to "icmp". + // Note it changes automatically to dns if icmp is not supported. + SmallCheckType string // RestartVPN indicates whether to restart the VPN connection // when the healthcheck fails. RestartVPN *bool @@ -37,6 +41,7 @@ type Health struct { var ( ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid") ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible") + ErrSmallCheckTypeNotValid = errors.New("small check type is not valid") ) func (h Health) Validate() (err error) { @@ -55,6 +60,11 @@ func (h Health) Validate() (err error) { } } + err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns") + if err != nil { + return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err) + } + return nil } @@ -63,6 +73,7 @@ func (h *Health) copy() (copied Health) { ServerAddress: h.ServerAddress, TargetAddresses: h.TargetAddresses, ICMPTargetIPs: gosettings.CopySlice(h.ICMPTargetIPs), + SmallCheckType: h.SmallCheckType, RestartVPN: gosettings.CopyPointer(h.RestartVPN), } } @@ -74,6 +85,7 @@ func (h *Health) OverrideWith(other Health) { h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress) h.TargetAddresses = gosettings.OverrideWithSlice(h.TargetAddresses, other.TargetAddresses) h.ICMPTargetIPs = gosettings.OverrideWithSlice(h.ICMPTargetIPs, other.ICMPTargetIPs) + h.SmallCheckType = gosettings.OverrideWithComparable(h.SmallCheckType, other.SmallCheckType) h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN) } @@ -84,6 +96,7 @@ func (h *Health) SetDefaults() { netip.AddrFrom4([4]byte{1, 1, 1, 1}), netip.AddrFrom4([4]byte{8, 8, 8, 8}), }) + h.SmallCheckType = gosettings.DefaultComparable(h.SmallCheckType, "icmp") h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true) } @@ -98,13 +111,19 @@ func (h Health) toLinesNode() (node *gotree.Node) { for _, targetAddr := range h.TargetAddresses { targetAddrs.Append(targetAddr) } - if len(h.ICMPTargetIPs) == 1 && h.ICMPTargetIPs[0].IsUnspecified() { - node.Appendf("ICMP target IP: VPN server IP address") - } else { - icmpIPs := node.Appendf("ICMP target IPs:") - for _, ip := range h.ICMPTargetIPs { - icmpIPs.Append(ip.String()) + switch h.SmallCheckType { + case "icmp": + icmpNode := node.Appendf("Small health check type: ICMP echo request") + if len(h.ICMPTargetIPs) == 1 && h.ICMPTargetIPs[0].IsUnspecified() { + icmpNode.Appendf("ICMP target IP: VPN server IP address") + } else { + icmpIPs := icmpNode.Appendf("ICMP target IPs:") + for _, ip := range h.ICMPTargetIPs { + icmpIPs.Append(ip.String()) + } } + case "dns": + node.Appendf("Small health check type: Plain DNS lookup over UDP") } node.Appendf("Restart VPN on healthcheck failure: %s", gosettings.BoolToYesNo(h.RestartVPN)) return node @@ -118,6 +137,7 @@ func (h *Health) Read(r *reader.Reader) (err error) { if err != nil { return err } + h.SmallCheckType = r.String("HEALTH_SMALL_CHECK_TYPE") h.RestartVPN, err = r.BoolPtr("HEALTH_RESTART_VPN") if err != nil { return err diff --git a/internal/configuration/settings/settings_test.go b/internal/configuration/settings/settings_test.go index 454da841..4f051877 100644 --- a/internal/configuration/settings/settings_test.go +++ b/internal/configuration/settings/settings_test.go @@ -60,9 +60,10 @@ func Test_Settings_String(t *testing.T) { | ├── Target addresses: | | ├── cloudflare.com:443 | | └── github.com:443 -| ├── ICMP target IPs: -| | ├── 1.1.1.1 -| | └── 8.8.8.8 +| ├── Small health check type: ICMP echo request +| | └── ICMP target IPs: +| | ├── 1.1.1.1 +| | └── 8.8.8.8 | └── Restart VPN on healthcheck failure: yes ├── Shadowsocks server settings: | └── Enabled: no diff --git a/internal/healthcheck/checker.go b/internal/healthcheck/checker.go index 9404360b..f623d875 100644 --- a/internal/healthcheck/checker.go +++ b/internal/healthcheck/checker.go @@ -16,16 +16,16 @@ import ( ) type Checker struct { - tlsDialAddrs []string - dialer *net.Dialer - echoer *icmp.Echoer - dnsClient *dns.Client - logger Logger - icmpTargetIPs []netip.Addr - configMutex sync.Mutex + tlsDialAddrs []string + dialer *net.Dialer + echoer *icmp.Echoer + dnsClient *dns.Client + logger Logger + icmpTargetIPs []netip.Addr + smallCheckType string + configMutex sync.Mutex icmpNotPermitted bool - smallCheckName string // Internal periodic service signals stop context.CancelFunc @@ -45,14 +45,17 @@ func NewChecker(logger Logger) *Checker { } } -// SetConfig sets the TCP+TLS dial addresses and the ICMP echo IP address -// to target by the [Checker]. +// SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address +// to target and the desired small check type (dns or icmp). // This function MUST be called before calling [Checker.Start]. -func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr) { +func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr, + smallCheckType string, +) { c.configMutex.Lock() defer c.configMutex.Unlock() c.tlsDialAddrs = tlsDialAddrs c.icmpTargetIPs = icmpTargets + c.smallCheckType = smallCheckType } // Start starts the checker by first running a blocking 6s-timed TCP+TLS check, @@ -63,10 +66,15 @@ func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr) { // It returns an error if the initial TCP+TLS check fails. // The Checker has to be ultimately stopped by calling [Checker.Stop]. func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) { - if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 { + if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" { panic("call Checker.SetConfig with non empty values before Checker.Start") } + if c.icmpNotPermitted { + // restore forced check type to dns if icmp was found to be not permitted + c.smallCheckType = smallCheckDNS + } + err = c.startupCheck(ctx) if err != nil { return nil, fmt.Errorf("startup check: %w", err) @@ -77,7 +85,6 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) c.stop = cancel done := make(chan struct{}) c.done = done - c.smallCheckName = "ICMP echo" const smallCheckPeriod = time.Minute smallCheckTimer := time.NewTimer(smallCheckPeriod) const fullCheckPeriod = 5 * time.Minute @@ -119,6 +126,7 @@ func (c *Checker) Stop() error { <-c.done c.tlsDialAddrs = nil c.icmpTargetIPs = nil + c.smallCheckType = "" return nil } @@ -140,20 +148,21 @@ func (c *Checker) smallPeriodicCheck(ctx context.Context) error { 30 * time.Second, } check := func(ctx context.Context, try int) error { - if c.icmpNotPermitted { + if c.smallCheckType == smallCheckDNS { return c.dnsClient.Check(ctx) } ip := icmpTargetIPs[try%len(icmpTargetIPs)] err := c.echoer.Echo(ctx, ip) if errors.Is(err, icmp.ErrNotPermitted) { c.icmpNotPermitted = true - c.smallCheckName = "plain DNS over UDP" - c.logger.Infof("%s; permanently falling back to %s checks.", c.smallCheckName, err) + c.smallCheckType = smallCheckDNS + c.logger.Infof("%s; permanently falling back to %s checks", + smallCheckTypeToString(c.smallCheckType), err) return c.dnsClient.Check(ctx) } return err } - return withRetries(ctx, tryTimeouts, c.logger, c.smallCheckName, check) + return withRetries(ctx, tryTimeouts, c.logger, smallCheckTypeToString(c.smallCheckType), check) } func (c *Checker) fullPeriodicCheck(ctx context.Context) error { @@ -299,3 +308,19 @@ func (c *Checker) startupCheck(ctx context.Context) error { } return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", ")) } + +const ( + smallCheckDNS = "dns" + smallCheckICMP = "icmp" +) + +func smallCheckTypeToString(smallCheckType string) string { + switch smallCheckType { + case smallCheckICMP: + return "ICMP echo" + case smallCheckDNS: + return "plain DNS over UDP" + default: + panic("unknown small check type: " + smallCheckType) + } +} diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index d783e5d8..ea7b1f9d 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -101,7 +101,7 @@ type CmdStarter interface { } type HealthChecker interface { - SetConfig(tlsDialAddrs []string, icmpTargetIPs []netip.Addr) + SetConfig(tlsDialAddrs []string, icmpTargetIPs []netip.Addr, smallCheckType string) Start(ctx context.Context) (runError <-chan error, err error) Stop() error } diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 16225fc0..7c031f89 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -35,7 +35,8 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) { if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() { icmpTargetIPs = []netip.Addr{data.serverIP} } - l.healthChecker.SetConfig(l.healthSettings.TargetAddresses, icmpTargetIPs) + l.healthChecker.SetConfig(l.healthSettings.TargetAddresses, icmpTargetIPs, + l.healthSettings.SmallCheckType) healthErrCh, err := l.healthChecker.Start(ctx) l.healthServer.SetError(err)