Use PMTUD to set the MTU to the VPN interface

- Add `VPN_PMTUD` option enabled by default
- One can revert to use `VPN_PMTUD=off` to disable the new PMTUD mechanism
This commit is contained in:
Quentin McGaw
2025-09-10 14:43:21 +00:00
parent e21d798f57
commit 162d244865
12 changed files with 141 additions and 25 deletions

View File

@@ -81,6 +81,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) (err error)
}
type DNSLoop interface {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/provider"
)
@@ -14,39 +15,39 @@ import (
func setupOpenVPN(ctx context.Context, fw Firewall,
openvpnConf OpenVPN, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, starter CmdStarter,
logger openvpn.Logger) (runner *openvpn.Runner, serverName string,
canPortForward bool, err error,
logger openvpn.Logger) (runner *openvpn.Runner,
connection models.Connection, err error,
) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err)
return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err)
}
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
if err := openvpnConf.WriteConfig(lines); err != nil {
return nil, "", false, fmt.Errorf("writing configuration to file: %w", err)
return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err)
}
if *settings.OpenVPN.User != "" {
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
if err != nil {
return nil, "", false, fmt.Errorf("writing auth to file: %w", err)
return nil, models.Connection{}, fmt.Errorf("writing auth to file: %w", err)
}
}
if *settings.OpenVPN.KeyPassphrase != "" {
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
if err != nil {
return nil, "", false, fmt.Errorf("writing askpass file: %w", err)
return nil, models.Connection{}, fmt.Errorf("writing askpass file: %w", err)
}
}
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err)
return nil, models.Connection{}, fmt.Errorf("allowing VPN connection through firewall: %w", err)
}
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
return runner, connection.ServerName, connection.PortForward, nil
return runner, connection, nil
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/log"
)
@@ -28,17 +29,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
var vpnRunner interface {
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
}
var serverName, vpnInterface string
var canPortForward bool
var vpnInterface string
var connection models.Connection
var err error
subLogger := l.logger.New(log.SetComponent(settings.Type))
if settings.Type == vpn.OpenVPN {
vpnInterface = settings.OpenVPN.Interface
vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw,
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard
vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
providerConf, settings, l.ipv6Supported, subLogger)
}
if err != nil {
@@ -46,8 +47,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue
}
tunnelUpData := tunnelUpData{
serverName: serverName,
canPortForward: canPortForward,
PMTUD: *settings.PMTUD,
serverIP: connection.IP,
vpnType: settings.Type,
serverName: connection.ServerName,
canPortForward: connection.PortForward,
portForwarder: portForwarder,
vpnIntf: vpnInterface,
username: settings.Provider.PortForwarding.Username,

View File

@@ -2,15 +2,32 @@ package vpn
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/pmtud"
"github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/log"
)
type tunnelUpData struct {
// Port forwarding
vpnIntf string
// vpnIntf is the name of the VPN network interface
// which is used both for port forwarding and MTU discovery
vpnIntf string
// Path MTU discovery fields:
// PMTUD indicates whether to perform Path MTU Discovery and
// adjust the VPN interface MTU accordingly.
PMTUD bool
// serverIP is used for path MTU discovery
serverIP netip.Addr
// vpnType is used for path MTU discovery to find the protocol overhead.
// It can be "wireguard" or "openvpn".
vpnType string
// Port forwarding fields:
serverName string // used for PIA
canPortForward bool // used for PIA
username string // used for PIA
@@ -21,6 +38,16 @@ type tunnelUpData struct {
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
l.client.CloseIdleConnections()
if data.PMTUD {
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
mtuLogger.Info("finding maximum MTU, this takes around 3 seconds")
err := updateToMaxMTU(ctx, data.vpnIntf, data.serverIP, data.vpnType,
l.netLinker, mtuLogger)
if err != nil {
l.logger.Error(err.Error())
}
}
for _, vpnPort := range l.vpnInputPorts {
err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf)
if err != nil {
@@ -57,3 +84,50 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
l.logger.Error(err.Error())
}
}
var errVPNTypeUnknown = errors.New("unknown VPN type")
func updateToMaxMTU(ctx context.Context, vpnInterface string,
serverIP netip.Addr, vpnType string, netlinker NetLinker, logger *log.Logger,
) error {
link, err := netlinker.LinkByName(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN interface by name: %w", err)
}
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
// protocol overhead, so start lower than 1500 according to the protocol used.
const physicalLinkMTU = 1500
vpnLinkMTU := physicalLinkMTU
switch vpnType {
case "wireguard":
vpnLinkMTU -= 60 // Wireguard overhead
case "openvpn":
vpnLinkMTU -= 41 // OpenVPN overhead
default:
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
}
// Setting the VPN link MTU to 1500 might interrupt the connection until
// the new MTU is set again, but this is necessary to find the highest valid MTU.
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, serverIP, vpnLinkMTU, pingTimeout, logger)
if err != nil {
return fmt.Errorf("path MTU discovering: %w", err)
}
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
logger.Infof("VPN interface %s MTU set to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/wireguard"
@@ -16,11 +17,11 @@ import (
func setupWireguard(ctx context.Context, netlinker NetLinker,
fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error,
wireguarder *wireguard.Wireguard, connection models.Connection, err error,
) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, "", false, fmt.Errorf("finding a VPN server: %w", err)
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
}
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
@@ -31,13 +32,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
if err != nil {
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
return nil, models.Connection{}, fmt.Errorf("creating Wireguard: %w", err)
}
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil {
return nil, "", false, fmt.Errorf("setting firewall: %w", err)
return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
}
return wireguarder, connection.ServerName, connection.PortForward, nil
return wireguarder, connection, nil
}