Compare commits

..

18 Commits

Author SHA1 Message Date
Quentin McGaw
90b9e81129 Merge branch 'master' into pmtu 2025-11-07 21:55:58 +00:00
Quentin McGaw
2391c890b4 Run MTU discovery AFTER healthcheck is started 2025-10-17 00:39:44 +00:00
Quentin McGaw
51fd46b58e Merge branch 'master' into pmtu 2025-10-17 00:17:45 +00:00
Quentin McGaw
906e7b5ee1 Remove unneeded error context wrapping 2025-10-14 17:56:54 +00:00
Quentin McGaw
5428580b8f Handle ICMP not permitted errors 2025-10-14 17:56:04 +00:00
Quentin McGaw
6c25ee53f1 Fix unit test 2025-10-06 11:08:03 +00:00
Quentin McGaw
b9051b02bf Use the VPN local gateway IP address to run path MTU discovery 2025-10-06 10:03:15 +00:00
Quentin McGaw
f0f3193c1c Remove VPN_PMTUD option 2025-10-06 09:57:15 +00:00
Quentin McGaw
c0ebd180cb Revert to VPN original MTU (set by WIREGUARD_MTU for example) if ICMP fails 2025-10-06 09:57:15 +00:00
Quentin McGaw
b6e873cf25 Improve logging in case of ICMP blocked 2025-10-06 09:57:15 +00:00
Quentin McGaw
ccc2f306b9 Fallback on 1320 if ICMP is blocked 2025-10-06 09:57:15 +00:00
Quentin McGaw
5b1dc295fe Return an error if all MTUs failed to test 2025-10-06 09:57:15 +00:00
Quentin McGaw
00bc8bbbbb Handle administrative prohibition of ICMP 2025-10-06 09:57:15 +00:00
Quentin McGaw
8bef380d8c Fix unit test 2025-10-06 09:57:15 +00:00
Quentin McGaw
9ad1907574 Update log that PMTUD can take up to 4s 2025-10-06 09:57:15 +00:00
Quentin McGaw
d83999d954 Make binary search faster with 11 parallel queries 2025-10-06 09:57:15 +00:00
Quentin McGaw
162d244865 Use PMTUD to set the MTU to the VPN interface
- Add `VPN_PMTUD` option enabled by default
- One can revert to use `VPN_PMTUD=off` to disable the new PMTUD mechanism
2025-10-06 09:57:15 +00:00
Quentin McGaw
e21d798f57 pmtud package 2025-10-06 09:57:15 +00:00
35 changed files with 1044 additions and 471 deletions

View File

@@ -159,8 +159,6 @@ ENV VPN_SERVICE_PROVIDER=pia \
FIREWALL_INPUT_PORTS= \
FIREWALL_OUTBOUND_SUBNETS= \
FIREWALL_DEBUG=off \
# IPv6
IPV6_CHECK_ADDRESS=[2606:4700::6810:84e5]:443 \
# Logging
LOG_LEVEL=info \
# Health

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"io/fs"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
@@ -243,13 +242,10 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return err
}
ipv6SupportLevel, err := netLinker.FindIPv6SupportLevel(ctx,
allSettings.IPv6.CheckAddress, firewallConf)
ipv6Supported, err := netLinker.IsIPv6Supported()
if err != nil {
return fmt.Errorf("checking for IPv6 support: %w", err)
}
ipv6Supported := ipv6SupportLevel == netlink.IPv6Supported ||
ipv6SupportLevel == netlink.IPv6Internet
err = allSettings.Validate(storage, ipv6Supported, logger)
if err != nil {
@@ -434,7 +430,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
httpClient, unzipper, parallelResolver, publicIPLooper.Fetcher(), openvpnFileExtractor)
vpnLogger := logger.New(log.SetComponent("vpn"))
vpnLooper := vpn.NewLoop(allSettings.VPN, ipv6SupportLevel, allSettings.Firewall.VPNInputPorts,
vpnLooper := vpn.NewLoop(allSettings.VPN, ipv6Supported, allSettings.Firewall.VPNInputPorts,
providers, storage, allSettings.Health, healthChecker, healthcheckServer, ovpnConf, netLinker, firewallConf,
routingConf, portForwardLooper, cmder, publicIPLooper, dnsLooper, vpnLogger, httpClient,
buildInfo, *allSettings.Version.Enabled)
@@ -478,7 +474,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
logger.New(log.SetComponent("http server")),
allSettings.ControlServer.AuthFilePath,
buildInfo, vpnLooper, portForwardLooper, dnsLooper, updaterLooper, publicIPLooper,
storage, ipv6SupportLevel.IsSupported())
storage, ipv6Supported)
if err != nil {
return fmt.Errorf("setting up control server: %w", err)
}
@@ -554,9 +550,7 @@ type netLinker interface {
Ruler
Linker
IsWireguardSupported() (ok bool, err error)
FindIPv6SupportLevel(ctx context.Context,
checkAddress netip.AddrPort, firewall netlink.Firewall,
) (level netlink.IPv6SupportLevel, err error)
IsIPv6Supported() (ok bool, err error)
PatchLoggerLevel(level log.Level)
}
@@ -587,6 +581,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) error
}
type clier interface {

View File

@@ -1,14 +0,0 @@
package cli
import (
"context"
"net/netip"
)
type noopFirewall struct{}
func (f *noopFirewall) AcceptOutput(_ context.Context, _, _ string, _ netip.Addr,
_ uint16, _ bool,
) (err error) {
return nil
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/openvpn/extract"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/storage"
@@ -41,9 +40,7 @@ type IPFetcher interface {
}
type IPv6Checker interface {
FindIPv6SupportLevel(ctx context.Context,
checkAddress netip.AddrPort, firewall netlink.Firewall,
) (level netlink.IPv6SupportLevel, err error)
IsIPv6Supported() (supported bool, err error)
}
func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
@@ -61,14 +58,12 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
}
allSettings.SetDefaults()
ipv6SupportLevel, err := ipv6Checker.FindIPv6SupportLevel(context.Background(),
allSettings.IPv6.CheckAddress, &noopFirewall{})
ipv6Supported, err := ipv6Checker.IsIPv6Supported()
if err != nil {
return fmt.Errorf("checking for IPv6 support: %w", err)
}
err = allSettings.Validate(storage, ipv6SupportLevel.IsSupported(), logger)
if err != nil {
if err = allSettings.Validate(storage, ipv6Supported, logger); err != nil {
return fmt.Errorf("validating settings: %w", err)
}
@@ -84,13 +79,13 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
unzipper, parallelResolver, ipFetcher, openvpnFileExtractor)
providerConf := providers.Get(allSettings.VPN.Provider.Name)
connection, err := providerConf.GetConnection(
allSettings.VPN.Provider.ServerSelection, ipv6SupportLevel == netlink.IPv6Internet)
allSettings.VPN.Provider.ServerSelection, ipv6Supported)
if err != nil {
return err
}
lines := providerConf.OpenVPNConfig(connection,
allSettings.VPN.OpenVPN, ipv6SupportLevel.IsSupported())
allSettings.VPN.OpenVPN, ipv6Supported)
fmt.Println(strings.Join(lines, "\n"))
return nil

