Use the VPN local gateway IP address to run path MTU discovery

This commit is contained in:
Quentin McGaw
2025-10-06 10:03:15 +00:00
parent f0f3193c1c
commit b9051b02bf
2 changed files with 12 additions and 10 deletions

View File

@@ -47,7 +47,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue continue
} }
tunnelUpData := tunnelUpData{ tunnelUpData := tunnelUpData{
serverIP: connection.IP,
vpnType: settings.Type, vpnType: settings.Type,
serverName: connection.ServerName, serverName: connection.ServerName,
canPortForward: connection.PortForward, canPortForward: connection.PortForward,

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net/netip"
"time" "time"
"github.com/qdm12/dns/v2/pkg/check" "github.com/qdm12/dns/v2/pkg/check"
@@ -18,8 +17,6 @@ type tunnelUpData struct {
// vpnIntf is the name of the VPN network interface // vpnIntf is the name of the VPN network interface
// which is used both for port forwarding and MTU discovery // which is used both for port forwarding and MTU discovery
vpnIntf string vpnIntf string
// serverIP is used for path MTU discovery
serverIP netip.Addr
// vpnType is used for path MTU discovery to find the protocol overhead. // vpnType is used for path MTU discovery to find the protocol overhead.
// It can be "wireguard" or "openvpn". // It can be "wireguard" or "openvpn".
vpnType string vpnType string
@@ -35,11 +32,10 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
l.client.CloseIdleConnections() l.client.CloseIdleConnections()
mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
mtuLogger.Info("finding maximum MTU, this can take up to 4 seconds") err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
err := updateToMaxMTU(ctx, data.vpnIntf, data.serverIP, data.vpnType, l.netLinker, l.routing, mtuLogger)
l.netLinker, mtuLogger)
if err != nil { if err != nil {
l.logger.Error(err.Error()) mtuLogger.Error(err.Error())
} }
for _, vpnPort := range l.vpnInputPorts { for _, vpnPort := range l.vpnInputPorts {
@@ -82,8 +78,15 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
var errVPNTypeUnknown = errors.New("unknown VPN type") var errVPNTypeUnknown = errors.New("unknown VPN type")
func updateToMaxMTU(ctx context.Context, vpnInterface string, func updateToMaxMTU(ctx context.Context, vpnInterface string,
serverIP netip.Addr, vpnType string, netlinker NetLinker, logger *log.Logger, vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
) error { ) error {
logger.Info("finding maximum MTU, this can take up to 4 seconds")
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN gateway IP address: %w", err)
}
link, err := netlinker.LinkByName(vpnInterface) link, err := netlinker.LinkByName(vpnInterface)
if err != nil { if err != nil {
return fmt.Errorf("getting VPN interface by name: %w", err) return fmt.Errorf("getting VPN interface by name: %w", err)
@@ -114,7 +117,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string,
} }
const pingTimeout = time.Second const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, serverIP, vpnLinkMTU, pingTimeout, logger) vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
switch { switch {
case err == nil: case err == nil:
logger.Infof("Setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU) logger.Infof("Setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)