feat(netlink): detect IPv6 using query to address

- If a default IPv6 route is found, query the ip:port defined by `IPV6_CHECK_ADDRESS` to check for internet access
This commit is contained in:
Quentin McGaw
2024-12-12 06:48:43 +00:00
parent dae44051f6
commit 5ca13021e7
13 changed files with 384 additions and 7 deletions

View File

@@ -1,9 +1,19 @@
package netlink
import "github.com/qdm12/log"
import (
"context"
"net/netip"
"github.com/qdm12/log"
)
type DebugLogger interface {
Debug(message string)
Debugf(format string, args ...any)
Patch(options ...log.Option)
}
type Firewall interface {
AcceptOutput(ctx context.Context, protocol, intf string, ip netip.Addr,
port uint16, remove bool) (err error)
}

View File

@@ -1,7 +1,11 @@
package netlink
import (
"context"
"fmt"
"net"
"net/netip"
"time"
)
type IPv6SupportLevel uint8
@@ -21,7 +25,9 @@ func (i IPv6SupportLevel) IsSupported() bool {
return i == IPv6Supported || i == IPv6Internet
}
func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) {
func (n *NetLink) FindIPv6SupportLevel(ctx context.Context,
checkAddress netip.AddrPort, firewall Firewall,
) (level IPv6SupportLevel, err error) {
routes, err := n.RouteList(FamilyV6)
if err != nil {
return IPv6Unsupported, fmt.Errorf("listing IPv6 routes: %w", err)
@@ -44,7 +50,14 @@ func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) {
case sourceIsIPv4 && destinationIsIPv4,
destinationIsIPv6 && route.Dst.Addr().IsLoopback():
case route.Dst.Addr().IsUnspecified(): // default ipv6 route
n.debugLogger.Debugf("IPv6 internet access is enabled on link %s", link.Name)
n.debugLogger.Debugf("IPv6 default route found on link %s", link.Name)
err = dialAddrThroughFirewall(ctx, link.Name, checkAddress, firewall)
if err != nil {
n.debugLogger.Debugf("IPv6 query failed on %s: %w", link.Name, err)
level = IPv6Supported
continue
}
n.debugLogger.Debugf("IPv6 internet is accessible through link %s", link.Name)
return IPv6Internet, nil
default: // non-default ipv6 route found
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
@@ -57,3 +70,37 @@ func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) {
}
return level, nil
}
func dialAddrThroughFirewall(ctx context.Context, intf string,
checkAddress netip.AddrPort, firewall Firewall,
) (err error) {
const protocol = "tcp"
remove := false
err = firewall.AcceptOutput(ctx, protocol, intf,
checkAddress.Addr(), checkAddress.Port(), remove)
if err != nil {
return fmt.Errorf("accepting output traffic: %w", err)
}
defer func() {
remove = true
firewallErr := firewall.AcceptOutput(ctx, protocol, intf,
checkAddress.Addr(), checkAddress.Port(), remove)
if err == nil && firewallErr != nil {
err = fmt.Errorf("removing output traffic rule: %w", firewallErr)
}
}()
dialer := &net.Dialer{
Timeout: time.Second,
}
conn, err := dialer.DialContext(ctx, protocol, checkAddress.String())
if err != nil {
return fmt.Errorf("dialing: %w", err)
}
err = conn.Close()
if err != nil {
return fmt.Errorf("closing connection: %w", err)
}
return nil
}

View File

@@ -0,0 +1,166 @@
package netlink
import (
"context"
"errors"
"net"
"net/netip"
"strings"
"testing"
"time"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func isIPv6LocallySupported() bool {
dialer := net.Dialer{Timeout: time.Millisecond}
_, err := dialer.Dial("tcp6", "[::1]:9999")
return !strings.HasSuffix(err.Error(), "connect: cannot assign requested address")
}
// Susceptible to TOCTOU but it should be fine for the use case.
func findAvailableTCPPort(t *testing.T) (port uint16) {
t.Helper()
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
addr := listener.Addr().String()
err = listener.Close()
require.NoError(t, err)
addrPort, err := netip.ParseAddrPort(addr)
require.NoError(t, err)
return addrPort.Port()
}
func Test_dialAddrThroughFirewall(t *testing.T) {
t.Parallel()
errTest := errors.New("test error")
const ipv6InternetWorks = false
testCases := map[string]struct {
getIPv6CheckAddr func(t *testing.T) netip.AddrPort
firewallAddErr error
firewallRemoveErr error
errMessageRegex func() string
}{
"cloudflare.com": {
getIPv6CheckAddr: func(_ *testing.T) netip.AddrPort {
return netip.MustParseAddrPort("[2606:4700::6810:84e5]:443")
},
errMessageRegex: func() string {
if ipv6InternetWorks {
return ""
}
return "dialing: dial tcp \\[2606:4700::6810:84e5\\]:443: " +
"connect: (cannot assign requested address|network is unreachable)"
},
},
"local_server": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
network := "tcp6"
loopback := netip.MustParseAddr("::1")
if !isIPv6LocallySupported() {
network = "tcp4"
loopback = netip.MustParseAddr("127.0.0.1")
}
listener, err := net.ListenTCP(network, nil)
require.NoError(t, err)
t.Cleanup(func() {
err := listener.Close()
require.NoError(t, err)
})
addrPort := netip.MustParseAddrPort(listener.Addr().String())
return netip.AddrPortFrom(loopback, addrPort.Port())
},
},
"no_local_server": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
loopback := netip.MustParseAddr("::1")
if !ipv6InternetWorks {
loopback = netip.MustParseAddr("127.0.0.1")
}
availablePort := findAvailableTCPPort(t)
return netip.AddrPortFrom(loopback, availablePort)
},
errMessageRegex: func() string {
return "dialing: dial tcp (\\[::1\\]|127\\.0\\.0\\.1):[1-9][0-9]{1,4}: " +
"connect: connection refused"
},
},
"firewall_add_error": {
firewallAddErr: errTest,
errMessageRegex: func() string {
return "accepting output traffic: test error"
},
},
"firewall_remove_error": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
network := "tcp4"
loopback := netip.MustParseAddr("127.0.0.1")
listener, err := net.ListenTCP(network, nil)
require.NoError(t, err)
t.Cleanup(func() {
err := listener.Close()
require.NoError(t, err)
})
addrPort := netip.MustParseAddrPort(listener.Addr().String())
return netip.AddrPortFrom(loopback, addrPort.Port())
},
firewallRemoveErr: errTest,
errMessageRegex: func() string {
return "removing output traffic rule: test error"
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var checkAddr netip.AddrPort
if testCase.getIPv6CheckAddr != nil {
checkAddr = testCase.getIPv6CheckAddr(t)
}
ctx := context.Background()
const intf = "eth0"
firewall := NewMockFirewall(ctrl)
call := firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
checkAddr.Addr(), checkAddr.Port(), false).
Return(testCase.firewallAddErr)
if testCase.firewallAddErr == nil {
firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
checkAddr.Addr(), checkAddr.Port(), true).
Return(testCase.firewallRemoveErr).After(call)
}
err := dialAddrThroughFirewall(ctx, intf, checkAddr, firewall)
var errMessageRegex string
if testCase.errMessageRegex != nil {
errMessageRegex = testCase.errMessageRegex()
}
if errMessageRegex == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Regexp(t, errMessageRegex, err.Error())
}
})
}
}

View File

@@ -0,0 +1,3 @@
package netlink
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall

View File

@@ -0,0 +1,50 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/netlink (interfaces: Firewall)
// Package netlink is a generated GoMock package.
package netlink
import (
context "context"
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockFirewall is a mock of Firewall interface.
type MockFirewall struct {
ctrl *gomock.Controller
recorder *MockFirewallMockRecorder
}
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
type MockFirewallMockRecorder struct {
mock *MockFirewall
}
// NewMockFirewall creates a new mock instance.
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
mock := &MockFirewall{ctrl: ctrl}
mock.recorder = &MockFirewallMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
return m.recorder
}
// AcceptOutput mocks base method.
func (m *MockFirewall) AcceptOutput(arg0 context.Context, arg1, arg2 string, arg3 netip.Addr, arg4 uint16, arg5 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptOutput", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(error)
return ret0
}
// AcceptOutput indicates an expected call of AcceptOutput.
func (mr *MockFirewallMockRecorder) AcceptOutput(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutput", reflect.TypeOf((*MockFirewall)(nil).AcceptOutput), arg0, arg1, arg2, arg3, arg4, arg5)
}