From 162d24486576693aeb819b3636dc464f033cb66c Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 10 Sep 2025 14:43:21 +0000 Subject: [PATCH] Use PMTUD to set the MTU to the VPN interface - Add `VPN_PMTUD` option enabled by default - One can revert to use `VPN_PMTUD=off` to disable the new PMTUD mechanism --- Dockerfile | 1 + cmd/gluetun/main.go | 1 + internal/configuration/settings/vpn.go | 11 ++++ internal/netlink/link.go | 4 ++ internal/pmtud/errors.go | 14 +++++ internal/pmtud/ipv4.go | 4 +- internal/pmtud/ipv6.go | 4 +- internal/vpn/interfaces.go | 1 + internal/vpn/openvpn.go | 19 ++++--- internal/vpn/run.go | 16 ++++-- internal/vpn/tunnelup.go | 78 +++++++++++++++++++++++++- internal/vpn/wireguard.go | 13 +++-- 12 files changed, 141 insertions(+), 25 deletions(-) diff --git a/Dockerfile b/Dockerfile index 5e2fa98a..b043ea9b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -77,6 +77,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ VPN_TYPE=openvpn \ # Common VPN options VPN_INTERFACE=tun0 \ + VPN_PMTUD=on \ # OpenVPN OPENVPN_ENDPOINT_IP= \ OPENVPN_ENDPOINT_PORT= \ diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index c5b47126..65e9dbb9 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -580,6 +580,7 @@ type Linker interface { LinkDel(link netlink.Link) (err error) LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetDown(link netlink.Link) (err error) + LinkSetMTU(link netlink.Link, mtu int) error } type clier interface { diff --git a/internal/configuration/settings/vpn.go b/internal/configuration/settings/vpn.go index aec51543..2472e76b 100644 --- a/internal/configuration/settings/vpn.go +++ b/internal/configuration/settings/vpn.go @@ -18,6 +18,7 @@ type VPN struct { Provider Provider `json:"provider"` OpenVPN OpenVPN `json:"openvpn"` Wireguard Wireguard `json:"wireguard"` + PMTUD *bool `json:"pmtud"` } // TODO v4 remove pointer for receiver (because of Surfshark). @@ -54,6 +55,7 @@ func (v *VPN) Copy() (copied VPN) { Provider: v.Provider.copy(), OpenVPN: v.OpenVPN.copy(), Wireguard: v.Wireguard.copy(), + PMTUD: gosettings.CopyPointer(v.PMTUD), } } @@ -62,6 +64,7 @@ func (v *VPN) OverrideWith(other VPN) { v.Provider.overrideWith(other.Provider) v.OpenVPN.overrideWith(other.OpenVPN) v.Wireguard.overrideWith(other.Wireguard) + v.PMTUD = gosettings.OverrideWithPointer(v.PMTUD, other.PMTUD) } func (v *VPN) setDefaults() { @@ -69,6 +72,7 @@ func (v *VPN) setDefaults() { v.Provider.setDefaults() v.OpenVPN.setDefaults(v.Provider.Name) v.Wireguard.setDefaults(v.Provider.Name) + v.PMTUD = gosettings.DefaultPointer(v.PMTUD, true) } func (v VPN) String() string { @@ -86,6 +90,8 @@ func (v VPN) toLinesNode() (node *gotree.Node) { node.AppendNode(v.Wireguard.toLinesNode()) } + node.Appendf("Path MTU discovery update: %s", gosettings.BoolToYesNo(v.PMTUD)) + return node } @@ -107,5 +113,10 @@ func (v *VPN) read(r *reader.Reader) (err error) { return fmt.Errorf("wireguard: %w", err) } + v.PMTUD, err = r.BoolPtr("VPN_PMTUD") + if err != nil { + return err + } + return nil } diff --git a/internal/netlink/link.go b/internal/netlink/link.go index d810e47e..b2c96134 100644 --- a/internal/netlink/link.go +++ b/internal/netlink/link.go @@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) { return netlink.LinkSetDown(linkToNetlinkLink(&link)) } +func (n *NetLink) LinkSetMTU(link Link, mtu int) error { + return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu) +} + type netlinkLinkImpl struct { attrs *netlink.LinkAttrs linkType string diff --git a/internal/pmtud/errors.go b/internal/pmtud/errors.go index 095deb86..1e81f1ae 100644 --- a/internal/pmtud/errors.go +++ b/internal/pmtud/errors.go @@ -1,10 +1,24 @@ package pmtud import ( + "context" "errors" + "fmt" + "net" + "time" ) var ( ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable") ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported") ) + +func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive + switch { + case errors.Is(timedCtx.Err(), context.DeadlineExceeded): + err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout) + case timedCtx.Err() != nil: + err = timedCtx.Err() + } + return err +} diff --git a/internal/pmtud/ipv4.go b/internal/pmtud/ipv4.go index 9b080e0b..c75c6e3d 100644 --- a/internal/pmtud/ipv4.go +++ b/internal/pmtud/ipv4.go @@ -73,6 +73,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) if err != nil { + err = wrapConnErr(err, ctx, pingTimeout) return 0, fmt.Errorf("writing ICMP message: %w", err) } @@ -84,6 +85,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J bytesRead, _, err := conn.ReadFrom(buffer) if err != nil { + err = wrapConnErr(err, ctx, pingTimeout) return 0, fmt.Errorf("reading from ICMP connection: %w", err) } packetBytes := buffer[:bytesRead] @@ -135,7 +137,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr, if inboundID == outboundID { return physicalLinkMTU, nil } - logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d", + logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d", inboundID, outboundID) continue default: diff --git a/internal/pmtud/ipv6.go b/internal/pmtud/ipv6.go index a9cc196e..b73ea09f 100644 --- a/internal/pmtud/ipv6.go +++ b/internal/pmtud/ipv6.go @@ -53,6 +53,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr, _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()}) if err != nil { + err = wrapConnErr(err, ctx, pingTimeout) return 0, fmt.Errorf("writing ICMP message: %w", err) } @@ -64,6 +65,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr, // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J bytesRead, _, err := conn.ReadFrom(buffer) if err != nil { + err = wrapConnErr(err, ctx, pingTimeout) return 0, fmt.Errorf("reading from ICMP connection: %w", err) } packetBytes := buffer[:bytesRead] @@ -106,7 +108,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr, if inboundID == outboundID { return physicalLinkMTU, nil } - logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d", + logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d", inboundID, outboundID) continue default: diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 68103690..fa075bbd 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -81,6 +81,7 @@ type Linker interface { LinkDel(link netlink.Link) (err error) LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetDown(link netlink.Link) (err error) + LinkSetMTU(link netlink.Link, mtu int) (err error) } type DNSLoop interface { diff --git a/internal/vpn/openvpn.go b/internal/vpn/openvpn.go index 102640e1..c9842f43 100644 --- a/internal/vpn/openvpn.go +++ b/internal/vpn/openvpn.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/provider" ) @@ -14,39 +15,39 @@ import ( func setupOpenVPN(ctx context.Context, fw Firewall, openvpnConf OpenVPN, providerConf provider.Provider, settings settings.VPN, ipv6Supported bool, starter CmdStarter, - logger openvpn.Logger) (runner *openvpn.Runner, serverName string, - canPortForward bool, err error, + logger openvpn.Logger) (runner *openvpn.Runner, + connection models.Connection, err error, ) { - connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) + connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) if err != nil { - return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err) + return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err) } lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported) if err := openvpnConf.WriteConfig(lines); err != nil { - return nil, "", false, fmt.Errorf("writing configuration to file: %w", err) + return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err) } if *settings.OpenVPN.User != "" { err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password) if err != nil { - return nil, "", false, fmt.Errorf("writing auth to file: %w", err) + return nil, models.Connection{}, fmt.Errorf("writing auth to file: %w", err) } } if *settings.OpenVPN.KeyPassphrase != "" { err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase) if err != nil { - return nil, "", false, fmt.Errorf("writing askpass file: %w", err) + return nil, models.Connection{}, fmt.Errorf("writing askpass file: %w", err) } } if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil { - return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err) + return nil, models.Connection{}, fmt.Errorf("allowing VPN connection through firewall: %w", err) } runner = openvpn.NewRunner(settings.OpenVPN, starter, logger) - return runner, connection.ServerName, connection.PortForward, nil + return runner, connection, nil } diff --git a/internal/vpn/run.go b/internal/vpn/run.go index a0cc0274..1788c79f 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -5,6 +5,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/vpn" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/log" ) @@ -28,17 +29,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { var vpnRunner interface { Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{}) } - var serverName, vpnInterface string - var canPortForward bool + var vpnInterface string + var connection models.Connection var err error subLogger := l.logger.New(log.SetComponent(settings.Type)) if settings.Type == vpn.OpenVPN { vpnInterface = settings.OpenVPN.Interface - vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw, + vpnRunner, connection, err = setupOpenVPN(ctx, l.fw, l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger) } else { // Wireguard vpnInterface = settings.Wireguard.Interface - vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw, + vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw, providerConf, settings, l.ipv6Supported, subLogger) } if err != nil { @@ -46,8 +47,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { continue } tunnelUpData := tunnelUpData{ - serverName: serverName, - canPortForward: canPortForward, + PMTUD: *settings.PMTUD, + serverIP: connection.IP, + vpnType: settings.Type, + serverName: connection.ServerName, + canPortForward: connection.PortForward, portForwarder: portForwarder, vpnIntf: vpnInterface, username: settings.Provider.PortForwarding.Username, diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 103d65dd..473766a3 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -2,15 +2,32 @@ package vpn import ( "context" + "errors" + "fmt" + "net/netip" + "time" "github.com/qdm12/dns/v2/pkg/check" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/pmtud" "github.com/qdm12/gluetun/internal/version" + "github.com/qdm12/log" ) type tunnelUpData struct { - // Port forwarding - vpnIntf string + // vpnIntf is the name of the VPN network interface + // which is used both for port forwarding and MTU discovery + vpnIntf string + // Path MTU discovery fields: + // PMTUD indicates whether to perform Path MTU Discovery and + // adjust the VPN interface MTU accordingly. + PMTUD bool + // serverIP is used for path MTU discovery + serverIP netip.Addr + // vpnType is used for path MTU discovery to find the protocol overhead. + // It can be "wireguard" or "openvpn". + vpnType string + // Port forwarding fields: serverName string // used for PIA canPortForward bool // used for PIA username string // used for PIA @@ -21,6 +38,16 @@ type tunnelUpData struct { func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { l.client.CloseIdleConnections() + if data.PMTUD { + mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) + mtuLogger.Info("finding maximum MTU, this takes around 3 seconds") + err := updateToMaxMTU(ctx, data.vpnIntf, data.serverIP, data.vpnType, + l.netLinker, mtuLogger) + if err != nil { + l.logger.Error(err.Error()) + } + } + for _, vpnPort := range l.vpnInputPorts { err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf) if err != nil { @@ -57,3 +84,50 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { l.logger.Error(err.Error()) } } + +var errVPNTypeUnknown = errors.New("unknown VPN type") + +func updateToMaxMTU(ctx context.Context, vpnInterface string, + serverIP netip.Addr, vpnType string, netlinker NetLinker, logger *log.Logger, +) error { + link, err := netlinker.LinkByName(vpnInterface) + if err != nil { + return fmt.Errorf("getting VPN interface by name: %w", err) + } + + // Note: no point testing for an MTU of 1500, it will never work due to the VPN + // protocol overhead, so start lower than 1500 according to the protocol used. + const physicalLinkMTU = 1500 + vpnLinkMTU := physicalLinkMTU + switch vpnType { + case "wireguard": + vpnLinkMTU -= 60 // Wireguard overhead + case "openvpn": + vpnLinkMTU -= 41 // OpenVPN overhead + default: + return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType) + } + + // Setting the VPN link MTU to 1500 might interrupt the connection until + // the new MTU is set again, but this is necessary to find the highest valid MTU. + logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU) + + err = netlinker.LinkSetMTU(link, vpnLinkMTU) + if err != nil { + return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) + } + + const pingTimeout = time.Second + vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, serverIP, vpnLinkMTU, pingTimeout, logger) + if err != nil { + return fmt.Errorf("path MTU discovering: %w", err) + } + + err = netlinker.LinkSetMTU(link, vpnLinkMTU) + if err != nil { + return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err) + } + + logger.Infof("VPN interface %s MTU set to maximum valid MTU %d", vpnInterface, vpnLinkMTU) + return nil +} diff --git a/internal/vpn/wireguard.go b/internal/vpn/wireguard.go index 7f5c4246..60fc9afd 100644 --- a/internal/vpn/wireguard.go +++ b/internal/vpn/wireguard.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/wireguard" @@ -16,11 +17,11 @@ import ( func setupWireguard(ctx context.Context, netlinker NetLinker, fw Firewall, providerConf provider.Provider, settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) ( - wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error, + wireguarder *wireguard.Wireguard, connection models.Connection, err error, ) { - connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) + connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) if err != nil { - return nil, "", false, fmt.Errorf("finding a VPN server: %w", err) + return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err) } wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported) @@ -31,13 +32,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker, wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger) if err != nil { - return nil, "", false, fmt.Errorf("creating Wireguard: %w", err) + return nil, models.Connection{}, fmt.Errorf("creating Wireguard: %w", err) } err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface) if err != nil { - return nil, "", false, fmt.Errorf("setting firewall: %w", err) + return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err) } - return wireguarder, connection.ServerName, connection.PortForward, nil + return wireguarder, connection, nil }