View File

@@ -1,51 +0,0 @@
package settings
import (
"net/netip"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
"github.com/qdm12/gotree"
)
// IPv6 contains settings regarding IPv6 configuration.
type IPv6 struct {
// CheckAddress is the TCP ip:port address to dial to check
// IPv6 is supported, in case a default IPv6 route is found.
// It defaults to cloudflare.com address [2606:4700::6810:84e5]:443
CheckAddress netip.AddrPort
}
func (i IPv6) validate() (err error) {
return nil
}
func (i *IPv6) copy() (copied IPv6) {
return IPv6{
CheckAddress: i.CheckAddress,
}
}
func (i *IPv6) overrideWith(other IPv6) {
i.CheckAddress = gosettings.OverrideWithValidator(i.CheckAddress, other.CheckAddress)
}
func (i *IPv6) setDefaults() {
defaultCheckAddress := netip.MustParseAddrPort("[2606:4700::6810:84e5]:443")
i.CheckAddress = gosettings.DefaultComparable(i.CheckAddress, defaultCheckAddress)
}
func (i IPv6) String() string {
return i.toLinesNode().String()
}
func (i IPv6) toLinesNode() (node *gotree.Node) {
node = gotree.New("IPv6 settings:")
node.Appendf("Check address: %s", i.CheckAddress)
return node
}
func (i *IPv6) read(r *reader.Reader) (err error) {
i.CheckAddress, err = r.NetipAddrPort("IPV6_CHECK_ADDRESS")
return err
}

View File

