feat(healthcheck): combination of ICMP and TCP+TLS checks (#2923)
- New option: `HEALTH_ICMP_TARGET_IP` defaults to `0.0.0.0` meaning use the VPN server public IP address. - Options removed: `HEALTH_VPN_INITIAL_DURATION` and `HEALTH_VPN_ADDITIONAL_DURATION` - times and retries are handpicked and hardcoded. - Less aggressive checks and less false positive detection
This commit is contained in:
@@ -164,9 +164,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
# Health
|
||||
HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \
|
||||
HEALTH_TARGET_ADDRESS=cloudflare.com:443 \
|
||||
HEALTH_SUCCESS_WAIT_DURATION=5s \
|
||||
HEALTH_VPN_DURATION_INITIAL=6s \
|
||||
HEALTH_VPN_DURATION_ADDITION=5s \
|
||||
HEALTH_ICMP_TARGET_IP=0.0.0.0 \
|
||||
# DNS over TLS
|
||||
DOT=on \
|
||||
DOT_PROVIDERS=cloudflare \
|
||||
|
||||
@@ -414,6 +414,13 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
return fmt.Errorf("starting public ip loop: %w", err)
|
||||
}
|
||||
|
||||
healthLogger := logger.New(log.SetComponent("healthcheck"))
|
||||
healthcheckServer := healthcheck.NewServer(allSettings.Health, healthLogger)
|
||||
healthServerHandler, healthServerCtx, healthServerDone := goshutdown.NewGoRoutineHandler(
|
||||
"HTTP health server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
||||
go healthcheckServer.Run(healthServerCtx, healthServerDone)
|
||||
healthChecker := healthcheck.NewChecker(healthLogger)
|
||||
|
||||
updaterLogger := logger.New(log.SetComponent("updater"))
|
||||
|
||||
unzipper := unzip.New(httpClient)
|
||||
@@ -424,8 +431,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
vpnLogger := logger.New(log.SetComponent("vpn"))
|
||||
vpnLooper := vpn.NewLoop(allSettings.VPN, ipv6Supported, allSettings.Firewall.VPNInputPorts,
|
||||
providers, storage, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper,
|
||||
cmder, publicIPLooper, dnsLooper, vpnLogger, httpClient,
|
||||
providers, storage, allSettings.Health, healthChecker, healthcheckServer, ovpnConf, netLinker, firewallConf,
|
||||
routingConf, portForwardLooper, cmder, publicIPLooper, dnsLooper, vpnLogger, httpClient,
|
||||
buildInfo, *allSettings.Version.Enabled)
|
||||
vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler(
|
||||
"vpn", goroutine.OptionTimeout(time.Second))
|
||||
@@ -476,12 +483,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
<-httpServerReady
|
||||
controlGroupHandler.Add(httpServerHandler)
|
||||
|
||||
healthLogger := logger.New(log.SetComponent("healthcheck"))
|
||||
healthcheckServer := healthcheck.NewServer(allSettings.Health, healthLogger, vpnLooper)
|
||||
healthServerHandler, healthServerCtx, healthServerDone := goshutdown.NewGoRoutineHandler(
|
||||
"HTTP health server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
||||
go healthcheckServer.Run(healthServerCtx, healthServerDone)
|
||||
|
||||
orderHandler := goshutdown.NewOrderHandler("gluetun",
|
||||
order.OptionTimeout(totalShutdownTimeout),
|
||||
order.OptionOnSuccess(defaultShutdownOnSuccess),
|
||||
|
||||
@@ -12,6 +12,8 @@ func readObsolete(r *reader.Reader) (warnings []string) {
|
||||
"DOT_VERBOSITY": "DOT_VERBOSITY is obsolete, use LOG_LEVEL instead.",
|
||||
"DOT_VERBOSITY_DETAILS": "DOT_VERBOSITY_DETAILS is obsolete because it was specific to Unbound.",
|
||||
"DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.",
|
||||
"HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete",
|
||||
"HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete",
|
||||
}
|
||||
sortedKeys := maps.Keys(keyToMessage)
|
||||
slices.Sort(sortedKeys)
|
||||
|
||||
@@ -2,6 +2,7 @@ package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -24,16 +25,13 @@ type Health struct {
|
||||
// HTTP server. It defaults to 500 milliseconds.
|
||||
ReadTimeout time.Duration
|
||||
// TargetAddress is the address (host or host:port)
|
||||
// to TCP dial to periodically for the health check.
|
||||
// to TCP TLS dial to periodically for the health check.
|
||||
// It cannot be the empty string in the internal state.
|
||||
TargetAddress string
|
||||
// SuccessWait is the duration to wait to re-run the
|
||||
// healthcheck after a successful healthcheck.
|
||||
// It defaults to 5 seconds and cannot be zero in
|
||||
// the internal state.
|
||||
SuccessWait time.Duration
|
||||
// VPN has health settings specific to the VPN loop.
|
||||
VPN HealthyWait
|
||||
// ICMPTargetIP is the IP address to use for ICMP echo requests
|
||||
// in the health checker. It can be set to an unspecified address
|
||||
// such that the VPN server IP is used, which is also the default behavior.
|
||||
ICMPTargetIP netip.Addr
|
||||
}
|
||||
|
||||
func (h Health) Validate() (err error) {
|
||||
@@ -42,11 +40,6 @@ func (h Health) Validate() (err error) {
|
||||
return fmt.Errorf("server listening address is not valid: %w", err)
|
||||
}
|
||||
|
||||
err = h.VPN.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("health VPN settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -56,8 +49,7 @@ func (h *Health) copy() (copied Health) {
|
||||
ReadHeaderTimeout: h.ReadHeaderTimeout,
|
||||
ReadTimeout: h.ReadTimeout,
|
||||
TargetAddress: h.TargetAddress,
|
||||
SuccessWait: h.SuccessWait,
|
||||
VPN: h.VPN.copy(),
|
||||
ICMPTargetIP: h.ICMPTargetIP,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,8 +61,7 @@ func (h *Health) OverrideWith(other Health) {
|
||||
h.ReadHeaderTimeout = gosettings.OverrideWithComparable(h.ReadHeaderTimeout, other.ReadHeaderTimeout)
|
||||
h.ReadTimeout = gosettings.OverrideWithComparable(h.ReadTimeout, other.ReadTimeout)
|
||||
h.TargetAddress = gosettings.OverrideWithComparable(h.TargetAddress, other.TargetAddress)
|
||||
h.SuccessWait = gosettings.OverrideWithComparable(h.SuccessWait, other.SuccessWait)
|
||||
h.VPN.overrideWith(other.VPN)
|
||||
h.ICMPTargetIP = gosettings.OverrideWithComparable(h.ICMPTargetIP, other.ICMPTargetIP)
|
||||
}
|
||||
|
||||
func (h *Health) SetDefaults() {
|
||||
@@ -80,9 +71,7 @@ func (h *Health) SetDefaults() {
|
||||
const defaultReadTimeout = 500 * time.Millisecond
|
||||
h.ReadTimeout = gosettings.DefaultComparable(h.ReadTimeout, defaultReadTimeout)
|
||||
h.TargetAddress = gosettings.DefaultComparable(h.TargetAddress, "cloudflare.com:443")
|
||||
const defaultSuccessWait = 5 * time.Second
|
||||
h.SuccessWait = gosettings.DefaultComparable(h.SuccessWait, defaultSuccessWait)
|
||||
h.VPN.setDefaults()
|
||||
h.ICMPTargetIP = gosettings.DefaultComparable(h.ICMPTargetIP, netip.IPv4Unspecified()) // use the VPN server IP
|
||||
}
|
||||
|
||||
func (h Health) String() string {
|
||||
@@ -93,10 +82,11 @@ func (h Health) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("Health settings:")
|
||||
node.Appendf("Server listening address: %s", h.ServerAddress)
|
||||
node.Appendf("Target address: %s", h.TargetAddress)
|
||||
node.Appendf("Duration to wait after success: %s", h.SuccessWait)
|
||||
node.Appendf("Read header timeout: %s", h.ReadHeaderTimeout)
|
||||
node.Appendf("Read timeout: %s", h.ReadTimeout)
|
||||
node.AppendNode(h.VPN.toLinesNode("VPN"))
|
||||
icmpTarget := "VPN server IP"
|
||||
if !h.ICMPTargetIP.IsUnspecified() {
|
||||
icmpTarget = h.ICMPTargetIP.String()
|
||||
}
|
||||
node.Appendf("ICMP target IP: %s", icmpTarget)
|
||||
return node
|
||||
}
|
||||
|
||||
@@ -104,16 +94,9 @@ func (h *Health) Read(r *reader.Reader) (err error) {
|
||||
h.ServerAddress = r.String("HEALTH_SERVER_ADDRESS")
|
||||
h.TargetAddress = r.String("HEALTH_TARGET_ADDRESS",
|
||||
reader.RetroKeys("HEALTH_ADDRESS_TO_PING"))
|
||||
|
||||
h.SuccessWait, err = r.Duration("HEALTH_SUCCESS_WAIT_DURATION")
|
||||
h.ICMPTargetIP, err = r.NetipAddr("HEALTH_ICMP_TARGET_IP")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = h.VPN.read(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("VPN health settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/reader"
|
||||
"github.com/qdm12/gotree"
|
||||
)
|
||||
|
||||
type HealthyWait struct {
|
||||
// Initial is the initial duration to wait for the program
|
||||
// to be healthy before taking action.
|
||||
// It cannot be nil in the internal state.
|
||||
Initial *time.Duration
|
||||
// Addition is the duration to add to the Initial duration
|
||||
// after Initial has expired to wait longer for the program
|
||||
// to be healthy.
|
||||
// It cannot be nil in the internal state.
|
||||
Addition *time.Duration
|
||||
}
|
||||
|
||||
func (h HealthyWait) validate() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HealthyWait) copy() (copied HealthyWait) {
|
||||
return HealthyWait{
|
||||
Initial: gosettings.CopyPointer(h.Initial),
|
||||
Addition: gosettings.CopyPointer(h.Addition),
|
||||
}
|
||||
}
|
||||
|
||||
// overrideWith overrides fields of the receiver
|
||||
// settings object with any field set in the other
|
||||
// settings.
|
||||
func (h *HealthyWait) overrideWith(other HealthyWait) {
|
||||
h.Initial = gosettings.OverrideWithPointer(h.Initial, other.Initial)
|
||||
h.Addition = gosettings.OverrideWithPointer(h.Addition, other.Addition)
|
||||
}
|
||||
|
||||
func (h *HealthyWait) setDefaults() {
|
||||
const initialDurationDefault = 6 * time.Second
|
||||
const additionDurationDefault = 5 * time.Second
|
||||
h.Initial = gosettings.DefaultPointer(h.Initial, initialDurationDefault)
|
||||
h.Addition = gosettings.DefaultPointer(h.Addition, additionDurationDefault)
|
||||
}
|
||||
|
||||
func (h HealthyWait) String() string {
|
||||
return h.toLinesNode("Health").String()
|
||||
}
|
||||
|
||||
func (h HealthyWait) toLinesNode(kind string) (node *gotree.Node) {
|
||||
node = gotree.New(kind + " wait durations:")
|
||||
node.Appendf("Initial duration: %s", *h.Initial)
|
||||
node.Appendf("Additional duration: %s", *h.Addition)
|
||||
return node
|
||||
}
|
||||
|
||||
func (h *HealthyWait) read(r *reader.Reader) (err error) {
|
||||
h.Initial, err = r.DurationPtr(
|
||||
"HEALTH_VPN_DURATION_INITIAL",
|
||||
reader.RetroKeys("HEALTH_OPENVPN_DURATION_INITIAL"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.Addition, err = r.DurationPtr(
|
||||
"HEALTH_VPN_DURATION_ADDITION",
|
||||
reader.RetroKeys("HEALTH_OPENVPN_DURATION_ADDITION"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -58,12 +58,7 @@ func Test_Settings_String(t *testing.T) {
|
||||
├── Health settings:
|
||||
| ├── Server listening address: 127.0.0.1:9999
|
||||
| ├── Target address: cloudflare.com:443
|
||||
| ├── Duration to wait after success: 5s
|
||||
| ├── Read header timeout: 100ms
|
||||
| ├── Read timeout: 500ms
|
||||
| └── VPN wait durations:
|
||||
| ├── Initial duration: 6s
|
||||
| └── Additional duration: 5s
|
||||
| └── ICMP target IP: VPN server IP
|
||||
├── Shadowsocks server settings:
|
||||
| └── Enabled: no
|
||||
├── HTTP proxy settings:
|
||||
|
||||
239
internal/healthcheck/checker.go
Normal file
239
internal/healthcheck/checker.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/healthcheck/dns"
|
||||
"github.com/qdm12/gluetun/internal/healthcheck/icmp"
|
||||
)
|
||||
|
||||
type Checker struct {
|
||||
tlsDialAddr string
|
||||
dialer *net.Dialer
|
||||
echoer *icmp.Echoer
|
||||
dnsClient *dns.Client
|
||||
logger Logger
|
||||
icmpTarget netip.Addr
|
||||
configMutex sync.Mutex
|
||||
|
||||
icmpNotPermitted bool
|
||||
|
||||
// Internal periodic service signals
|
||||
stop context.CancelFunc
|
||||
done <-chan struct{}
|
||||
}
|
||||
|
||||
func NewChecker(logger Logger) *Checker {
|
||||
return &Checker{
|
||||
dialer: &net.Dialer{
|
||||
Resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
},
|
||||
},
|
||||
echoer: icmp.NewEchoer(logger),
|
||||
dnsClient: dns.New(),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig sets the TCP+TLS dial address and the ICMP echo IP address
|
||||
// to target by the [Checker].
|
||||
// This function MUST be called before calling [Checker.Start].
|
||||
func (c *Checker) SetConfig(tlsDialAddr string, icmpTarget netip.Addr) {
|
||||
c.configMutex.Lock()
|
||||
defer c.configMutex.Unlock()
|
||||
c.tlsDialAddr = tlsDialAddr
|
||||
c.icmpTarget = icmpTarget
|
||||
}
|
||||
|
||||
// Start starts the checker by first running a blocking 2s-timed TCP+TLS check,
|
||||
// and, on success, starts the periodic checks in a separate goroutine:
|
||||
// - a "small" ICMP echo check every 15 seconds
|
||||
// - a "full" TCP+TLS check every 5 minutes
|
||||
// It returns a channel `runError` that receives an error if one of the periodic checks fail.
|
||||
// It returns an error if the initial TCP+TLS check fails.
|
||||
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
|
||||
if c.tlsDialAddr == "" || c.icmpTarget.IsUnspecified() {
|
||||
panic("call Checker.SetConfig with non empty values before Checker.Start")
|
||||
}
|
||||
|
||||
// connection isn't under load yet when the checker starts, so a short
|
||||
// 6 seconds timeout suffices and provides quick enough feedback that
|
||||
// the new connection is not working.
|
||||
const timeout = 6 * time.Second
|
||||
tcpTLSCheckCtx, tcpTLSCheckCancel := context.WithTimeout(ctx, timeout)
|
||||
err = tcpTLSCheck(tcpTLSCheckCtx, c.dialer, c.tlsDialAddr)
|
||||
tcpTLSCheckCancel()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("startup check: %w", err)
|
||||
}
|
||||
|
||||
ready := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c.stop = cancel
|
||||
done := make(chan struct{})
|
||||
c.done = done
|
||||
const smallCheckPeriod = 15 * time.Second
|
||||
smallCheckTimer := time.NewTimer(smallCheckPeriod)
|
||||
const fullCheckPeriod = 5 * time.Minute
|
||||
fullCheckTimer := time.NewTimer(fullCheckPeriod)
|
||||
runErrorCh := make(chan error)
|
||||
runError = runErrorCh
|
||||
go func() {
|
||||
defer close(done)
|
||||
close(ready)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fullCheckTimer.Stop()
|
||||
smallCheckTimer.Stop()
|
||||
return
|
||||
case <-smallCheckTimer.C:
|
||||
err := c.smallPeriodicCheck(ctx)
|
||||
if err != nil {
|
||||
runErrorCh <- fmt.Errorf("periodic small check: %w", err)
|
||||
return
|
||||
}
|
||||
smallCheckTimer.Reset(smallCheckPeriod)
|
||||
case <-fullCheckTimer.C:
|
||||
err := c.fullPeriodicCheck(ctx)
|
||||
if err != nil {
|
||||
runErrorCh <- fmt.Errorf("periodic full check: %w", err)
|
||||
return
|
||||
}
|
||||
fullCheckTimer.Reset(fullCheckPeriod)
|
||||
}
|
||||
}
|
||||
}()
|
||||
<-ready
|
||||
return runError, nil
|
||||
}
|
||||
|
||||
func (c *Checker) Stop() error {
|
||||
c.stop()
|
||||
<-c.done
|
||||
c.icmpTarget = netip.Addr{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Checker) smallPeriodicCheck(ctx context.Context) error {
|
||||
c.configMutex.Lock()
|
||||
ip := c.icmpTarget
|
||||
c.configMutex.Unlock()
|
||||
const maxTries = 3
|
||||
const timeout = 3 * time.Second
|
||||
const extraTryTime = time.Second // 1s added for each subsequent retry
|
||||
check := func(ctx context.Context) error {
|
||||
if c.icmpNotPermitted {
|
||||
return c.dnsClient.Check(ctx)
|
||||
}
|
||||
err := c.echoer.Echo(ctx, ip)
|
||||
if errors.Is(err, icmp.ErrNotPermitted) {
|
||||
c.icmpNotPermitted = true
|
||||
c.logger.Warnf("%s; permanently falling back to plaintext DNS checks.", err)
|
||||
return c.dnsClient.Check(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return withRetries(ctx, maxTries, timeout, extraTryTime, c.logger, "ICMP echo", check)
|
||||
}
|
||||
|
||||
func (c *Checker) fullPeriodicCheck(ctx context.Context) error {
|
||||
const maxTries = 2
|
||||
// 10s timeout in case the connection is under stress
|
||||
// See https://github.com/qdm12/gluetun/issues/2270
|
||||
const timeout = 10 * time.Second
|
||||
const extraTryTime = 3 * time.Second // 3s added for each subsequent retry
|
||||
check := func(ctx context.Context) error {
|
||||
return tcpTLSCheck(ctx, c.dialer, c.tlsDialAddr)
|
||||
}
|
||||
return withRetries(ctx, maxTries, timeout, extraTryTime, c.logger, "TCP+TLS dial", check)
|
||||
}
|
||||
|
||||
func tcpTLSCheck(ctx context.Context, dialer *net.Dialer, targetAddress string) error {
|
||||
// TODO use mullvad API if current provider is Mullvad
|
||||
|
||||
address, err := makeAddressToDial(targetAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
const dialNetwork = "tcp4"
|
||||
connection, err := dialer.DialContext(ctx, dialNetwork, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing: %w", err)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(address, ":443") {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("splitting host and port: %w", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: host,
|
||||
}
|
||||
tlsConnection := tls.Client(connection, tlsConfig)
|
||||
err = tlsConnection.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("running TLS handshake: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = connection.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeAddressToDial(address string) (addressToDial string, err error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
addrErr := new(net.AddrError)
|
||||
ok := errors.As(err, &addrErr)
|
||||
if !ok || addrErr.Err != "missing port in address" {
|
||||
return "", fmt.Errorf("splitting host and port from address: %w", err)
|
||||
}
|
||||
host = address
|
||||
const defaultPort = "443"
|
||||
port = defaultPort
|
||||
}
|
||||
address = net.JoinHostPort(host, port)
|
||||
return address, nil
|
||||
}
|
||||
|
||||
var ErrAllCheckTriesFailed = errors.New("all check tries failed")
|
||||
|
||||
func withRetries(ctx context.Context, maxTries uint, tryTimeout, extraTryTime time.Duration,
|
||||
warner Logger, checkName string, check func(ctx context.Context) error,
|
||||
) error {
|
||||
try := uint(0)
|
||||
for {
|
||||
timeout := tryTimeout + time.Duration(try)*extraTryTime //nolint:gosec
|
||||
checkCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
err := check(checkCtx)
|
||||
cancel()
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case ctx.Err() != nil:
|
||||
return fmt.Errorf("%s context error: %w", checkName, ctx.Err())
|
||||
default:
|
||||
warner.Warnf("%s attempt %d/%d failed: %v", checkName, try+1, maxTries, err)
|
||||
try++
|
||||
if try == maxTries {
|
||||
return fmt.Errorf("%w: %s: after %d attempts", ErrAllCheckTriesFailed, checkName, maxTries)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Server_healthCheck(t *testing.T) {
|
||||
func Test_Checker_fullcheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("canceled real dialer", func(t *testing.T) {
|
||||
@@ -21,20 +20,18 @@ func Test_Server_healthCheck(t *testing.T) {
|
||||
dialer := &net.Dialer{}
|
||||
const address = "cloudflare.com:443"
|
||||
|
||||
server := &Server{
|
||||
checker := &Checker{
|
||||
dialer: dialer,
|
||||
config: settings.Health{
|
||||
TargetAddress: address,
|
||||
},
|
||||
tlsDialAddr: address,
|
||||
}
|
||||
|
||||
canceledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := server.healthCheck(canceledCtx)
|
||||
err := checker.fullPeriodicCheck(canceledCtx)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "operation was canceled")
|
||||
assert.EqualError(t, err, "TCP+TLS dial context error: context canceled")
|
||||
})
|
||||
|
||||
t.Run("dial localhost:0", func(t *testing.T) {
|
||||
@@ -54,14 +51,12 @@ func Test_Server_healthCheck(t *testing.T) {
|
||||
listeningAddress := listener.Addr()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
server := &Server{
|
||||
checker := &Checker{
|
||||
dialer: dialer,
|
||||
config: settings.Health{
|
||||
TargetAddress: listeningAddress.String(),
|
||||
},
|
||||
tlsDialAddr: listeningAddress.String(),
|
||||
}
|
||||
|
||||
err = server.healthCheck(ctx)
|
||||
err = checker.fullPeriodicCheck(ctx)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
39
internal/healthcheck/dns/dns.go
Normal file
39
internal/healthcheck/dns/dns.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Client is a simple plaintext UDP DNS client, to be used for healthchecks.
|
||||
// Note the client connects to a DNS server only over UDP on port 53,
|
||||
// because we don't want to use DoT or DoH and impact the TCP connections
|
||||
// when running a healthcheck.
|
||||
type Client struct{}
|
||||
|
||||
func New() *Client {
|
||||
return &Client{}
|
||||
}
|
||||
|
||||
var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup")
|
||||
|
||||
func (c *Client) Check(ctx context.Context) error {
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
dialer := net.Dialer{}
|
||||
return dialer.DialContext(ctx, "udp", "1.1.1.1:53")
|
||||
},
|
||||
}
|
||||
ips, err := resolver.LookupIP(ctx, "ip", "github.com")
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case len(ips) == 0:
|
||||
return fmt.Errorf("%w", ErrLookupNoIPs)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -9,13 +9,15 @@ import (
|
||||
type handler struct {
|
||||
healthErr error
|
||||
healthErrMu sync.RWMutex
|
||||
logger Logger
|
||||
}
|
||||
|
||||
var errHealthcheckNotRunYet = errors.New("healthcheck did not run yet")
|
||||
|
||||
func newHandler() *handler {
|
||||
func newHandler(logger Logger) *handler {
|
||||
return &handler{
|
||||
healthErr: errHealthcheckNotRunYet,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Server) runHealthcheckLoop(ctx context.Context, done chan<- struct{}) {
|
||||
defer close(done)
|
||||
|
||||
timeoutIndex := 0
|
||||
healthcheckTimeouts := []time.Duration{
|
||||
2 * time.Second,
|
||||
4 * time.Second,
|
||||
6 * time.Second,
|
||||
8 * time.Second,
|
||||
// This can be useful when the connection is under stress
|
||||
// See https://github.com/qdm12/gluetun/issues/2270
|
||||
10 * time.Second,
|
||||
}
|
||||
s.vpn.healthyTimer = time.NewTimer(s.vpn.healthyWait)
|
||||
|
||||
for {
|
||||
previousErr := s.handler.getErr()
|
||||
|
||||
timeout := healthcheckTimeouts[timeoutIndex]
|
||||
healthcheckCtx, healthcheckCancel := context.WithTimeout(
|
||||
ctx, timeout)
|
||||
err := s.healthCheck(healthcheckCtx)
|
||||
healthcheckCancel()
|
||||
|
||||
s.handler.setErr(err)
|
||||
|
||||
switch {
|
||||
case previousErr != nil && err == nil: // First success
|
||||
s.logger.Info("healthy!")
|
||||
timeoutIndex = 0
|
||||
s.vpn.healthyTimer.Stop()
|
||||
s.vpn.healthyWait = *s.config.VPN.Initial
|
||||
case previousErr == nil && err != nil: // First failure
|
||||
s.logger.Debug("unhealthy: " + err.Error())
|
||||
s.vpn.healthyTimer.Stop()
|
||||
s.vpn.healthyTimer = time.NewTimer(s.vpn.healthyWait)
|
||||
case previousErr != nil && err != nil: // Nth failure
|
||||
if timeoutIndex < len(healthcheckTimeouts)-1 {
|
||||
timeoutIndex++
|
||||
}
|
||||
select {
|
||||
case <-s.vpn.healthyTimer.C:
|
||||
timeoutIndex = 0 // retry next with the smallest timeout
|
||||
s.onUnhealthyVPN(ctx, err.Error())
|
||||
default:
|
||||
}
|
||||
case previousErr == nil && err == nil: // Nth success
|
||||
timer := time.NewTimer(s.config.SuccessWait)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) healthCheck(ctx context.Context) (err error) {
|
||||
// TODO use mullvad API if current provider is Mullvad
|
||||
|
||||
address, err := makeAddressToDial(s.config.TargetAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
const dialNetwork = "tcp4"
|
||||
connection, err := s.dialer.DialContext(ctx, dialNetwork, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing: %w", err)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(address, ":443") {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("splitting host and port: %w", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: host,
|
||||
}
|
||||
tlsConnection := tls.Client(connection, tlsConfig)
|
||||
err = tlsConnection.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("running TLS handshake: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = connection.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeAddressToDial(address string) (addressToDial string, err error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
addrErr := new(net.AddrError)
|
||||
ok := errors.As(err, &addrErr)
|
||||
if !ok || addrErr.Err != "missing port in address" {
|
||||
return "", fmt.Errorf("splitting host and port from address: %w", err)
|
||||
}
|
||||
host = address
|
||||
const defaultPort = "443"
|
||||
port = defaultPort
|
||||
}
|
||||
address = net.JoinHostPort(host, port)
|
||||
return address, nil
|
||||
}
|
||||
49
internal/healthcheck/icmp/apple_ipv4.go
Normal file
49
internal/healthcheck/icmp/apple_ipv4.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
var _ net.PacketConn = &ipv4Wrapper{}
|
||||
|
||||
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
|
||||
// the net.PacketConn interface. It's only used for Darwin or iOS.
|
||||
type ipv4Wrapper struct {
|
||||
ipv4Conn *ipv4.PacketConn
|
||||
}
|
||||
|
||||
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
|
||||
return &ipv4Wrapper{ipv4Conn: ipv4}
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return i.ipv4Conn.WriteTo(p, nil, addr)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) Close() error {
|
||||
return i.ipv4Conn.Close()
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) LocalAddr() net.Addr {
|
||||
return i.ipv4Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetWriteDeadline(t)
|
||||
}
|
||||
190
internal/healthcheck/icmp/echo.go
Normal file
190
internal/healthcheck/icmp/echo.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
)
|
||||
|
||||
type Echoer struct {
|
||||
buffer []byte
|
||||
randomSource io.Reader
|
||||
logger Logger
|
||||
}
|
||||
|
||||
func NewEchoer(logger Logger) *Echoer {
|
||||
const maxICMPEchoSize = 1500
|
||||
buffer := make([]byte, maxICMPEchoSize)
|
||||
var seed [32]byte
|
||||
_, _ = cryptorand.Read(seed[:])
|
||||
randomSource := rand.NewChaCha8(seed)
|
||||
return &Echoer{
|
||||
buffer: buffer,
|
||||
randomSource: randomSource,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTimedOut = errors.New("timed out waiting for ICMP echo reply")
|
||||
ErrNotPermitted = errors.New("not permitted")
|
||||
)
|
||||
|
||||
func (i *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
||||
var ipVersion string
|
||||
var conn net.PacketConn
|
||||
if ip.Is4() {
|
||||
ipVersion = "v4"
|
||||
conn, err = listenICMPv4(ctx)
|
||||
} else {
|
||||
ipVersion = "v6"
|
||||
conn, err = listenICMPv6(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted)
|
||||
}
|
||||
return fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
const echoDataSize = 32
|
||||
id, message := buildMessageToSend(ipVersion, echoDataSize, i.randomSource)
|
||||
|
||||
encodedMessage, err := message.Marshal(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||
}
|
||||
return fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
receivedData, err := receiveEchoReply(conn, id, i.buffer, ipVersion, i.logger)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
|
||||
return fmt.Errorf("%w", ErrTimedOut)
|
||||
}
|
||||
return fmt.Errorf("receiving ICMP echo reply: %w", err)
|
||||
}
|
||||
|
||||
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
|
||||
if !bytes.Equal(receivedData, sentData) {
|
||||
return fmt.Errorf("%w: sent %x and received %x", ErrICMPEchoDataMismatch, sentData, receivedData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildMessageToSend(ipVersion string, size uint, randomSource io.Reader) (id int, message *icmp.Message) {
|
||||
const uint16Bytes = 2
|
||||
idBytes := make([]byte, uint16Bytes)
|
||||
_, _ = randomSource.Read(idBytes)
|
||||
id = int(binary.BigEndian.Uint16(idBytes))
|
||||
|
||||
var icmpType icmp.Type
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
icmpType = ipv4.ICMPTypeEcho
|
||||
case "v6":
|
||||
icmpType = ipv6.ICMPTypeEchoRequest
|
||||
default:
|
||||
panic(fmt.Sprintf("IP version %q not supported", ipVersion))
|
||||
}
|
||||
messageBodyData := make([]byte, size)
|
||||
_, _ = randomSource.Read(messageBodyData)
|
||||
|
||||
// See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types
|
||||
message = &icmp.Message{
|
||||
Type: icmpType, // echo request
|
||||
Code: 0, // no code
|
||||
Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6)
|
||||
Body: &icmp.Echo{
|
||||
ID: id,
|
||||
Seq: 0, // only one packet
|
||||
Data: messageBodyData,
|
||||
},
|
||||
}
|
||||
return id, message
|
||||
}
|
||||
|
||||
func receiveEchoReply(conn net.PacketConn, id int, buffer []byte, ipVersion string, logger Logger,
|
||||
) (data []byte, err error) {
|
||||
var icmpProtocol int
|
||||
const (
|
||||
icmpv4Protocol = 1
|
||||
icmpv6Protocol = 58
|
||||
)
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
icmpProtocol = icmpv4Protocol
|
||||
case "v6":
|
||||
icmpProtocol = icmpv6Protocol
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||
}
|
||||
|
||||
for {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
// Parse the ICMP message
|
||||
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
switch body := message.Body.(type) {
|
||||
case *icmp.Echo:
|
||||
if id != body.ID {
|
||||
logger.Warnf("ignoring ICMP echo reply mismatching expected id %d (id: %d, type: %d, code: %d, length: %d)",
|
||||
id, body.ID, message.Type, message.Code, len(packetBytes))
|
||||
continue // not the ID we are looking for
|
||||
}
|
||||
return body.Data, nil
|
||||
case *icmp.DstUnreach:
|
||||
logger.Debugf("ignoring ICMP destination unreachable message (type: 3, code: %d, expected-id %d)", message.Code, id)
|
||||
// See https://github.com/qdm12/gluetun/pull/2923#issuecomment-3377532249
|
||||
// on why we ignore this message. If it is actually unreachable, the timeout on waiting for
|
||||
// the echo reply will do instead of returning an error error.
|
||||
continue
|
||||
case *icmp.TimeExceeded:
|
||||
logger.Debugf("ignoring ICMP time exceeded message (type: 11, code: %d, expected-id %d)", message.Code, id)
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %T (type %d, code %d, expected-id %d)",
|
||||
ErrICMPBodyUnsupported, body, message.Type, message.Code, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
6
internal/healthcheck/icmp/interfaces.go
Normal file
6
internal/healthcheck/icmp/interfaces.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package icmp
|
||||
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...any)
|
||||
Warnf(format string, args ...any)
|
||||
}
|
||||
35
internal/healthcheck/icmp/listen.go
Normal file
35
internal/healthcheck/icmp/listen.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package icmp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
||||
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listening for ICMPv6 packets: %w", err)
|
||||
}
|
||||
return packetConn, nil
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
package healthcheck
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Debugf(format string, args ...any)
|
||||
Info(s string)
|
||||
Warnf(format string, args ...any)
|
||||
Error(s string)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
type vpnHealth struct {
|
||||
loop StatusApplier
|
||||
healthyWait time.Duration
|
||||
healthyTimer *time.Timer
|
||||
}
|
||||
|
||||
func (s *Server) onUnhealthyVPN(ctx context.Context, lastErrMessage string) {
|
||||
s.logger.Info("program has been unhealthy for " +
|
||||
s.vpn.healthyWait.String() + ": restarting VPN (healthcheck error: " + lastErrMessage + ")")
|
||||
s.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
|
||||
s.logger.Info("DO NOT OPEN AN ISSUE UNLESS YOU HAVE READ AND TRIED EVERY POSSIBLE SOLUTION")
|
||||
_, _ = s.vpn.loop.ApplyStatus(ctx, constants.Stopped)
|
||||
_, _ = s.vpn.loop.ApplyStatus(ctx, constants.Running)
|
||||
s.vpn.healthyWait += *s.config.VPN.Addition
|
||||
s.vpn.healthyTimer = time.NewTimer(s.vpn.healthyWait)
|
||||
}
|
||||
@@ -10,9 +10,6 @@ import (
|
||||
func (s *Server) Run(ctx context.Context, done chan<- struct{}) {
|
||||
defer close(done)
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
go s.runHealthcheckLoop(ctx, loopDone)
|
||||
|
||||
server := http.Server{
|
||||
Addr: s.config.ServerAddress,
|
||||
Handler: s.handler,
|
||||
@@ -37,6 +34,5 @@ func (s *Server) Run(ctx context.Context, done chan<- struct{}) {
|
||||
s.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
<-loopDone
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
@@ -11,30 +10,21 @@ import (
|
||||
type Server struct {
|
||||
logger Logger
|
||||
handler *handler
|
||||
dialer *net.Dialer
|
||||
config settings.Health
|
||||
vpn vpnHealth
|
||||
}
|
||||
|
||||
func NewServer(config settings.Health,
|
||||
logger Logger, vpnLoop StatusApplier,
|
||||
) *Server {
|
||||
func NewServer(config settings.Health, logger Logger) *Server {
|
||||
return &Server{
|
||||
logger: logger,
|
||||
handler: newHandler(),
|
||||
dialer: &net.Dialer{
|
||||
Resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
},
|
||||
},
|
||||
handler: newHandler(logger),
|
||||
config: config,
|
||||
vpn: vpnHealth{
|
||||
loop: vpnLoop,
|
||||
healthyWait: *config.VPN.Initial,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) SetError(err error) {
|
||||
s.handler.setErr(err)
|
||||
}
|
||||
|
||||
type StatusApplier interface {
|
||||
ApplyStatus(ctx context.Context, status models.LoopStatus) (
|
||||
outcome string, err error)
|
||||
|
||||
@@ -99,3 +99,13 @@ type CmdStarter interface {
|
||||
stdoutLines, stderrLines <-chan string,
|
||||
waitError <-chan error, startErr error)
|
||||
}
|
||||
|
||||
type HealthChecker interface {
|
||||
SetConfig(tlsDialAddr string, icmpTarget netip.Addr)
|
||||
Start(ctx context.Context) (runError <-chan error, err error)
|
||||
Stop() error
|
||||
}
|
||||
|
||||
type HealthServer interface {
|
||||
SetError(err error)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@ type Loop struct {
|
||||
state *state.State
|
||||
providers Providers
|
||||
storage Storage
|
||||
healthSettings settings.Health
|
||||
healthChecker HealthChecker
|
||||
healthServer HealthServer
|
||||
// Fixed parameters
|
||||
buildInfo models.BuildInformation
|
||||
versionInfo bool
|
||||
@@ -49,7 +52,8 @@ const (
|
||||
)
|
||||
|
||||
func NewLoop(vpnSettings settings.VPN, ipv6Supported bool, vpnInputPorts []uint16,
|
||||
providers Providers, storage Storage, openvpnConf OpenVPN,
|
||||
providers Providers, storage Storage, healthSettings settings.Health,
|
||||
healthChecker HealthChecker, healthServer HealthServer, openvpnConf OpenVPN,
|
||||
netLinker NetLinker, fw Firewall, routing Routing,
|
||||
portForward PortForward, starter CmdStarter,
|
||||
publicip PublicIPLoop, dnsLooper DNSLoop,
|
||||
@@ -69,6 +73,9 @@ func NewLoop(vpnSettings settings.VPN, ipv6Supported bool, vpnInputPorts []uint1
|
||||
state: state,
|
||||
providers: providers,
|
||||
storage: storage,
|
||||
healthSettings: healthSettings,
|
||||
healthChecker: healthChecker,
|
||||
healthServer: healthServer,
|
||||
buildInfo: buildInfo,
|
||||
versionInfo: versionInfo,
|
||||
ipv6Supported: ipv6Supported,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
@@ -14,39 +15,38 @@ import (
|
||||
func setupOpenVPN(ctx context.Context, fw Firewall,
|
||||
openvpnConf OpenVPN, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6Supported bool, starter CmdStarter,
|
||||
logger openvpn.Logger) (runner *openvpn.Runner, serverName string,
|
||||
canPortForward bool, err error,
|
||||
logger openvpn.Logger) (runner *openvpn.Runner, connection models.Connection, err error,
|
||||
) {
|
||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err)
|
||||
}
|
||||
|
||||
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
|
||||
|
||||
if err := openvpnConf.WriteConfig(lines); err != nil {
|
||||
return nil, "", false, fmt.Errorf("writing configuration to file: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err)
|
||||
}
|
||||
|
||||
if *settings.OpenVPN.User != "" {
|
||||
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("writing auth to file: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("writing auth to file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if *settings.OpenVPN.KeyPassphrase != "" {
|
||||
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("writing askpass file: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("writing askpass file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
|
||||
return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("allowing VPN connection through firewall: %w", err)
|
||||
}
|
||||
|
||||
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
|
||||
|
||||
return runner, connection.ServerName, connection.PortForward, nil
|
||||
return runner, connection, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
|
||||
@@ -28,17 +29,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
var vpnRunner interface {
|
||||
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
|
||||
}
|
||||
var serverName, vpnInterface string
|
||||
var canPortForward bool
|
||||
var vpnInterface string
|
||||
var connection models.Connection
|
||||
var err error
|
||||
subLogger := l.logger.New(log.SetComponent(settings.Type))
|
||||
if settings.Type == vpn.OpenVPN {
|
||||
vpnInterface = settings.OpenVPN.Interface
|
||||
vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw,
|
||||
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
|
||||
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
|
||||
} else { // Wireguard
|
||||
vpnInterface = settings.Wireguard.Interface
|
||||
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||
providerConf, settings, l.ipv6Supported, subLogger)
|
||||
}
|
||||
if err != nil {
|
||||
@@ -46,8 +47,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
continue
|
||||
}
|
||||
tunnelUpData := tunnelUpData{
|
||||
serverName: serverName,
|
||||
canPortForward: canPortForward,
|
||||
serverIP: connection.IP,
|
||||
serverName: connection.ServerName,
|
||||
canPortForward: connection.PortForward,
|
||||
portForwarder: portForwarder,
|
||||
vpnIntf: vpnInterface,
|
||||
username: settings.Provider.PortForwarding.Username,
|
||||
@@ -73,7 +75,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
for stayHere {
|
||||
select {
|
||||
case <-tunnelReady:
|
||||
go l.onTunnelUp(openvpnCtx, tunnelUpData)
|
||||
go l.onTunnelUp(openvpnCtx, ctx, tunnelUpData)
|
||||
case <-ctx.Done():
|
||||
l.cleanup()
|
||||
openvpnCancel()
|
||||
|
||||
@@ -2,6 +2,7 @@ package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/check"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
@@ -9,6 +10,8 @@ import (
|
||||
)
|
||||
|
||||
type tunnelUpData struct {
|
||||
// Healthcheck
|
||||
serverIP netip.Addr
|
||||
// Port forwarding
|
||||
vpnIntf string
|
||||
serverName string // used for PIA
|
||||
@@ -18,7 +21,7 @@ type tunnelUpData struct {
|
||||
portForwarder PortForwarder
|
||||
}
|
||||
|
||||
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
||||
l.client.CloseIdleConnections()
|
||||
|
||||
for _, vpnPort := range l.vpnInputPorts {
|
||||
@@ -28,6 +31,24 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
}
|
||||
}
|
||||
|
||||
icmpTarget := l.healthSettings.ICMPTargetIP
|
||||
if icmpTarget.IsUnspecified() {
|
||||
icmpTarget = data.serverIP
|
||||
}
|
||||
l.healthChecker.SetConfig(l.healthSettings.TargetAddress, icmpTarget)
|
||||
|
||||
healthErrCh, err := l.healthChecker.Start(ctx)
|
||||
l.healthServer.SetError(err)
|
||||
if err != nil {
|
||||
// Note this restart call must be done in a separate goroutine
|
||||
// from the VPN loop goroutine.
|
||||
l.restartVPN(loopCtx, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = l.healthChecker.Stop()
|
||||
}()
|
||||
|
||||
if *l.dnsLooper.GetSettings().DoT.Enabled {
|
||||
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
|
||||
} else {
|
||||
@@ -37,7 +58,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
}
|
||||
}
|
||||
|
||||
err := l.publicip.RunOnce(ctx)
|
||||
err = l.publicip.RunOnce(ctx)
|
||||
if err != nil {
|
||||
l.logger.Error("getting public IP address information: " + err.Error())
|
||||
}
|
||||
@@ -56,4 +77,21 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case healthErr := <-healthErrCh:
|
||||
l.healthServer.SetError(healthErr)
|
||||
// Note this restart call must be done in a separate goroutine
|
||||
// from the VPN loop goroutine.
|
||||
l.restartVPN(loopCtx, healthErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
|
||||
l.logger.Warnf("restarting VPN because it failed to pass the healthcheck: %s", healthErr)
|
||||
l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
|
||||
l.logger.Info("DO NOT OPEN AN ISSUE UNLESS YOU HAVE READ AND TRIED EVERY POSSIBLE SOLUTION")
|
||||
_, _ = l.ApplyStatus(ctx, constants.Stopped)
|
||||
_, _ = l.ApplyStatus(ctx, constants.Running)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
@@ -16,11 +17,11 @@ import (
|
||||
func setupWireguard(ctx context.Context, netlinker NetLinker,
|
||||
fw Firewall, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
|
||||
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error,
|
||||
wireguarder *wireguard.Wireguard, connection models.Connection, err error,
|
||||
) {
|
||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("finding a VPN server: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
|
||||
}
|
||||
|
||||
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
|
||||
@@ -31,13 +32,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
|
||||
|
||||
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("creating Wireguard: %w", err)
|
||||
}
|
||||
|
||||
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("setting firewall: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
|
||||
}
|
||||
|
||||
return wireguarder, connection.ServerName, connection.PortForward, nil
|
||||
return wireguarder, connection, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user