Compare commits
18 Commits
ivp6-level
...
pmtu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90b9e81129 | ||
|
|
2391c890b4 | ||
|
|
51fd46b58e | ||
|
|
906e7b5ee1 | ||
|
|
5428580b8f | ||
|
|
6c25ee53f1 | ||
|
|
b9051b02bf | ||
|
|
f0f3193c1c | ||
|
|
c0ebd180cb | ||
|
|
b6e873cf25 | ||
|
|
ccc2f306b9 | ||
|
|
5b1dc295fe | ||
|
|
00bc8bbbbb | ||
|
|
8bef380d8c | ||
|
|
9ad1907574 | ||
|
|
d83999d954 | ||
|
|
162d244865 | ||
|
|
e21d798f57 |
@@ -159,8 +159,6 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
FIREWALL_INPUT_PORTS= \
|
||||
FIREWALL_OUTBOUND_SUBNETS= \
|
||||
FIREWALL_DEBUG=off \
|
||||
# IPv6
|
||||
IPV6_CHECK_ADDRESS=[2606:4700::6810:84e5]:443 \
|
||||
# Logging
|
||||
LOG_LEVEL=info \
|
||||
# Health
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
@@ -243,13 +242,10 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
return err
|
||||
}
|
||||
|
||||
ipv6SupportLevel, err := netLinker.FindIPv6SupportLevel(ctx,
|
||||
allSettings.IPv6.CheckAddress, firewallConf)
|
||||
ipv6Supported, err := netLinker.IsIPv6Supported()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking for IPv6 support: %w", err)
|
||||
}
|
||||
ipv6Supported := ipv6SupportLevel == netlink.IPv6Supported ||
|
||||
ipv6SupportLevel == netlink.IPv6Internet
|
||||
|
||||
err = allSettings.Validate(storage, ipv6Supported, logger)
|
||||
if err != nil {
|
||||
@@ -434,7 +430,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
httpClient, unzipper, parallelResolver, publicIPLooper.Fetcher(), openvpnFileExtractor)
|
||||
|
||||
vpnLogger := logger.New(log.SetComponent("vpn"))
|
||||
vpnLooper := vpn.NewLoop(allSettings.VPN, ipv6SupportLevel, allSettings.Firewall.VPNInputPorts,
|
||||
vpnLooper := vpn.NewLoop(allSettings.VPN, ipv6Supported, allSettings.Firewall.VPNInputPorts,
|
||||
providers, storage, allSettings.Health, healthChecker, healthcheckServer, ovpnConf, netLinker, firewallConf,
|
||||
routingConf, portForwardLooper, cmder, publicIPLooper, dnsLooper, vpnLogger, httpClient,
|
||||
buildInfo, *allSettings.Version.Enabled)
|
||||
@@ -478,7 +474,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
logger.New(log.SetComponent("http server")),
|
||||
allSettings.ControlServer.AuthFilePath,
|
||||
buildInfo, vpnLooper, portForwardLooper, dnsLooper, updaterLooper, publicIPLooper,
|
||||
storage, ipv6SupportLevel.IsSupported())
|
||||
storage, ipv6Supported)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up control server: %w", err)
|
||||
}
|
||||
@@ -554,9 +550,7 @@ type netLinker interface {
|
||||
Ruler
|
||||
Linker
|
||||
IsWireguardSupported() (ok bool, err error)
|
||||
FindIPv6SupportLevel(ctx context.Context,
|
||||
checkAddress netip.AddrPort, firewall netlink.Firewall,
|
||||
) (level netlink.IPv6SupportLevel, err error)
|
||||
IsIPv6Supported() (ok bool, err error)
|
||||
PatchLoggerLevel(level log.Level)
|
||||
}
|
||||
|
||||
@@ -587,6 +581,7 @@ type Linker interface {
|
||||
LinkDel(link netlink.Link) (err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) (err error)
|
||||
LinkSetMTU(link netlink.Link, mtu int) error
|
||||
}
|
||||
|
||||
type clier interface {
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type noopFirewall struct{}
|
||||
|
||||
func (f *noopFirewall) AcceptOutput(_ context.Context, _, _ string, _ netip.Addr,
|
||||
_ uint16, _ bool,
|
||||
) (err error) {
|
||||
return nil
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/openvpn/extract"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
@@ -41,9 +40,7 @@ type IPFetcher interface {
|
||||
}
|
||||
|
||||
type IPv6Checker interface {
|
||||
FindIPv6SupportLevel(ctx context.Context,
|
||||
checkAddress netip.AddrPort, firewall netlink.Firewall,
|
||||
) (level netlink.IPv6SupportLevel, err error)
|
||||
IsIPv6Supported() (supported bool, err error)
|
||||
}
|
||||
|
||||
func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
|
||||
@@ -61,14 +58,12 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
|
||||
}
|
||||
allSettings.SetDefaults()
|
||||
|
||||
ipv6SupportLevel, err := ipv6Checker.FindIPv6SupportLevel(context.Background(),
|
||||
allSettings.IPv6.CheckAddress, &noopFirewall{})
|
||||
ipv6Supported, err := ipv6Checker.IsIPv6Supported()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking for IPv6 support: %w", err)
|
||||
}
|
||||
|
||||
err = allSettings.Validate(storage, ipv6SupportLevel.IsSupported(), logger)
|
||||
if err != nil {
|
||||
if err = allSettings.Validate(storage, ipv6Supported, logger); err != nil {
|
||||
return fmt.Errorf("validating settings: %w", err)
|
||||
}
|
||||
|
||||
@@ -84,13 +79,13 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
|
||||
unzipper, parallelResolver, ipFetcher, openvpnFileExtractor)
|
||||
providerConf := providers.Get(allSettings.VPN.Provider.Name)
|
||||
connection, err := providerConf.GetConnection(
|
||||
allSettings.VPN.Provider.ServerSelection, ipv6SupportLevel == netlink.IPv6Internet)
|
||||
allSettings.VPN.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lines := providerConf.OpenVPNConfig(connection,
|
||||
allSettings.VPN.OpenVPN, ipv6SupportLevel.IsSupported())
|
||||
allSettings.VPN.OpenVPN, ipv6Supported)
|
||||
|
||||
fmt.Println(strings.Join(lines, "\n"))
|
||||
return nil
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/reader"
|
||||
"github.com/qdm12/gotree"
|
||||
)
|
||||
|
||||
// IPv6 contains settings regarding IPv6 configuration.
|
||||
type IPv6 struct {
|
||||
// CheckAddress is the TCP ip:port address to dial to check
|
||||
// IPv6 is supported, in case a default IPv6 route is found.
|
||||
// It defaults to cloudflare.com address [2606:4700::6810:84e5]:443
|
||||
CheckAddress netip.AddrPort
|
||||
}
|
||||
|
||||
func (i IPv6) validate() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *IPv6) copy() (copied IPv6) {
|
||||
return IPv6{
|
||||
CheckAddress: i.CheckAddress,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *IPv6) overrideWith(other IPv6) {
|
||||
i.CheckAddress = gosettings.OverrideWithValidator(i.CheckAddress, other.CheckAddress)
|
||||
}
|
||||
|
||||
func (i *IPv6) setDefaults() {
|
||||
defaultCheckAddress := netip.MustParseAddrPort("[2606:4700::6810:84e5]:443")
|
||||
i.CheckAddress = gosettings.DefaultComparable(i.CheckAddress, defaultCheckAddress)
|
||||
}
|
||||
|
||||
func (i IPv6) String() string {
|
||||
return i.toLinesNode().String()
|
||||
}
|
||||
|
||||
func (i IPv6) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("IPv6 settings:")
|
||||
node.Appendf("Check address: %s", i.CheckAddress)
|
||||
return node
|
||||
}
|
||||
|
||||
func (i *IPv6) read(r *reader.Reader) (err error) {
|
||||
i.CheckAddress, err = r.NetipAddrPort("IPV6_CHECK_ADDRESS")
|
||||
return err
|
||||
}
|
||||
@@ -27,7 +27,6 @@ type Settings struct {
|
||||
Updater Updater
|
||||
Version Version
|
||||
VPN VPN
|
||||
IPv6 IPv6
|
||||
Pprof pprof.Settings
|
||||
}
|
||||
|
||||
@@ -54,7 +53,6 @@ func (s *Settings) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Support
|
||||
"system": s.System.validate,
|
||||
"updater": s.Updater.Validate,
|
||||
"version": s.Version.validate,
|
||||
"ipv6": s.IPv6.validate,
|
||||
// Pprof validation done in pprof constructor
|
||||
"VPN": func() error {
|
||||
return s.VPN.Validate(filterChoicesGetter, ipv6Supported, warner)
|
||||
@@ -87,7 +85,6 @@ func (s *Settings) copy() (copied Settings) {
|
||||
Version: s.Version.copy(),
|
||||
VPN: s.VPN.Copy(),
|
||||
Pprof: s.Pprof.Copy(),
|
||||
IPv6: s.IPv6.copy(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +106,6 @@ func (s *Settings) OverrideWith(other Settings,
|
||||
patchedSettings.Version.overrideWith(other.Version)
|
||||
patchedSettings.VPN.OverrideWith(other.VPN)
|
||||
patchedSettings.Pprof.OverrideWith(other.Pprof)
|
||||
patchedSettings.IPv6.overrideWith(other.IPv6)
|
||||
err = patchedSettings.Validate(filterChoicesGetter, ipv6Supported, warner)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -125,7 +121,6 @@ func (s *Settings) SetDefaults() {
|
||||
s.Health.SetDefaults()
|
||||
s.HTTPProxy.setDefaults()
|
||||
s.Log.setDefaults()
|
||||
s.IPv6.setDefaults()
|
||||
s.PublicIP.setDefaults()
|
||||
s.Shadowsocks.setDefaults()
|
||||
s.Storage.setDefaults()
|
||||
@@ -147,7 +142,6 @@ func (s Settings) toLinesNode() (node *gotree.Node) {
|
||||
node.AppendNode(s.DNS.toLinesNode())
|
||||
node.AppendNode(s.Firewall.toLinesNode())
|
||||
node.AppendNode(s.Log.toLinesNode())
|
||||
node.AppendNode(s.IPv6.toLinesNode())
|
||||
node.AppendNode(s.Health.toLinesNode())
|
||||
node.AppendNode(s.Shadowsocks.toLinesNode())
|
||||
node.AppendNode(s.HTTPProxy.toLinesNode())
|
||||
@@ -214,7 +208,6 @@ func (s *Settings) Read(r *reader.Reader, warner Warner) (err error) {
|
||||
"updater": s.Updater.read,
|
||||
"version": s.Version.read,
|
||||
"VPN": s.VPN.read,
|
||||
"IPv6": s.IPv6.read,
|
||||
"profiling": s.Pprof.Read,
|
||||
}
|
||||
|
||||
|
||||
@@ -55,8 +55,6 @@ func Test_Settings_String(t *testing.T) {
|
||||
| └── Enabled: yes
|
||||
├── Log settings:
|
||||
| └── Log level: INFO
|
||||
├── IPv6 settings:
|
||||
| └── Check address: [2606:4700::6810:84e5]:443
|
||||
├── Health settings:
|
||||
| ├── Server listening address: 127.0.0.1:9999
|
||||
| ├── Target address: cloudflare.com:443
|
||||
|
||||
@@ -45,6 +45,7 @@ type Wireguard struct {
|
||||
// It has been lowered to 1320 following quite a bit of
|
||||
// investigation in the issue:
|
||||
// https://github.com/qdm12/gluetun/issues/2533.
|
||||
// Note this should now be replaced with the PMTUD feature.
|
||||
MTU uint16 `json:"mtu"`
|
||||
// Implementation is the Wireguard implementation to use.
|
||||
// It can be "auto", "userspace" or "kernelspace".
|
||||
|
||||
@@ -162,24 +162,6 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
func (c *Config) AcceptOutput(ctx context.Context,
|
||||
protocol, intf string, ip netip.Addr, port uint16, remove bool,
|
||||
) error {
|
||||
interfaceFlag := "-o " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf("%s OUTPUT -d %s %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
appendOrDelete(remove), ip, interfaceFlag, protocol, protocol, port)
|
||||
if ip.Is4() {
|
||||
return c.runIptablesInstruction(ctx, instruction)
|
||||
} else if c.ip6Tables == "" {
|
||||
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
|
||||
}
|
||||
return c.runIP6tablesInstruction(ctx, instruction)
|
||||
}
|
||||
|
||||
// Thanks to @npawelek.
|
||||
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
|
||||
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,
|
||||
|
||||
@@ -1,19 +1,9 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
import "github.com/qdm12/log"
|
||||
|
||||
type DebugLogger interface {
|
||||
Debug(message string)
|
||||
Debugf(format string, args ...any)
|
||||
Patch(options ...log.Option)
|
||||
}
|
||||
|
||||
type Firewall interface {
|
||||
AcceptOutput(ctx context.Context, protocol, intf string, ip netip.Addr,
|
||||
port uint16, remove bool) (err error)
|
||||
}
|
||||
|
||||
@@ -1,106 +1,37 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IPv6SupportLevel uint8
|
||||
|
||||
const (
|
||||
IPv6Unsupported = iota
|
||||
// IPv6Supported indicates the host supports IPv6 but has no access to the
|
||||
// Internet via IPv6. It is true if one IPv6 route is found and no default
|
||||
// IPv6 route is found.
|
||||
IPv6Supported
|
||||
// IPv6Internet indicates the host has access to the Internet via IPv6,
|
||||
// which is detected when a default IPv6 route is found.
|
||||
IPv6Internet
|
||||
)
|
||||
|
||||
func (i IPv6SupportLevel) IsSupported() bool {
|
||||
return i == IPv6Supported || i == IPv6Internet
|
||||
}
|
||||
|
||||
func (n *NetLink) FindIPv6SupportLevel(ctx context.Context,
|
||||
checkAddress netip.AddrPort, firewall Firewall,
|
||||
) (level IPv6SupportLevel, err error) {
|
||||
func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
|
||||
routes, err := n.RouteList(FamilyV6)
|
||||
if err != nil {
|
||||
return IPv6Unsupported, fmt.Errorf("listing IPv6 routes: %w", err)
|
||||
return false, fmt.Errorf("listing IPv6 routes: %w", err)
|
||||
}
|
||||
|
||||
// Check each route for IPv6 due to Podman bug listing IPv4 routes
|
||||
// as IPv6 routes at container start, see:
|
||||
// https://github.com/qdm12/gluetun/issues/1241#issuecomment-1333405949
|
||||
level = IPv6Unsupported
|
||||
for _, route := range routes {
|
||||
link, err := n.LinkByIndex(route.LinkIndex)
|
||||
if err != nil {
|
||||
return IPv6Unsupported, fmt.Errorf("finding link corresponding to route: %w", err)
|
||||
return false, fmt.Errorf("finding link corresponding to route: %w", err)
|
||||
}
|
||||
|
||||
sourceIsIPv4 := route.Src.IsValid() && route.Src.Is4()
|
||||
destinationIsIPv4 := route.Dst.IsValid() && route.Dst.Addr().Is4()
|
||||
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
|
||||
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
|
||||
switch {
|
||||
case sourceIsIPv4 && destinationIsIPv4,
|
||||
case !sourceIsIPv6 && !destinationIsIPv6,
|
||||
destinationIsIPv6 && route.Dst.Addr().IsLoopback():
|
||||
case route.Dst.Addr().IsUnspecified(): // default ipv6 route
|
||||
n.debugLogger.Debugf("IPv6 default route found on link %s", link.Name)
|
||||
err = dialAddrThroughFirewall(ctx, link.Name, checkAddress, firewall)
|
||||
if err != nil {
|
||||
n.debugLogger.Debugf("IPv6 query failed on %s: %w", link.Name, err)
|
||||
level = IPv6Supported
|
||||
continue
|
||||
}
|
||||
n.debugLogger.Debugf("IPv6 internet is accessible through link %s", link.Name)
|
||||
return IPv6Internet, nil
|
||||
default: // non-default ipv6 route found
|
||||
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
|
||||
level = IPv6Supported
|
||||
continue
|
||||
}
|
||||
|
||||
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if level == IPv6Unsupported {
|
||||
n.debugLogger.Debugf("no IPv6 route found in %d routes", len(routes))
|
||||
}
|
||||
return level, nil
|
||||
}
|
||||
|
||||
func dialAddrThroughFirewall(ctx context.Context, intf string,
|
||||
checkAddress netip.AddrPort, firewall Firewall,
|
||||
) (err error) {
|
||||
const protocol = "tcp"
|
||||
remove := false
|
||||
err = firewall.AcceptOutput(ctx, protocol, intf,
|
||||
checkAddress.Addr(), checkAddress.Port(), remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting output traffic: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
remove = true
|
||||
firewallErr := firewall.AcceptOutput(ctx, protocol, intf,
|
||||
checkAddress.Addr(), checkAddress.Port(), remove)
|
||||
if err == nil && firewallErr != nil {
|
||||
err = fmt.Errorf("removing output traffic rule: %w", firewallErr)
|
||||
}
|
||||
}()
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Second,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, protocol, checkAddress.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing: %w", err)
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
n.debugLogger.Debugf("IPv6 is not supported after searching %d routes",
|
||||
len(routes))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func isIPv6LocallySupported() bool {
|
||||
dialer := net.Dialer{Timeout: time.Millisecond}
|
||||
_, err := dialer.Dial("tcp6", "[::1]:9999")
|
||||
return !strings.HasSuffix(err.Error(), "connect: cannot assign requested address")
|
||||
}
|
||||
|
||||
// Susceptible to TOCTOU but it should be fine for the use case.
|
||||
func findAvailableTCPPort(t *testing.T) (port uint16) {
|
||||
t.Helper()
|
||||
|
||||
config := &net.ListenConfig{}
|
||||
listener, err := config.Listen(context.Background(), "tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
err = listener.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
return addrPort.Port()
|
||||
}
|
||||
|
||||
func Test_dialAddrThroughFirewall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errTest := errors.New("test error")
|
||||
|
||||
const ipv6InternetWorks = false
|
||||
|
||||
testCases := map[string]struct {
|
||||
getIPv6CheckAddr func(t *testing.T) netip.AddrPort
|
||||
firewallAddErr error
|
||||
firewallRemoveErr error
|
||||
errMessageRegex func() string
|
||||
}{
|
||||
"cloudflare.com": {
|
||||
getIPv6CheckAddr: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.MustParseAddrPort("[2606:4700::6810:84e5]:443")
|
||||
},
|
||||
errMessageRegex: func() string {
|
||||
if ipv6InternetWorks {
|
||||
return ""
|
||||
}
|
||||
return "dialing: dial tcp \\[2606:4700::6810:84e5\\]:443: " +
|
||||
"connect: (cannot assign requested address|network is unreachable)"
|
||||
},
|
||||
},
|
||||
"local_server": {
|
||||
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
|
||||
network := "tcp6"
|
||||
loopback := netip.MustParseAddr("::1")
|
||||
if !isIPv6LocallySupported() {
|
||||
network = "tcp4"
|
||||
loopback = netip.MustParseAddr("127.0.0.1")
|
||||
}
|
||||
|
||||
listener, err := net.ListenTCP(network, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := listener.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
addrPort := netip.MustParseAddrPort(listener.Addr().String())
|
||||
return netip.AddrPortFrom(loopback, addrPort.Port())
|
||||
},
|
||||
},
|
||||
"no_local_server": {
|
||||
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
|
||||
loopback := netip.MustParseAddr("::1")
|
||||
if !ipv6InternetWorks {
|
||||
loopback = netip.MustParseAddr("127.0.0.1")
|
||||
}
|
||||
|
||||
availablePort := findAvailableTCPPort(t)
|
||||
return netip.AddrPortFrom(loopback, availablePort)
|
||||
},
|
||||
errMessageRegex: func() string {
|
||||
return "dialing: dial tcp (\\[::1\\]|127\\.0\\.0\\.1):[1-9][0-9]{1,4}: " +
|
||||
"connect: connection refused"
|
||||
},
|
||||
},
|
||||
"firewall_add_error": {
|
||||
firewallAddErr: errTest,
|
||||
errMessageRegex: func() string {
|
||||
return "accepting output traffic: test error"
|
||||
},
|
||||
},
|
||||
"firewall_remove_error": {
|
||||
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
|
||||
network := "tcp4"
|
||||
loopback := netip.MustParseAddr("127.0.0.1")
|
||||
listener, err := net.ListenTCP(network, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := listener.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
addrPort := netip.MustParseAddrPort(listener.Addr().String())
|
||||
return netip.AddrPortFrom(loopback, addrPort.Port())
|
||||
},
|
||||
firewallRemoveErr: errTest,
|
||||
errMessageRegex: func() string {
|
||||
return "removing output traffic rule: test error"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var checkAddr netip.AddrPort
|
||||
if testCase.getIPv6CheckAddr != nil {
|
||||
checkAddr = testCase.getIPv6CheckAddr(t)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
const intf = "eth0"
|
||||
firewall := NewMockFirewall(ctrl)
|
||||
call := firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
|
||||
checkAddr.Addr(), checkAddr.Port(), false).
|
||||
Return(testCase.firewallAddErr)
|
||||
if testCase.firewallAddErr == nil {
|
||||
firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
|
||||
checkAddr.Addr(), checkAddr.Port(), true).
|
||||
Return(testCase.firewallRemoveErr).After(call)
|
||||
}
|
||||
|
||||
err := dialAddrThroughFirewall(ctx, intf, checkAddr, firewall)
|
||||
var errMessageRegex string
|
||||
if testCase.errMessageRegex != nil {
|
||||
errMessageRegex = testCase.errMessageRegex()
|
||||
}
|
||||
if errMessageRegex == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Regexp(t, errMessageRegex, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) {
|
||||
return netlink.LinkSetDown(linkToNetlinkLink(&link))
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetMTU(link Link, mtu int) error {
|
||||
return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu)
|
||||
}
|
||||
|
||||
type netlinkLinkImpl struct {
|
||||
attrs *netlink.LinkAttrs
|
||||
linkType string
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
package netlink
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
|
||||
@@ -1,50 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/netlink (interfaces: Firewall)
|
||||
|
||||
// Package netlink is a generated GoMock package.
|
||||
package netlink
|
||||
|
||||
import (
|
||||
context "context"
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockFirewall is a mock of Firewall interface.
|
||||
type MockFirewall struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockFirewallMockRecorder
|
||||
}
|
||||
|
||||
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
|
||||
type MockFirewallMockRecorder struct {
|
||||
mock *MockFirewall
|
||||
}
|
||||
|
||||
// NewMockFirewall creates a new mock instance.
|
||||
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
|
||||
mock := &MockFirewall{ctrl: ctrl}
|
||||
mock.recorder = &MockFirewallMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcceptOutput mocks base method.
|
||||
func (m *MockFirewall) AcceptOutput(arg0 context.Context, arg1, arg2 string, arg3 netip.Addr, arg4 uint16, arg5 bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptOutput", arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AcceptOutput indicates an expected call of AcceptOutput.
|
||||
func (mr *MockFirewallMockRecorder) AcceptOutput(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutput", reflect.TypeOf((*MockFirewall)(nil).AcceptOutput), arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
}
|
||||
49
internal/pmtud/apple_ipv4.go
Normal file
49
internal/pmtud/apple_ipv4.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
var _ net.PacketConn = &ipv4Wrapper{}
|
||||
|
||||
// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement
|
||||
// the net.PacketConn interface. It's only used for Darwin or iOS.
|
||||
type ipv4Wrapper struct {
|
||||
ipv4Conn *ipv4.PacketConn
|
||||
}
|
||||
|
||||
func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper {
|
||||
return &ipv4Wrapper{ipv4Conn: ipv4}
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, _, addr, err = i.ipv4Conn.ReadFrom(p)
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return i.ipv4Conn.WriteTo(p, nil, addr)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) Close() error {
|
||||
return i.ipv4Conn.Close()
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) LocalAddr() net.Addr {
|
||||
return i.ipv4Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error {
|
||||
return i.ipv4Conn.SetWriteDeadline(t)
|
||||
}
|
||||
83
internal/pmtud/check.go
Normal file
83
internal/pmtud/check.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPNextHopMTUTooLow = errors.New("ICMP Next Hop MTU is too low")
|
||||
ErrICMPNextHopMTUTooHigh = errors.New("ICMP Next Hop MTU is too high")
|
||||
)
|
||||
|
||||
func checkMTU(mtu, minMTU, physicalLinkMTU int) (err error) {
|
||||
switch {
|
||||
case mtu < minMTU:
|
||||
return fmt.Errorf("%w: %d", ErrICMPNextHopMTUTooLow, mtu)
|
||||
case mtu > physicalLinkMTU:
|
||||
return fmt.Errorf("%w: %d is larger than physical link MTU %d",
|
||||
ErrICMPNextHopMTUTooHigh, mtu, physicalLinkMTU)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkInvokingReplyIDMatch(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message,
|
||||
) (match bool, err error) {
|
||||
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing invoking packet: %w", err)
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
return inboundBody.ID == outboundBody.ID, nil
|
||||
}
|
||||
|
||||
var ErrICMPIDMismatch = errors.New("ICMP id mismatch")
|
||||
|
||||
func checkEchoReply(icmpProtocol int, received []byte,
|
||||
outboundMessage *icmp.Message, truncatedBody bool,
|
||||
) (err error) {
|
||||
inboundMessage, err := icmp.ParseMessage(icmpProtocol, received)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing invoking packet: %w", err)
|
||||
}
|
||||
inboundBody, ok := inboundMessage.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, inboundMessage.Body)
|
||||
}
|
||||
outboundBody := outboundMessage.Body.(*icmp.Echo) //nolint:forcetypeassert
|
||||
if inboundBody.ID != outboundBody.ID {
|
||||
return fmt.Errorf("%w: sent id %d and received id %d",
|
||||
ErrICMPIDMismatch, outboundBody.ID, inboundBody.ID)
|
||||
}
|
||||
err = checkEchoBodies(outboundBody.Data, inboundBody.Data, truncatedBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking sent and received bodies: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch")
|
||||
|
||||
func checkEchoBodies(sent, received []byte, receivedTruncated bool) (err error) {
|
||||
if len(received) > len(sent) {
|
||||
return fmt.Errorf("%w: sent %d bytes and received %d bytes",
|
||||
ErrICMPEchoDataMismatch, len(sent), len(received))
|
||||
}
|
||||
if receivedTruncated {
|
||||
sent = sent[:len(received)]
|
||||
}
|
||||
if !bytes.Equal(received, sent) {
|
||||
return fmt.Errorf("%w: sent %x and received %x",
|
||||
ErrICMPEchoDataMismatch, sent, received)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
10
internal/pmtud/df.go
Normal file
10
internal/pmtud/df.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package pmtud
|
||||
|
||||
// setDontFragment for platforms other than Linux and Windows
|
||||
// is not implemented, so we just return assuming the don't
|
||||
// fragment flag is set on IP packets.
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
return nil
|
||||
}
|
||||
12
internal/pmtud/df_linux.go
Normal file
12
internal/pmtud/df_linux.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build linux
|
||||
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP,
|
||||
syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_PROBE)
|
||||
}
|
||||
13
internal/pmtud/df_windows.go
Normal file
13
internal/pmtud/df_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setDontFragment(fd uintptr) (err error) {
|
||||
// https://docs.microsoft.com/en-us/troubleshoot/windows/win32/header-library-requirement-socket-ipproto-ip
|
||||
// #define IP_DONTFRAGMENT 14 /* don't fragment IP datagrams */
|
||||
return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, 14, 1)
|
||||
}
|
||||
29
internal/pmtud/errors.go
Normal file
29
internal/pmtud/errors.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPNotPermitted = errors.New("ICMP not permitted")
|
||||
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrICMPCommunicationAdministrativelyProhibited = errors.New("communication administratively prohibited")
|
||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
)
|
||||
|
||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||
switch {
|
||||
case strings.HasSuffix(err.Error(), "sendto: operation not permitted"):
|
||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||
case timedCtx.Err() != nil:
|
||||
err = timedCtx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
7
internal/pmtud/interfaces.go
Normal file
7
internal/pmtud/interfaces.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package pmtud
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(msg string, args ...any)
|
||||
Warnf(msg string, args ...any)
|
||||
}
|
||||
159
internal/pmtud/ipv4.go
Normal file
159
internal/pmtud/ipv4.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
// see https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
minIPv4MTU = 68
|
||||
icmpv4Protocol = 1
|
||||
)
|
||||
|
||||
func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
listenConfig.Control = func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var setDFErr error
|
||||
err := rawConn.Control(func(fd uintptr) {
|
||||
setDFErr = setDontFragment(fd) // runs when calling ListenPacket
|
||||
})
|
||||
if err == nil {
|
||||
err = setDFErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
||||
packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn))
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU int, pingTimeout time.Duration, logger Logger,
|
||||
) (mtu int, err error) {
|
||||
if ip.Is6() {
|
||||
panic("IP address is not v4")
|
||||
}
|
||||
conn, err := listenICMPv4(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// First try to send a packet which is too big to get the maximum MTU
|
||||
// directly.
|
||||
outboundID, outboundMessage := buildMessageToSend("v4", physicalLinkMTU)
|
||||
encodedMessage, err := outboundMessage.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, physicalLinkMTU)
|
||||
|
||||
for { // for loop in case we read an echo reply for another ICMP request
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
// Side note: echo reply should be at most the number of bytes
|
||||
// sent, and can be lower, more precisely 576-ipHeader bytes,
|
||||
// in case the next hop we are reaching replies with a destination
|
||||
// unreachable and wants to ensure the response makes it way back
|
||||
// by keeping a low packet size, see:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
|
||||
inboundMessage, err := icmp.ParseMessage(icmpv4Protocol, packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
switch typedBody := inboundMessage.Body.(type) {
|
||||
case *icmp.DstUnreach:
|
||||
const fragmentationRequiredAndDFFlagSetCode = 4
|
||||
const communicationAdministrativelyProhibitedCode = 13
|
||||
switch inboundMessage.Code {
|
||||
case fragmentationRequiredAndDFFlagSetCode:
|
||||
case communicationAdministrativelyProhibitedCode:
|
||||
return 0, fmt.Errorf("%w: %w (code %d)",
|
||||
ErrICMPDestinationUnreachable,
|
||||
ErrICMPCommunicationAdministrativelyProhibited,
|
||||
inboundMessage.Code)
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: code %d",
|
||||
ErrICMPDestinationUnreachable, inboundMessage.Code)
|
||||
}
|
||||
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-4
|
||||
// Note: the go library does not handle this NextHopMTU section.
|
||||
nextHopMTU := packetBytes[6:8]
|
||||
mtu = int(binary.BigEndian.Uint16(nextHopMTU))
|
||||
err = checkMTU(mtu, minIPv4MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking next-hop-mtu found: %w", err)
|
||||
}
|
||||
|
||||
// The code below is really for sanity checks
|
||||
packetBytes = packetBytes[8:]
|
||||
header, err := ipv4.ParseHeader(packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing IPv4 header: %w", err)
|
||||
}
|
||||
packetBytes = packetBytes[header.Len:] // truncated original datagram
|
||||
|
||||
const truncated = true
|
||||
err = checkEchoReply(icmpv4Protocol, packetBytes, outboundMessage, truncated)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking echo reply: %w", err)
|
||||
}
|
||||
return mtu, nil
|
||||
case *icmp.Echo:
|
||||
inboundID := uint16(typedBody.ID) //nolint:gosec
|
||||
if inboundID == outboundID {
|
||||
return physicalLinkMTU, nil
|
||||
}
|
||||
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
122
internal/pmtud/ipv6.go
Normal file
122
internal/pmtud/ipv6.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
const (
|
||||
minIPv6MTU = 1280
|
||||
icmpv6Protocol = 58
|
||||
)
|
||||
|
||||
func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) {
|
||||
var listenConfig net.ListenConfig
|
||||
const listenAddress = ""
|
||||
packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU int, pingTimeout time.Duration, logger Logger,
|
||||
) (mtu int, err error) {
|
||||
if ip.Is4() {
|
||||
panic("IP address is not v6")
|
||||
}
|
||||
conn, err := listenICMPv6(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// First try to send a packet which is too big to get the maximum MTU
|
||||
// directly.
|
||||
outboundID, outboundMessage := buildMessageToSend("v6", physicalLinkMTU)
|
||||
encodedMessage, err := outboundMessage.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, physicalLinkMTU)
|
||||
|
||||
for { // for loop if we encounter another ICMP packet with an unknown id.
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
packetBytes = packetBytes[ipv6.HeaderLen:]
|
||||
|
||||
inboundMessage, err := icmp.ParseMessage(icmpv6Protocol, packetBytes)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
switch typedBody := inboundMessage.Body.(type) {
|
||||
case *icmp.PacketTooBig:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.2
|
||||
mtu = typedBody.MTU
|
||||
err = checkMTU(mtu, minIPv6MTU, physicalLinkMTU)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking MTU: %w", err)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
const truncatedBody = true
|
||||
err = checkEchoReply(icmpv6Protocol, typedBody.Data, outboundMessage, truncatedBody)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message: %w", err)
|
||||
}
|
||||
return typedBody.MTU, nil
|
||||
case *icmp.DstUnreach:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1885#section-3.1
|
||||
idMatch, err := checkInvokingReplyIDMatch(icmpv6Protocol, packetBytes, outboundMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("checking invoking message id: %w", err)
|
||||
} else if idMatch {
|
||||
return 0, fmt.Errorf("%w", ErrICMPDestinationUnreachable)
|
||||
}
|
||||
logger.Debug("discarding received ICMP destination unreachable reply with an unknown id")
|
||||
continue
|
||||
case *icmp.Echo:
|
||||
inboundID := uint16(typedBody.ID) //nolint:gosec
|
||||
if inboundID == outboundID {
|
||||
return physicalLinkMTU, nil
|
||||
}
|
||||
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, typedBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
58
internal/pmtud/message.go
Normal file
58
internal/pmtud/message.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
func buildMessageToSend(ipVersion string, mtu int) (id uint16, message *icmp.Message) {
|
||||
var seed [32]byte
|
||||
_, _ = cryptorand.Read(seed[:])
|
||||
randomSource := rand.NewChaCha8(seed)
|
||||
|
||||
const uint16Bytes = 2
|
||||
idBytes := make([]byte, uint16Bytes)
|
||||
_, _ = randomSource.Read(idBytes)
|
||||
id = binary.BigEndian.Uint16(idBytes)
|
||||
|
||||
var ipHeaderLength int
|
||||
var icmpType icmp.Type
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
ipHeaderLength = ipv4.HeaderLen
|
||||
icmpType = ipv4.ICMPTypeEcho
|
||||
case "v6":
|
||||
ipHeaderLength = ipv6.HeaderLen
|
||||
icmpType = ipv6.ICMPTypeEchoRequest
|
||||
default:
|
||||
panic(fmt.Sprintf("IP version %q not supported", ipVersion))
|
||||
}
|
||||
const pingHeaderLength = 0 +
|
||||
1 + // type
|
||||
1 + // code
|
||||
2 + // checksum
|
||||
2 + // identifier
|
||||
2 // sequence number
|
||||
pingBodyDataSize := mtu - ipHeaderLength - pingHeaderLength
|
||||
messageBodyData := make([]byte, pingBodyDataSize)
|
||||
_, _ = randomSource.Read(messageBodyData)
|
||||
|
||||
// See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types
|
||||
message = &icmp.Message{
|
||||
Type: icmpType, // echo request
|
||||
Code: 0, // no code
|
||||
Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6)
|
||||
Body: &icmp.Echo{
|
||||
ID: int(id),
|
||||
Seq: 0, // only one packet
|
||||
Data: messageBodyData,
|
||||
},
|
||||
}
|
||||
return id, message
|
||||
}
|
||||
7
internal/pmtud/nooplogger.go
Normal file
7
internal/pmtud/nooplogger.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package pmtud
|
||||
|
||||
type noopLogger struct{}
|
||||
|
||||
func (noopLogger) Debug(_ string) {}
|
||||
func (noopLogger) Debugf(_ string, _ ...any) {}
|
||||
func (noopLogger) Warnf(_ string, _ ...any) {}
|
||||
271
internal/pmtud/pmtud.go
Normal file
271
internal/pmtud/pmtud.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
)
|
||||
|
||||
var ErrMTUNotFound = errors.New("path MTU discovery failed to find MTU")
|
||||
|
||||
// PathMTUDiscover discovers the maximum MTU for the path to the given ip address.
|
||||
// If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU.
|
||||
// If the pingTimeout is zero, it defaults to 1 second.
|
||||
// If the logger is nil, a no-op logger is used.
|
||||
// It returns [ErrMTUNotFound] if the MTU could not be determined.
|
||||
func PathMTUDiscover(ctx context.Context, ip netip.Addr,
|
||||
physicalLinkMTU int, pingTimeout time.Duration, logger Logger) (
|
||||
mtu int, err error,
|
||||
) {
|
||||
if physicalLinkMTU == 0 {
|
||||
const ethernetStandardMTU = 1500
|
||||
physicalLinkMTU = ethernetStandardMTU
|
||||
}
|
||||
if pingTimeout == 0 {
|
||||
pingTimeout = time.Second
|
||||
}
|
||||
if logger == nil {
|
||||
logger = &noopLogger{}
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
logger.Debug("finding IPv4 next hop MTU")
|
||||
mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed) || errors.Is(err, ErrICMPCommunicationAdministrativelyProhibited): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("requesting IPv6 ICMP packet-too-big reply")
|
||||
mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
return mtu, nil
|
||||
case errors.Is(err, net.ErrClosed): // blackhole
|
||||
default:
|
||||
return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back method: send echo requests with different packet
|
||||
// sizes and check which ones succeed to find the maximum MTU.
|
||||
logger.Debug("falling back to sending different sized echo packets")
|
||||
minMTU := minIPv4MTU
|
||||
if ip.Is6() {
|
||||
minMTU = minIPv6MTU
|
||||
}
|
||||
return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger)
|
||||
}
|
||||
|
||||
type pmtudTestUnit struct {
|
||||
mtu int
|
||||
echoID uint16
|
||||
sentBytes int
|
||||
ok bool
|
||||
}
|
||||
|
||||
func pmtudMultiSizes(ctx context.Context, ip netip.Addr,
|
||||
minMTU, maxPossibleMTU int, pingTimeout time.Duration,
|
||||
logger Logger,
|
||||
) (maxMTU int, err error) {
|
||||
var ipVersion string
|
||||
var conn net.PacketConn
|
||||
if ip.Is4() {
|
||||
ipVersion = "v4"
|
||||
conn, err = listenICMPv4(ctx)
|
||||
} else {
|
||||
ipVersion = "v6"
|
||||
conn, err = listenICMPv6(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "socket: operation not permitted") {
|
||||
err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrICMPNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("listening for ICMP packets: %w", err)
|
||||
}
|
||||
|
||||
mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU)
|
||||
if len(mtusToTest) == 1 { // only minMTU because minMTU == maxPossibleMTU
|
||||
return minMTU, nil
|
||||
}
|
||||
logger.Debugf("testing the following MTUs: %v", mtusToTest)
|
||||
|
||||
tests := make([]pmtudTestUnit, len(mtusToTest))
|
||||
for i := range mtusToTest {
|
||||
tests[i] = pmtudTestUnit{mtu: mtusToTest[i]}
|
||||
}
|
||||
|
||||
timedCtx, cancel := context.WithTimeout(ctx, pingTimeout)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-timedCtx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
for i := range tests {
|
||||
id, message := buildMessageToSend(ipVersion, tests[i].mtu)
|
||||
tests[i].echoID = id
|
||||
|
||||
encodedMessage, err := message.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encoding ICMP message: %w", err)
|
||||
}
|
||||
tests[i].sentBytes = len(encodedMessage)
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||
err = fmt.Errorf("%w", ErrICMPNotPermitted)
|
||||
}
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = collectReplies(conn, ipVersion, tests, logger)
|
||||
switch {
|
||||
case err == nil: // max possible MTU is working
|
||||
return tests[len(tests)-1].mtu, nil
|
||||
case err != nil && errors.Is(err, net.ErrClosed):
|
||||
// we have timeouts (IPv4 testing or IPv6 PMTUD blackholes)
|
||||
// so find the highest MTU which worked.
|
||||
// Note we start from index len(tests) - 2 since the max MTU
|
||||
// cannot be working if we had a timeout.
|
||||
for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd
|
||||
if tests[i].ok {
|
||||
return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1,
|
||||
pingTimeout, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// All MTUs failed.
|
||||
return 0, fmt.Errorf("%w: ICMP might be blocked", ErrMTUNotFound)
|
||||
case err != nil:
|
||||
return 0, fmt.Errorf("collecting ICMP echo replies: %w", err)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// Create the MTU slice of length 11 such that:
|
||||
// - the first element is the minMTU
|
||||
// - the last element is the maxMTU
|
||||
// - elements in-between are separated as close to each other
|
||||
// The number 11 is chosen to find the final MTU in 3 searches,
|
||||
// with a total search space of 1728 MTUs which is enough;
|
||||
// to find it in 2 searches requires 37 parallel queries which
|
||||
// could be blocked by firewalls.
|
||||
func makeMTUsToTest(minMTU, maxMTU int) (mtus []int) {
|
||||
const mtusLength = 11 // find the final MTU in 3 searches
|
||||
diff := maxMTU - minMTU
|
||||
switch {
|
||||
case minMTU > maxMTU:
|
||||
panic("minMTU > maxMTU")
|
||||
case diff <= mtusLength:
|
||||
mtus = make([]int, 0, diff)
|
||||
for mtu := minMTU; mtu <= maxMTU; mtu++ {
|
||||
mtus = append(mtus, mtu)
|
||||
}
|
||||
default:
|
||||
step := float64(diff) / float64(mtusLength-1)
|
||||
mtus = make([]int, 0, mtusLength)
|
||||
for mtu := float64(minMTU); len(mtus) < mtusLength-1; mtu += step {
|
||||
mtus = append(mtus, int(math.Round(mtu)))
|
||||
}
|
||||
mtus = append(mtus, maxMTU) // last element is the maxMTU
|
||||
}
|
||||
|
||||
return mtus
|
||||
}
|
||||
|
||||
func collectReplies(conn net.PacketConn, ipVersion string,
|
||||
tests []pmtudTestUnit, logger Logger,
|
||||
) (err error) {
|
||||
echoIDToTestIndex := make(map[uint16]int, len(tests))
|
||||
for i, test := range tests {
|
||||
echoIDToTestIndex[test.echoID] = i
|
||||
}
|
||||
|
||||
// The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would
|
||||
// create huge buffers which we don't really want to support anyway.
|
||||
// The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with
|
||||
// a conventional maximum of 9000 bytes. However, some manufacturers support up
|
||||
// 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to
|
||||
// match eventual Jumbo frames. More information at:
|
||||
// https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media
|
||||
const maxPossibleMTU = 9196
|
||||
buffer := make([]byte, maxPossibleMTU)
|
||||
|
||||
idsFound := 0
|
||||
for idsFound < len(tests) {
|
||||
// Note we need to read the whole packet in one call to ReadFrom, so the buffer
|
||||
// must be large enough to read the entire reply packet. See:
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
|
||||
ipPacketLength := len(packetBytes)
|
||||
|
||||
var icmpProtocol int
|
||||
switch ipVersion {
|
||||
case "v4":
|
||||
icmpProtocol = icmpv4Protocol
|
||||
case "v6":
|
||||
icmpProtocol = icmpv6Protocol
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown IP version: %s", ipVersion))
|
||||
}
|
||||
|
||||
// Parse the ICMP message
|
||||
// Note: this parsing works for a truncated 556 bytes ICMP reply packet.
|
||||
message, err := icmp.ParseMessage(icmpProtocol, packetBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
echoBody, ok := message.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body)
|
||||
}
|
||||
|
||||
id := uint16(echoBody.ID) //nolint:gosec
|
||||
testIndex, testing := echoIDToTestIndex[id]
|
||||
if !testing { // not an id we expected so ignore it
|
||||
logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)",
|
||||
echoBody.ID, message.Type, message.Code, ipPacketLength)
|
||||
continue
|
||||
}
|
||||
idsFound++
|
||||
sentBytes := tests[testIndex].sentBytes
|
||||
|
||||
// echo reply should be at most the number of bytes sent,
|
||||
// and can be lower, more precisely 556 bytes, in case
|
||||
// the host we are reaching wants to stay out of trouble
|
||||
// and ensure its echo reply goes through without
|
||||
// fragmentation, see the following page:
|
||||
// https://datatracker.ietf.org/doc/html/rfc1122#page-59
|
||||
const conservativeReplyLength = 556
|
||||
truncated := ipPacketLength < sentBytes &&
|
||||
ipPacketLength == conservativeReplyLength
|
||||
// Check the packet size is the same if the reply is not truncated
|
||||
if !truncated && sentBytes != ipPacketLength {
|
||||
return fmt.Errorf("%w: sent %dB and received %dB",
|
||||
ErrICMPEchoDataMismatch, sentBytes, ipPacketLength)
|
||||
}
|
||||
// Truncated reply or matching reply size
|
||||
tests[testIndex].ok = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
22
internal/pmtud/pmtud_integration_test.go
Normal file
22
internal/pmtud/pmtud_integration_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build integration
|
||||
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_PathMTUDiscover(t *testing.T) {
|
||||
t.Parallel()
|
||||
const physicalLinkMTU = 1500
|
||||
const timeout = time.Second
|
||||
mtu, err := PathMTUDiscover(context.Background(), netip.MustParseAddr("1.1.1.1"),
|
||||
physicalLinkMTU, timeout, nil)
|
||||
require.NoError(t, err)
|
||||
t.Log("MTU found:", mtu)
|
||||
}
|
||||
55
internal/pmtud/pmtud_test.go
Normal file
55
internal/pmtud/pmtud_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_makeMTUsToTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
minMTU int
|
||||
maxMTU int
|
||||
mtus []int
|
||||
}{
|
||||
"0_0": {
|
||||
mtus: []int{0},
|
||||
},
|
||||
"0_1": {
|
||||
maxMTU: 1,
|
||||
mtus: []int{0, 1},
|
||||
},
|
||||
"0_8": {
|
||||
maxMTU: 8,
|
||||
mtus: []int{0, 1, 2, 3, 4, 5, 6, 7, 8},
|
||||
},
|
||||
"0_12": {
|
||||
maxMTU: 12,
|
||||
mtus: []int{0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12},
|
||||
},
|
||||
"0_80": {
|
||||
maxMTU: 80,
|
||||
mtus: []int{0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80},
|
||||
},
|
||||
"0_100": {
|
||||
maxMTU: 100,
|
||||
mtus: []int{0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
|
||||
},
|
||||
"1280_1500": {
|
||||
minMTU: 1280,
|
||||
maxMTU: 1500,
|
||||
mtus: []int{1280, 1302, 1324, 1346, 1368, 1390, 1412, 1434, 1456, 1478, 1500},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mtus := makeMTUsToTest(testCase.minMTU, testCase.maxMTU)
|
||||
assert.Equal(t, testCase.mtus, mtus)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -81,6 +81,7 @@ type Linker interface {
|
||||
LinkDel(link netlink.Link) (err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) (err error)
|
||||
LinkSetMTU(link netlink.Link, mtu int) (err error)
|
||||
}
|
||||
|
||||
type DNSLoop interface {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/vpn/state"
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
@@ -22,10 +21,10 @@ type Loop struct {
|
||||
healthChecker HealthChecker
|
||||
healthServer HealthServer
|
||||
// Fixed parameters
|
||||
buildInfo models.BuildInformation
|
||||
versionInfo bool
|
||||
ipv6SupportLevel netlink.IPv6SupportLevel
|
||||
vpnInputPorts []uint16 // TODO make changeable through stateful firewall
|
||||
buildInfo models.BuildInformation
|
||||
versionInfo bool
|
||||
ipv6Supported bool
|
||||
vpnInputPorts []uint16 // TODO make changeable through stateful firewall
|
||||
// Configurators
|
||||
openvpnConf OpenVPN
|
||||
netLinker NetLinker
|
||||
@@ -52,7 +51,7 @@ const (
|
||||
defaultBackoffTime = 15 * time.Second
|
||||
)
|
||||
|
||||
func NewLoop(vpnSettings settings.VPN, ipv6SupportLevel netlink.IPv6SupportLevel, vpnInputPorts []uint16,
|
||||
func NewLoop(vpnSettings settings.VPN, ipv6Supported bool, vpnInputPorts []uint16,
|
||||
providers Providers, storage Storage, healthSettings settings.Health,
|
||||
healthChecker HealthChecker, healthServer HealthServer, openvpnConf OpenVPN,
|
||||
netLinker NetLinker, fw Firewall, routing Routing,
|
||||
@@ -70,32 +69,32 @@ func NewLoop(vpnSettings settings.VPN, ipv6SupportLevel netlink.IPv6SupportLevel
|
||||
state := state.New(statusManager, vpnSettings)
|
||||
|
||||
return &Loop{
|
||||
statusManager: statusManager,
|
||||
state: state,
|
||||
providers: providers,
|
||||
storage: storage,
|
||||
healthSettings: healthSettings,
|
||||
healthChecker: healthChecker,
|
||||
healthServer: healthServer,
|
||||
buildInfo: buildInfo,
|
||||
versionInfo: versionInfo,
|
||||
ipv6SupportLevel: ipv6SupportLevel,
|
||||
vpnInputPorts: vpnInputPorts,
|
||||
openvpnConf: openvpnConf,
|
||||
netLinker: netLinker,
|
||||
fw: fw,
|
||||
routing: routing,
|
||||
portForward: portForward,
|
||||
publicip: publicip,
|
||||
dnsLooper: dnsLooper,
|
||||
starter: starter,
|
||||
logger: logger,
|
||||
client: client,
|
||||
start: start,
|
||||
running: running,
|
||||
stop: stop,
|
||||
stopped: stopped,
|
||||
userTrigger: true,
|
||||
backoffTime: defaultBackoffTime,
|
||||
statusManager: statusManager,
|
||||
state: state,
|
||||
providers: providers,
|
||||
storage: storage,
|
||||
healthSettings: healthSettings,
|
||||
healthChecker: healthChecker,
|
||||
healthServer: healthServer,
|
||||
buildInfo: buildInfo,
|
||||
versionInfo: versionInfo,
|
||||
ipv6Supported: ipv6Supported,
|
||||
vpnInputPorts: vpnInputPorts,
|
||||
openvpnConf: openvpnConf,
|
||||
netLinker: netLinker,
|
||||
fw: fw,
|
||||
routing: routing,
|
||||
portForward: portForward,
|
||||
publicip: publicip,
|
||||
dnsLooper: dnsLooper,
|
||||
starter: starter,
|
||||
logger: logger,
|
||||
client: client,
|
||||
start: start,
|
||||
running: running,
|
||||
stop: stop,
|
||||
stopped: stopped,
|
||||
userTrigger: true,
|
||||
backoffTime: defaultBackoffTime,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
@@ -15,16 +14,15 @@ import (
|
||||
// It returns a serverName for port forwarding (PIA) and an error if it fails.
|
||||
func setupOpenVPN(ctx context.Context, fw Firewall,
|
||||
openvpnConf OpenVPN, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6SupportLevel netlink.IPv6SupportLevel, starter CmdStarter,
|
||||
settings settings.VPN, ipv6Supported bool, starter CmdStarter,
|
||||
logger openvpn.Logger) (runner *openvpn.Runner, connection models.Connection, err error,
|
||||
) {
|
||||
ipv6Internet := ipv6SupportLevel == netlink.IPv6Internet
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Internet)
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err)
|
||||
}
|
||||
|
||||
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6SupportLevel.IsSupported())
|
||||
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
|
||||
|
||||
if err := openvpnConf.WriteConfig(lines); err != nil {
|
||||
return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err)
|
||||
|
||||
@@ -36,17 +36,18 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
if settings.Type == vpn.OpenVPN {
|
||||
vpnInterface = settings.OpenVPN.Interface
|
||||
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
|
||||
l.openvpnConf, providerConf, settings, l.ipv6SupportLevel, l.starter, subLogger)
|
||||
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
|
||||
} else { // Wireguard
|
||||
vpnInterface = settings.Wireguard.Interface
|
||||
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||
providerConf, settings, l.ipv6SupportLevel, subLogger)
|
||||
providerConf, settings, l.ipv6Supported, subLogger)
|
||||
}
|
||||
if err != nil {
|
||||
l.crashed(ctx, err)
|
||||
continue
|
||||
}
|
||||
tunnelUpData := tunnelUpData{
|
||||
vpnType: settings.Type,
|
||||
serverIP: connection.IP,
|
||||
serverName: connection.ServerName,
|
||||
canPortForward: connection.PortForward,
|
||||
|
||||
@@ -2,16 +2,24 @@ package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/check"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud"
|
||||
"github.com/qdm12/gluetun/internal/version"
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
|
||||
type tunnelUpData struct {
|
||||
// Healthcheck
|
||||
serverIP netip.Addr
|
||||
// vpnType is used for path MTU discovery to find the protocol overhead.
|
||||
// It can be "wireguard" or "openvpn".
|
||||
vpnType string
|
||||
// Port forwarding
|
||||
vpnIntf string
|
||||
serverName string // used for PIA
|
||||
@@ -46,6 +54,13 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
||||
return
|
||||
}
|
||||
|
||||
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
||||
err = updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
|
||||
l.netLinker, l.routing, mtuLogger)
|
||||
if err != nil {
|
||||
mtuLogger.Error(err.Error())
|
||||
}
|
||||
|
||||
if *l.dnsLooper.GetSettings().ServerEnabled {
|
||||
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
|
||||
} else {
|
||||
@@ -112,3 +127,65 @@ func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
|
||||
_, _ = l.ApplyStatus(ctx, constants.Stopped)
|
||||
_, _ = l.ApplyStatus(ctx, constants.Running)
|
||||
}
|
||||
|
||||
var errVPNTypeUnknown = errors.New("unknown VPN type")
|
||||
|
||||
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
|
||||
) error {
|
||||
logger.Info("finding maximum MTU, this can take up to 4 seconds")
|
||||
|
||||
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting VPN gateway IP address: %w", err)
|
||||
}
|
||||
|
||||
link, err := netlinker.LinkByName(vpnInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting VPN interface by name: %w", err)
|
||||
}
|
||||
|
||||
originalMTU := link.MTU
|
||||
|
||||
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
|
||||
// protocol overhead, so start lower than 1500 according to the protocol used.
|
||||
const physicalLinkMTU = 1500
|
||||
vpnLinkMTU := physicalLinkMTU
|
||||
switch vpnType {
|
||||
case "wireguard":
|
||||
vpnLinkMTU -= 60 // Wireguard overhead
|
||||
case "openvpn":
|
||||
vpnLinkMTU -= 41 // OpenVPN overhead
|
||||
default:
|
||||
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
|
||||
}
|
||||
|
||||
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
||||
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
||||
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
|
||||
|
||||
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||
}
|
||||
|
||||
const pingTimeout = time.Second
|
||||
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
||||
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
|
||||
vpnLinkMTU = int(originalMTU)
|
||||
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
||||
vpnInterface, originalMTU, err)
|
||||
default:
|
||||
return fmt.Errorf("path MTU discovering: %w", err)
|
||||
}
|
||||
|
||||
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
@@ -17,16 +16,15 @@ import (
|
||||
// It returns a serverName for port forwarding (PIA) and an error if it fails.
|
||||
func setupWireguard(ctx context.Context, netlinker NetLinker,
|
||||
fw Firewall, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6SupportLevel netlink.IPv6SupportLevel, logger wireguard.Logger) (
|
||||
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
|
||||
wireguarder *wireguard.Wireguard, connection models.Connection, err error,
|
||||
) {
|
||||
ipv6Internet := ipv6SupportLevel == netlink.IPv6Internet
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Internet)
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
|
||||
}
|
||||
|
||||
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6SupportLevel.IsSupported())
|
||||
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
|
||||
|
||||
logger.Debug("Wireguard server public key: " + wireguardSettings.PublicKey)
|
||||
logger.Debug("Wireguard client private key: " + gosettings.ObfuscateKey(wireguardSettings.PrivateKey))
|
||||
|
||||
Reference in New Issue
Block a user