@@ -27,7 +27,6 @@ type Settings struct {
Updater Updater
Version Version
VPN VPN
IPv6 IPv6
Pprof pprof.Settings
}
@@ -54,7 +53,6 @@ func (s *Settings) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Support
"system": s.System.validate,
"updater": s.Updater.Validate,
"version": s.Version.validate,
"ipv6": s.IPv6.validate,
// Pprof validation done in pprof constructor
"VPN": func() error {
return s.VPN.Validate(filterChoicesGetter, ipv6Supported, warner)
@@ -87,7 +85,6 @@ func (s *Settings) copy() (copied Settings) {
Version: s.Version.copy(),
VPN: s.VPN.Copy(),
Pprof: s.Pprof.Copy(),
IPv6: s.IPv6.copy(),
}
}
@@ -109,7 +106,6 @@ func (s *Settings) OverrideWith(other Settings,
patchedSettings.Version.overrideWith(other.Version)
patchedSettings.VPN.OverrideWith(other.VPN)
patchedSettings.Pprof.OverrideWith(other.Pprof)
patchedSettings.IPv6.overrideWith(other.IPv6)
err = patchedSettings.Validate(filterChoicesGetter, ipv6Supported, warner)
if err != nil {
return err
@@ -125,7 +121,6 @@ func (s *Settings) SetDefaults() {
s.Health.SetDefaults()
s.HTTPProxy.setDefaults()
s.Log.setDefaults()
s.IPv6.setDefaults()
s.PublicIP.setDefaults()
s.Shadowsocks.setDefaults()
s.Storage.setDefaults()
@@ -147,7 +142,6 @@ func (s Settings) toLinesNode() (node *gotree.Node) {
node.AppendNode(s.DNS.toLinesNode())
node.AppendNode(s.Firewall.toLinesNode())
node.AppendNode(s.Log.toLinesNode())
node.AppendNode(s.IPv6.toLinesNode())
node.AppendNode(s.Health.toLinesNode())
node.AppendNode(s.Shadowsocks.toLinesNode())
node.AppendNode(s.HTTPProxy.toLinesNode())
@@ -214,7 +208,6 @@ func (s *Settings) Read(r *reader.Reader, warner Warner) (err error) {
"updater": s.Updater.read,
"version": s.Version.read,
"VPN": s.VPN.read,
"IPv6": s.IPv6.read,
"profiling": s.Pprof.Read,
}

View File

@@ -55,8 +55,6 @@ func Test_Settings_String(t *testing.T) {
| └── Enabled: yes
├── Log settings:
| └── Log level: INFO
├── IPv6 settings:
| └── Check address: [2606:4700::6810:84e5]:443
├── Health settings:
| ├── Server listening address: 127.0.0.1:9999
| ├── Target address: cloudflare.com:443

View File

@@ -45,6 +45,7 @@ type Wireguard struct {
// It has been lowered to 1320 following quite a bit of
// investigation in the issue:
// https://github.com/qdm12/gluetun/issues/2533.
// Note this should now be replaced with the PMTUD feature.
MTU uint16 `json:"mtu"`
// Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace".

View File

@@ -162,24 +162,6 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction)
}
func (c *Config) AcceptOutput(ctx context.Context,
protocol, intf string, ip netip.Addr, port uint16, remove bool,
) error {
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT -d %s %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), ip, interfaceFlag, protocol, protocol, port)
if ip.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// Thanks to @npawelek.
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,

View File

@@ -1,19 +1,9 @@
package netlink
import (
"context"
"net/netip"
"github.com/qdm12/log"
)
import "github.com/qdm12/log"
type DebugLogger interface {
Debug(message string)
Debugf(format string, args ...any)
Patch(options ...log.Option)
}
type Firewall interface {
AcceptOutput(ctx context.Context, protocol, intf string, ip netip.Addr,
port uint16, remove bool) (err error)
}

View File

@@ -1,106 +1,37 @@
package netlink
import (
"context"
"fmt"
"net"
"net/netip"
"time"
)
type IPv6SupportLevel uint8
const (
IPv6Unsupported = iota
// IPv6Supported indicates the host supports IPv6 but has no access to the
// Internet via IPv6. It is true if one IPv6 route is found and no default
// IPv6 route is found.
IPv6Supported
// IPv6Internet indicates the host has access to the Internet via IPv6,
// which is detected when a default IPv6 route is found.
IPv6Internet
)
func (i IPv6SupportLevel) IsSupported() bool {
return i == IPv6Supported || i == IPv6Internet
}
func (n *NetLink) FindIPv6SupportLevel(ctx context.Context,
checkAddress netip.AddrPort, firewall Firewall,
) (level IPv6SupportLevel, err error) {
func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
routes, err := n.RouteList(FamilyV6)
if err != nil {
return IPv6Unsupported, fmt.Errorf("listing IPv6 routes: %w", err)
return false, fmt.Errorf("listing IPv6 routes: %w", err)
}
// Check each route for IPv6 due to Podman bug listing IPv4 routes
// as IPv6 routes at container start, see:
// https://github.com/qdm12/gluetun/issues/1241#issuecomment-1333405949
level = IPv6Unsupported
for _, route := range routes {
link, err := n.LinkByIndex(route.LinkIndex)
if err != nil {
return IPv6Unsupported, fmt.Errorf("finding link corresponding to route: %w", err)
return false, fmt.Errorf("finding link corresponding to route: %w", err)
}
sourceIsIPv4 := route.Src.IsValid() && route.Src.Is4()
destinationIsIPv4 := route.Dst.IsValid() && route.Dst.Addr().Is4()
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
switch {
case sourceIsIPv4 && destinationIsIPv4,
case !sourceIsIPv6 && !destinationIsIPv6,
destinationIsIPv6 && route.Dst.Addr().IsLoopback():
case route.Dst.Addr().IsUnspecified(): // default ipv6 route
n.debugLogger.Debugf("IPv6 default route found on link %s", link.Name)
err = dialAddrThroughFirewall(ctx, link.Name, checkAddress, firewall)
if err != nil {
n.debugLogger.Debugf("IPv6 query failed on %s: %w", link.Name, err)
level = IPv6Supported
continue
}
n.debugLogger.Debugf("IPv6 internet is accessible through link %s", link.Name)
return IPv6Internet, nil
default: // non-default ipv6 route found
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
level = IPv6Supported
continue
}
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
return true, nil
}
if level == IPv6Unsupported {
n.debugLogger.Debugf("no IPv6 route found in %d routes", len(routes))
}
return level, nil
}
func dialAddrThroughFirewall(ctx context.Context, intf string,
checkAddress netip.AddrPort, firewall Firewall,
) (err error) {
const protocol = "tcp"
remove := false
err = firewall.AcceptOutput(ctx, protocol, intf,
checkAddress.Addr(), checkAddress.Port(), remove)
if err != nil {
return fmt.Errorf("accepting output traffic: %w", err)
}
defer func() {
remove = true
firewallErr := firewall.AcceptOutput(ctx, protocol, intf,
checkAddress.Addr(), checkAddress.Port(), remove)
if err == nil && firewallErr != nil {
err = fmt.Errorf("removing output traffic rule: %w", firewallErr)
}
}()
dialer := &net.Dialer{
Timeout: time.Second,
}
conn, err := dialer.DialContext(ctx, protocol, checkAddress.String())
if err != nil {
return fmt.Errorf("dialing: %w", err)
}
err = conn.Close()
if err != nil {
return fmt.Errorf("closing connection: %w", err)
}
return nil
n.debugLogger.Debugf("IPv6 is not supported after searching %d routes",
len(routes))
return false, nil
}

View File

@@ -1,167 +0,0 @@
package netlink
import (
"context"
"errors"
"net"
"net/netip"
"strings"
"testing"
"time"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func isIPv6LocallySupported() bool {
dialer := net.Dialer{Timeout: time.Millisecond}
_, err := dialer.Dial("tcp6", "[::1]:9999")
return !strings.HasSuffix(err.Error(), "connect: cannot assign requested address")
}
// Susceptible to TOCTOU but it should be fine for the use case.
func findAvailableTCPPort(t *testing.T) (port uint16) {
t.Helper()
config := &net.ListenConfig{}
listener, err := config.Listen(context.Background(), "tcp", "localhost:0")
require.NoError(t, err)
addr := listener.Addr().String()
err = listener.Close()
require.NoError(t, err)
addrPort, err := netip.ParseAddrPort(addr)
require.NoError(t, err)
return addrPort.Port()
}
func Test_dialAddrThroughFirewall(t *testing.T) {
t.Parallel()
errTest := errors.New("test error")
const ipv6InternetWorks = false
testCases := map[string]struct {
getIPv6CheckAddr func(t *testing.T) netip.AddrPort
firewallAddErr error
firewallRemoveErr error
errMessageRegex func() string
}{
"cloudflare.com": {
getIPv6CheckAddr: func(_ *testing.T) netip.AddrPort {
return netip.MustParseAddrPort("[2606:4700::6810:84e5]:443")
},
errMessageRegex: func() string {
if ipv6InternetWorks {
return ""
}
return "dialing: dial tcp \\[2606:4700::6810:84e5\\]:443: " +
"connect: (cannot assign requested address|network is unreachable)"
},
},
"local_server": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
network := "tcp6"
loopback := netip.MustParseAddr("::1")
if !isIPv6LocallySupported() {
network = "tcp4"
loopback = netip.MustParseAddr("127.0.0.1")
}
listener, err := net.ListenTCP(network, nil)
require.NoError(t, err)
t.Cleanup(func() {
err := listener.Close()
require.NoError(t, err)
})
addrPort := netip.MustParseAddrPort(listener.Addr().String())
return netip.AddrPortFrom(loopback, addrPort.Port())
},
},
"no_local_server": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
loopback := netip.MustParseAddr("::1")
if !ipv6InternetWorks {
loopback = netip.MustParseAddr("127.0.0.1")
}
availablePort := findAvailableTCPPort(t)
return netip.AddrPortFrom(loopback, availablePort)
},
errMessageRegex: func() string {
return "dialing: dial tcp (\\[::1\\]|127\\.0\\.0\\.1):[1-9][0-9]{1,4}: " +
"connect: connection refused"
},
},
"firewall_add_error": {
firewallAddErr: errTest,
errMessageRegex: func() string {
return "accepting output traffic: test error"
},
},
"firewall_remove_error": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
network := "tcp4"
loopback := netip.MustParseAddr("127.0.0.1")
listener, err := net.ListenTCP(network, nil)
require.NoError(t, err)
t.Cleanup(func() {
err := listener.Close()
require.NoError(t, err)
})
addrPort := netip.MustParseAddrPort(listener.Addr().String())
return netip.AddrPortFrom(loopback, addrPort.Port())
},
firewallRemoveErr: errTest,
errMessageRegex: func() string {
return "removing output traffic rule: test error"
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var checkAddr netip.AddrPort
if testCase.getIPv6CheckAddr != nil {
checkAddr = testCase.getIPv6CheckAddr(t)
}
ctx := context.Background()
const intf = "eth0"
firewall := NewMockFirewall(ctrl)
call := firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
checkAddr.Addr(), checkAddr.Port(), false).
Return(testCase.firewallAddErr)
if testCase.firewallAddErr == nil {
firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
checkAddr.Addr(), checkAddr.Port(), true).
Return(testCase.firewallRemoveErr).After(call)
}
err := dialAddrThroughFirewall(ctx, intf, checkAddr, firewall)
var errMessageRegex string
if testCase.errMessageRegex != nil {
errMessageRegex = testCase.errMessageRegex()
}
if errMessageRegex == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Regexp(t, errMessageRegex, err.Error())
}
})
}
}

View File

@@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetMTU(link Link, mtu int) error {
return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu)
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string

View File

@@ -1,3 +0,0 @@
package netlink
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall

View File

@@ -1,50 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/netlink (interfaces: Firewall)
// Package netlink is a generated GoMock package.
package netlink
import (
context "context"
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockFirewall is a mock of Firewall interface.
type MockFirewall struct {
ctrl *gomock.Controller
recorder *MockFirewallMockRecorder
}
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
type MockFirewallMockRecorder struct {
mock *MockFirewall
}
// NewMockFirewall creates a new mock instance.
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
mock := &MockFirewall{ctrl: ctrl}
mock.recorder = &MockFirewallMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
return m.recorder
}
// AcceptOutput mocks base method.
func (m *MockFirewall) AcceptOutput(arg0 context.Context, arg1, arg2 string, arg3 netip.Addr, arg4 uint16, arg5 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptOutput", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(error)
return ret0
}
// AcceptOutput indicates an expected call of AcceptOutput.
func (mr *MockFirewallMockRecorder) AcceptOutput(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutput", reflect.TypeOf((*MockFirewall)(nil).AcceptOutput), arg0, arg1, arg2, arg3, arg4, arg5)
}

View File

@@ -0,0 +1,49 @@
package pmtud
import (
"net"
"time"
"golang.org/x/net/ipv4"
)
var _ net.PacketConn = &ipv4Wrapper{}
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
// the net.PacketConn interface. It's only used for Darwin or iOS.
type ipv4Wrapper struct {
ipv4Conn *ipv4.PacketConn
}
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
return &ipv4Wrapper{ipv4Conn: ipv4}
}
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
return n, addr, err
}
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return i.ipv4Conn.WriteTo(p, nil, addr)
}
func (i *ipv4Wrapper) Close() error {
return i.ipv4Conn.Close()
}
func (i *ipv4Wrapper) LocalAddr() net.Addr {
return i.ipv4Conn.LocalAddr()
}
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
return i.ipv4Conn.SetDeadline(t)
}
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
return i.ipv4Conn.SetReadDeadline(t)
}
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
return i.ipv4Conn.SetWriteDeadline(t)
}

83
internal/pmtud/check.go Normal file
View File

@@ -0,0 +1,83 @@
package pmtud
import (
"bytes"
"errors"
"fmt"
"golang.org/x/net/icmp"
)
var (
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
)
func checkMTU(mtu, minMTU, physicalLinkMTU int) (err error) {
switch {
case mtu < minMTU:
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
case mtu > physicalLinkMTU:
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
default:
return nil
}
}
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
outboundMessage *icmp.Message,
) (match bool, err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return false, fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
return inboundBody.ID == outboundBody.ID, nil
}
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
func checkEchoReply(icmpProtocol int, received []byte,
outboundMessage *icmp.Message, truncatedBody bool,
) (err error) {
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
if err != nil {
return fmt.Errorf("parsing invoking packet: %w", err)
}
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
}
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
if inboundBody.ID != outboundBody.ID {
return fmt.Errorf("%w: sent id %d and received id %d",
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
}
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
if err != nil {
return fmt.Errorf("checking sent and received bodies: %w", err)
}
return nil
}
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
if len(received) > len(sent) {
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
ErrICMPEchoDataMismatch, len(sent), len(received))
}
if receivedTruncated {
sent = sent[:len(received)]
}
if !bytes.Equal(received, sent) {
return fmt.Errorf("%w: sent %x and received %x",
ErrICMPEchoDataMismatch, sent, received)
}
return nil
}

10
internal/pmtud/df.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build !linux && !windows
package pmtud
// setDontFragment for platforms other than Linux and Windows
// is not implemented, so we just return assuming the don't
// fragment flag is set on IP packets.
func setDontFragment(fd uintptr) (err error) {
return nil
}

View File

@@ -0,0 +1,12 @@
//go:build linux
package pmtud
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
}

View File

@@ -0,0 +1,13 @@
//go:build windows
package pmtud
import (
"syscall"
)
func setDontFragment(fd uintptr) (err error) {
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
}

29
internal/pmtud/errors.go Normal file
View File

@@ -0,0 +1,29 @@
package pmtud
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrICMPNotPermitted = errors.New("ICMP not permitted")
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
err = fmt.Errorf("%w", ErrICMPNotPermitted)
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}

View File

@@ -0,0 +1,7 @@
package pmtud
type Logger interface {
Debug(msg string)
Debugf(msg string, args ...any)
Warnf(msg string, args ...any)
}

159
internal/pmtud/ipv4.go Normal file
View File

@@ -0,0 +1,159 @@
package pmtud
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"syscall"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
)
const (
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
minIPv4MTU = 68
icmpv4Protocol = 1
)
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
var setDFErr error
err := rawConn.Control(func(fd uintptr) {
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
})
if err == nil {
err = setDFErr
}
return err
}
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return nil, err
}
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
}
return packetConn, nil
}
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
physicalLinkMTU int, pingTimeout time.Duration, logger Logger,
) (mtu int, err error) {
if ip.Is6() {
panic("IP address is not v4")
}
conn, err := listenICMPv4(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
for { // for loop in case we read an echo reply for another ICMP request
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
// Side note: echo reply should be at most the number of bytes
// sent, and can be lower, more precisely 576-ipHeader bytes,
// in case the next hop we are reaching replies with a destination
// unreachable and wants to ensure the response makes it way back
// by keeping a low packet size, see:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.DstUnreach:
const fragmentationRequiredAndDFFlagSetCode = 4
const communicationAdministrativelyProhibitedCode = 13
switch inboundMessage.Code {
case fragmentationRequiredAndDFFlagSetCode:
case communicationAdministrativelyProhibitedCode:
return 0, fmt.Errorf("%w: %w (code %d)",
ErrICMPDestinationUnreachable,
ErrICMPCommunicationAdministrativelyProhibited,
inboundMessage.Code)
default:
return 0, fmt.Errorf("%w: code %d",
ErrICMPDestinationUnreachable, inboundMessage.Code)
}
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
// Note: the go library does not handle this NextHopMTU section.
nextHopMTU := packetBytes[6:8]
mtu = int(binary.BigEndian.Uint16(nextHopMTU))
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
}
// The code below is really for sanity checks
packetBytes = packetBytes[8:]
header, err := ipv4.ParseHeader(packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
}
packetBytes = packetBytes[header.Len:] // truncated original datagram
const truncated = true
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
if err != nil {
return 0, fmt.Errorf("checking echo reply: %w", err)
}
return mtu, nil
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}

122
internal/pmtud/ipv6.go Normal file
View File

@@ -0,0 +1,122 @@
package pmtud
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
const (
minIPv6MTU = 1280
icmpv6Protocol = 58
)
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
var listenConfig net.ListenConfig
const listenAddress = ""
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return nil, err
}
return packetConn, nil
}
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
physicalLinkMTU int, pingTimeout time.Duration, logger Logger,
) (mtu int, err error) {
if ip.Is4() {
panic("IP address is not v6")
}
conn, err := listenICMPv6(ctx)
if err != nil {
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-ctx.Done()
conn.Close()
}()
// First try to send a packet which is too big to get the maximum MTU
// directly.
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
encodedMessage, err := outboundMessage.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
buffer := make([]byte, physicalLinkMTU)
for { // for loop if we encounter another ICMP packet with an unknown id.
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
packetBytes = packetBytes[ipv6.HeaderLen:]
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
if err != nil {
return 0, fmt.Errorf("parsing message: %w", err)
}
switch typedBody := inboundMessage.Body.(type) {
case *icmp.PacketTooBig:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
mtu = typedBody.MTU
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
if err != nil {
return 0, fmt.Errorf("checking MTU: %w", err)
}
// Sanity checks
const truncatedBody = true
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
if err != nil {
return 0, fmt.Errorf("checking invoking message: %w", err)
}
return typedBody.MTU, nil
case *icmp.DstUnreach:
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
if err != nil {
return 0, fmt.Errorf("checking invoking message id: %w", err)
} else if idMatch {
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
}
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
continue
case *icmp.Echo:
inboundID := uint16(typedBody.ID) //nolint:gosec
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
}
}
}

58
internal/pmtud/message.go Normal file
View File

@@ -0,0 +1,58 @@
package pmtud
import (
cryptorand "crypto/rand"
"encoding/binary"
"fmt"
"math/rand/v2"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func buildMessageToSend(ipVersion string, mtu int) (id uint16, message *icmp.Message) {
var seed [32]byte
_, _ = cryptorand.Read(seed[:])
randomSource := rand.NewChaCha8(seed)
const uint16Bytes = 2
idBytes := make([]byte, uint16Bytes)
_, _ = randomSource.Read(idBytes)
id = binary.BigEndian.Uint16(idBytes)
var ipHeaderLength int
var icmpType icmp.Type
switch ipVersion {
case "v4":
ipHeaderLength = ipv4.HeaderLen
icmpType = ipv4.ICMPTypeEcho
case "v6":
ipHeaderLength = ipv6.HeaderLen
icmpType = ipv6.ICMPTypeEchoRequest
default:
panic(fmt.Sprintf("IP version %q not supported", ipVersion))
}
const pingHeaderLength = 0 +
1 + // type
1 + // code
2 + // checksum
2 + // identifier
2 // sequence number
pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength
messageBodyData := make([]byte, pingBodyDataSize)
_, _ = randomSource.Read(messageBodyData)
// See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types
message = &icmp.Message{
Type: icmpType, // echo request
Code: 0, // no code
Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6)
Body: &icmp.Echo{
ID: int(id),
Seq: 0, // only one packet
Data: messageBodyData,
},
}
return id, message
}

View File

@@ -0,0 +1,7 @@
package pmtud
type noopLogger struct{}
func (noopLogger) Debug(_ string) {}
func (noopLogger) Debugf(_ string, _ ...any) {}
func (noopLogger) Warnf(_ string, _ ...any) {}

271
internal/pmtud/pmtud.go Normal file
View File

@@ -0,0 +1,271 @@
package pmtud
import (
"context"
"errors"
"fmt"
"math"
"net"
"net/netip"
"strings"
"time"
"golang.org/x/net/icmp"
)
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
// If the pingTimeout is zero, it defaults to 1 second.
// If the logger is nil, a no-op logger is used.
// It returns [ErrMTUNotFound] if the MTU could not be determined.
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
physicalLinkMTU int, pingTimeout time.Duration, logger Logger) (
mtu int, err error,
) {
if physicalLinkMTU == 0 {
const ethernetStandardMTU = 1500
physicalLinkMTU = ethernetStandardMTU
}
if pingTimeout == 0 {
pingTimeout = time.Second
}
if logger == nil {
logger = &noopLogger{}
}
if ip.Is4() {
logger.Debug("finding IPv4 next hop MTU")
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
default:
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
}
} else {
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
switch {
case err == nil:
return mtu, nil
case errors.Is(err, net.ErrClosed): // blackhole
default:
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
}
}
// Fall back method: send echo requests with different packet
// sizes and check which ones succeed to find the maximum MTU.
logger.Debug("falling back to sending different sized echo packets")
minMTU := minIPv4MTU
if ip.Is6() {
minMTU = minIPv6MTU
}
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
}
type pmtudTestUnit struct {
mtu int
echoID uint16
sentBytes int
ok bool
}
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
minMTU, maxPossibleMTU int, pingTimeout time.Duration,
logger Logger,
) (maxMTU int, err error) {
var ipVersion string
var conn net.PacketConn
if ip.Is4() {
ipVersion = "v4"
conn, err = listenICMPv4(ctx)
} else {
ipVersion = "v6"
conn, err = listenICMPv6(ctx)
}
if err != nil {
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
}
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
return minMTU, nil
}
logger.Debugf("testing the following MTUs: %v", mtusToTest)
tests := make([]pmtudTestUnit, len(mtusToTest))
for i := range mtusToTest {
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
}
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
defer cancel()
go func() {
<-timedCtx.Done()
conn.Close()
}()
for i := range tests {
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
tests[i].echoID = id
encodedMessage, err := message.Marshal(nil)
if err != nil {
return 0, fmt.Errorf("encoding ICMP message: %w", err)
}
tests[i].sentBytes = len(encodedMessage)
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
err = fmt.Errorf("%w", ErrICMPNotPermitted)
}
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
}
err = collectReplies(conn, ipVersion, tests, logger)
switch {
case err == nil: // max possible MTU is working
return tests[len(tests)-1].mtu, nil
case err != nil && errors.Is(err, net.ErrClosed):
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
// so find the highest MTU which worked.
// Note we start from index len(tests) - 2 since the max MTU
// cannot be working if we had a timeout.
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
if tests[i].ok {
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
pingTimeout, logger)
}
}
// All MTUs failed.
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
case err != nil:
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
default:
panic("unreachable")
}
}
// Create the MTU slice of length 11 such that:
// - the first element is the minMTU
// - the last element is the maxMTU
// - elements in-between are separated as close to each other
// The number 11 is chosen to find the final MTU in 3 searches,
// with a total search space of 1728 MTUs which is enough;
// to find it in 2 searches requires 37 parallel queries which
// could be blocked by firewalls.
func makeMTUsToTest(minMTU, maxMTU int) (mtus []int) {
const mtusLength = 11 // find the final MTU in 3 searches
diff := maxMTU - minMTU
switch {
case minMTU > maxMTU:
panic("minMTU > maxMTU")
case diff <= mtusLength:
mtus = make([]int, 0, diff)
for mtu := minMTU; mtu <= maxMTU; mtu++ {
mtus = append(mtus, mtu)
}
default:
step := float64(diff) / float64(mtusLength-1)
mtus = make([]int, 0, mtusLength)
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
mtus = append(mtus, int(math.Round(mtu)))
}
mtus = append(mtus, maxMTU) // last element is the maxMTU
}
return mtus
}
func collectReplies(conn net.PacketConn, ipVersion string,
tests []pmtudTestUnit, logger Logger,
) (err error) {
echoIDToTestIndex := make(map[uint16]int, len(tests))
for i, test := range tests {
echoIDToTestIndex[test.echoID] = i
}
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
// create huge buffers which we don't really want to support anyway.
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
// a conventional maximum of 9000 bytes. However, some manufacturers support up
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
// match eventual Jumbo frames. More information at:
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
const maxPossibleMTU = 9196
buffer := make([]byte, maxPossibleMTU)
idsFound := 0
for idsFound < len(tests) {
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
// must be large enough to read the entire reply packet. See:
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
return fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
ipPacketLength := len(packetBytes)
var icmpProtocol int
switch ipVersion {
case "v4":
icmpProtocol = icmpv4Protocol
case "v6":
icmpProtocol = icmpv6Protocol
default:
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
}
// Parse the ICMP message
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
if err != nil {
return fmt.Errorf("parsing message: %w", err)
}
echoBody, ok := message.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
}
id := uint16(echoBody.ID) //nolint:gosec
testIndex, testing := echoIDToTestIndex[id]
if !testing { // not an id we expected so ignore it
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
echoBody.ID, message.Type, message.Code, ipPacketLength)
continue
}
idsFound++
sentBytes := tests[testIndex].sentBytes
// echo reply should be at most the number of bytes sent,
// and can be lower, more precisely 556 bytes, in case
// the host we are reaching wants to stay out of trouble
// and ensure its echo reply goes through without
// fragmentation, see the following page:
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
const conservativeReplyLength = 556
truncated := ipPacketLength < sentBytes &&
ipPacketLength == conservativeReplyLength
// Check the packet size is the same if the reply is not truncated
if !truncated && sentBytes != ipPacketLength {
return fmt.Errorf("%w: sent %dB and received %dB",
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
}
// Truncated reply or matching reply size
tests[testIndex].ok = true
}
return nil
}

View File

@@ -0,0 +1,22 @@
//go:build integration
package pmtud
import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_PathMTUDiscover(t *testing.T) {
t.Parallel()
const physicalLinkMTU = 1500
const timeout = time.Second
mtu, err := PathMTUDiscover(context.Background(), netip.MustParseAddr("1.1.1.1"),
physicalLinkMTU, timeout, nil)
require.NoError(t, err)
t.Log("MTU found:", mtu)
}

View File

@@ -0,0 +1,55 @@
package pmtud
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_makeMTUsToTest(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
minMTU int
maxMTU int
mtus []int
}{
"0_0": {
mtus: []int{0},
},
"0_1": {
maxMTU: 1,
mtus: []int{0, 1},
},
"0_8": {
maxMTU: 8,
mtus: []int{0, 1, 2, 3, 4, 5, 6, 7, 8},
},
"0_12": {
maxMTU: 12,
mtus: []int{0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12},
},
"0_80": {
maxMTU: 80,
mtus: []int{0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80},
},
"0_100": {
maxMTU: 100,
mtus: []int{0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
},
"1280_1500": {
minMTU: 1280,
maxMTU: 1500,
mtus: []int{1280, 1302, 1324, 1346, 1368, 1390, 1412, 1434, 1456, 1478, 1500},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
assert.Equal(t, testCase.mtus, mtus)
})
}
}

View File

@@ -81,6 +81,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) (err error)
}
type DNSLoop interface {

View File

@@ -8,7 +8,6 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/vpn/state"
"github.com/qdm12/log"
)
@@ -22,10 +21,10 @@ type Loop struct {
healthChecker HealthChecker
healthServer HealthServer
// Fixed parameters
buildInfo models.BuildInformation
versionInfo bool
ipv6SupportLevel netlink.IPv6SupportLevel
vpnInputPorts []uint16 // TODO make changeable through stateful firewall
buildInfo models.BuildInformation
versionInfo bool
ipv6Supported bool
vpnInputPorts []uint16 // TODO make changeable through stateful firewall
// Configurators
openvpnConf OpenVPN
netLinker NetLinker
@@ -52,7 +51,7 @@ const (
defaultBackoffTime = 15 * time.Second
)
func NewLoop(vpnSettings settings.VPN, ipv6SupportLevel netlink.IPv6SupportLevel, vpnInputPorts []uint16,
func NewLoop(vpnSettings settings.VPN, ipv6Supported bool, vpnInputPorts []uint16,
providers Providers, storage Storage, healthSettings settings.Health,
healthChecker HealthChecker, healthServer HealthServer, openvpnConf OpenVPN,
netLinker NetLinker, fw Firewall, routing Routing,
@@ -70,32 +69,32 @@ func NewLoop(vpnSettings settings.VPN, ipv6SupportLevel netlink.IPv6SupportLevel
state := state.New(statusManager, vpnSettings)
return &Loop{
statusManager: statusManager,
state: state,
providers: providers,
storage: storage,
healthSettings: healthSettings,
healthChecker: healthChecker,
healthServer: healthServer,
buildInfo: buildInfo,
versionInfo: versionInfo,
ipv6SupportLevel: ipv6SupportLevel,
vpnInputPorts: vpnInputPorts,
openvpnConf: openvpnConf,
netLinker: netLinker,
fw: fw,
routing: routing,
portForward: portForward,
publicip: publicip,
dnsLooper: dnsLooper,
starter: starter,
logger: logger,
client: client,
start: start,
running: running,
stop: stop,
stopped: stopped,
userTrigger: true,
backoffTime: defaultBackoffTime,
statusManager: statusManager,
state: state,
providers: providers,
storage: storage,
healthSettings: healthSettings,
healthChecker: healthChecker,
healthServer: healthServer,
buildInfo: buildInfo,
versionInfo: versionInfo,
ipv6Supported: ipv6Supported,
vpnInputPorts: vpnInputPorts,
openvpnConf: openvpnConf,
netLinker: netLinker,
fw: fw,
routing: routing,
portForward: portForward,
publicip: publicip,
dnsLooper: dnsLooper,
starter: starter,
logger: logger,
client: client,
start: start,
running: running,
stop: stop,
stopped: stopped,
userTrigger: true,
backoffTime: defaultBackoffTime,
}
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/provider"
)
@@ -15,16 +14,15 @@ import (
// It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupOpenVPN(ctx context.Context, fw Firewall,
openvpnConf OpenVPN, providerConf provider.Provider,
settings settings.VPN, ipv6SupportLevel netlink.IPv6SupportLevel, starter CmdStarter,
settings settings.VPN, ipv6Supported bool, starter CmdStarter,
logger openvpn.Logger) (runner *openvpn.Runner, connection models.Connection, err error,
) {
ipv6Internet := ipv6SupportLevel == netlink.IPv6Internet
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Internet)
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err)
}
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6SupportLevel.IsSupported())
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
if err := openvpnConf.WriteConfig(lines); err != nil {
return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err)

View File

@@ -36,17 +36,18 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
if settings.Type == vpn.OpenVPN {
vpnInterface = settings.OpenVPN.Interface
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.ipv6SupportLevel, l.starter, subLogger)
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard
vpnInterface = settings.Wireguard.Interface
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
providerConf, settings, l.ipv6SupportLevel, subLogger)
providerConf, settings, l.ipv6Supported, subLogger)
}
if err != nil {
l.crashed(ctx, err)
continue
}
tunnelUpData := tunnelUpData{
vpnType: settings.Type,
serverIP: connection.IP,
serverName: connection.ServerName,
canPortForward: connection.PortForward,

View File

@@ -2,16 +2,24 @@ package vpn
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/pmtud"
"github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/log"
)
type tunnelUpData struct {
// Healthcheck
serverIP netip.Addr
// vpnType is used for path MTU discovery to find the protocol overhead.
// It can be "wireguard" or "openvpn".
vpnType string
// Port forwarding
vpnIntf string
serverName string // used for PIA
@@ -46,6 +54,13 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
return
}
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
err = updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
l.netLinker, l.routing, mtuLogger)
if err != nil {
mtuLogger.Error(err.Error())
}
if *l.dnsLooper.GetSettings().ServerEnabled {
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
} else {
@@ -112,3 +127,65 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
_, _ = l.ApplyStatus(ctx, constants.Stopped)
_, _ = l.ApplyStatus(ctx, constants.Running)
}
var errVPNTypeUnknown = errors.New("unknown VPN type")
func updateToMaxMTU(ctx context.Context, vpnInterface string,
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
) error {
logger.Info("finding maximum MTU, this can take up to 4 seconds")
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN gateway IP address: %w", err)
}
link, err := netlinker.LinkByName(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN interface by name: %w", err)
}
originalMTU := link.MTU
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
// protocol overhead, so start lower than 1500 according to the protocol used.
const physicalLinkMTU = 1500
vpnLinkMTU := physicalLinkMTU
switch vpnType {
case "wireguard":
vpnLinkMTU -= 60 // Wireguard overhead
case "openvpn":
vpnLinkMTU -= 41 // OpenVPN overhead
default:
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
}
// Setting the VPN link MTU to 1500 might interrupt the connection until
// the new MTU is set again, but this is necessary to find the highest valid MTU.
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
switch {
case err == nil:
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
vpnLinkMTU = int(originalMTU)
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
vpnInterface, originalMTU, err)
default:
return fmt.Errorf("path MTU discovering: %w", err)
}
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
return nil
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/wireguard"
@@ -17,16 +16,15 @@ import (
// It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupWireguard(ctx context.Context, netlinker NetLinker,
fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6SupportLevel netlink.IPv6SupportLevel, logger wireguard.Logger) (
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, connection models.Connection, err error,
) {
ipv6Internet := ipv6SupportLevel == netlink.IPv6Internet
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Internet)
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
}
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6SupportLevel.IsSupported())
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
logger.Debug("Wireguard server public key: " + wireguardSettings.PublicKey)
logger.Debug("Wireguard client private key: " + gosettings.ObfuscateKey(wireguardSettings.PrivateKey))