chore(all): replace net.IP with netip.Addr

This commit is contained in:
Quentin McGaw
2023-05-20 19:58:18 +00:00
parent 00ee6ff9a7
commit 0a29337c3b
91 changed files with 525 additions and 590 deletions

View File

@@ -3,7 +3,7 @@ package routing
import (
"errors"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
@@ -14,8 +14,8 @@ var (
type DefaultRoute struct {
NetInterface string
Gateway net.IP
AssignedIP net.IP
Gateway netip.Addr
AssignedIP netip.Addr
Family int
}
@@ -35,7 +35,7 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
continue
}
defaultRoute := DefaultRoute{
Gateway: route.Gw,
Gateway: netIPToNetipAddress(route.Gw),
Family: route.Family,
}
linkIndex := route.LinkIndex

View File

@@ -62,7 +62,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
assignedIP := netIPToNetipAddress(defaultRoute.AssignedIP)
assignedIP := defaultRoute.AssignedIP
bits := 32
if assignedIP.Is6() {
bits = 128
@@ -82,7 +82,7 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
for _, defaultRoute := range defaultRoutes {
assignedIP := netIPToNetipAddress(defaultRoute.AssignedIP)
assignedIP := defaultRoute.AssignedIP
bits := 32
if assignedIP.Is6() {
bits = 128

View File

@@ -4,11 +4,12 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func IPIsPrivate(ip net.IP) bool {
func IPIsPrivate(ip netip.Addr) bool {
return ip.IsPrivate() || ip.IsLoopback() ||
ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()
}
@@ -17,38 +18,26 @@ var (
errInterfaceIPNotFound = errors.New("IP address not found for interface")
)
func ipMatchesFamily(ip net.IP, family int) bool {
return (family == netlink.FAMILY_V6 && ip.To4() == nil) ||
(family == netlink.FAMILY_V4 && ip.To4() != nil)
func ipMatchesFamily(ip netip.Addr, family int) bool {
return (family == netlink.FAMILY_V6 && ip.Is6()) ||
(family == netlink.FAMILY_V4 && (ip.Is4() || ip.Is4In6()))
}
func ensureNoIPv6WrappedIPv4(candidateIP net.IP) (resultIP net.IP) {
const ipv4Size = 4
if candidateIP.To4() == nil || len(candidateIP) == ipv4Size { // ipv6 or ipv4
return candidateIP
}
// ipv6-wrapped ipv4
resultIP = make(net.IP, ipv4Size)
copy(resultIP, candidateIP[12:16])
return resultIP
}
func (r *Routing) assignedIP(interfaceName string, family int) (ip net.IP, err error) {
func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return nil, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
}
addresses, err := iface.Addrs()
if err != nil {
return nil, fmt.Errorf("listing interface %s addresses: %w", interfaceName, err)
return ip, fmt.Errorf("listing interface %s addresses: %w", interfaceName, err)
}
for _, address := range addresses {
switch value := address.(type) {
case *net.IPAddr:
ip = value.IP
ip = netIPToNetipAddress(value.IP)
case *net.IPNet:
ip = value.IP
ip = netIPToNetipAddress(value.IP)
default:
continue
}
@@ -60,9 +49,8 @@ func (r *Routing) assignedIP(interfaceName string, family int) (ip net.IP, err e
// Ensure we don't return an IPv6-wrapped IPv4 address
// since netip.Address String method works differently than
// net.IP String method for this kind of addresses.
ip = ensureNoIPv6WrappedIPv4(ip)
return ip, nil
return ip.Unmap(), nil
}
return nil, fmt.Errorf("%w: interface %s in %d addresses",
return ip, fmt.Errorf("%w: interface %s in %d addresses",
errInterfaceIPNotFound, interfaceName, len(addresses))
}

View File

@@ -1,7 +1,7 @@
package routing
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
@@ -87,8 +87,8 @@ func Test_IPIsPrivate(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
ip := net.ParseIP(testCase.ipString)
require.NotNil(t, ip)
ip, err := netip.ParseAddr(testCase.ipString)
require.NoError(t, err)
isPrivate := IPIsPrivate(ip)
@@ -96,35 +96,3 @@ func Test_IPIsPrivate(t *testing.T) {
})
}
}
func Test_ensureNoIPv6WrappedIPv4(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
candidateIP net.IP
resultIP net.IP
}{
"nil": {},
"ipv6": {
candidateIP: net.IPv6loopback,
resultIP: net.IPv6loopback,
},
"ipv4": {
candidateIP: net.IP{1, 2, 3, 4},
resultIP: net.IP{1, 2, 3, 4},
},
"ipv6_wrapped_ipv4": {
candidateIP: net.IPv4(1, 2, 3, 4),
resultIP: net.IP{1, 2, 3, 4},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
resultIP := ensureNoIPv6WrappedIPv4(testCase.candidateIP)
assert.Equal(t, testCase.resultIP, resultIP)
})
}
}

View File

@@ -3,7 +3,6 @@ package routing
import (
"errors"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
@@ -18,7 +17,7 @@ var (
type LocalNetwork struct {
IPNet netip.Prefix
InterfaceName string
IP net.IP
IP netip.Addr
}
func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {

View File

@@ -2,14 +2,13 @@ package routing
import (
"fmt"
"net"
"net/netip"
"strconv"
"github.com/qdm12/gluetun/internal/netlink"
)
func (r *Routing) addRouteVia(destination netip.Prefix, gateway net.IP,
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table int) error {
destinationStr := destination.String()
r.logger.Info("adding route for " + destinationStr)
@@ -25,7 +24,7 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway net.IP,
route := netlink.Route{
Dst: NetipPrefixToIPNet(&destination),
Gw: gateway,
Gw: gateway.AsSlice(),
LinkIndex: link.Attrs().Index,
Table: table,
}
@@ -37,7 +36,7 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway net.IP,
return nil
}
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway net.IP,
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
iface string, table int) (err error) {
destinationStr := destination.String()
r.logger.Info("deleting route for " + destinationStr)
@@ -53,7 +52,7 @@ func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway net.IP,
route := netlink.Route{
Dst: NetipPrefixToIPNet(&destination),
Gw: gateway,
Gw: gateway.AsSlice(),
LinkIndex: link.Attrs().Index,
Table: table,
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
@@ -14,10 +15,10 @@ var (
ErrVPNLocalGatewayIPNotFound = errors.New("VPN local gateway IP address not found")
)
func (r *Routing) VPNDestinationIP() (ip net.IP, err error) {
func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("listing routes: %w", err)
return ip, fmt.Errorf("listing routes: %w", err)
}
defaultLinkIndex := -1
@@ -28,36 +29,36 @@ func (r *Routing) VPNDestinationIP() (ip net.IP, err error) {
}
}
if defaultLinkIndex == -1 {
return nil, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes))
return ip, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes))
}
for _, route := range routes {
if route.LinkIndex == defaultLinkIndex &&
route.Dst != nil &&
!IPIsPrivate(route.Dst.IP) &&
!IPIsPrivate(netIPToNetipAddress(route.Dst.IP)) &&
bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) {
return route.Dst.IP, nil
return netIPToNetipAddress(route.Dst.IP), nil
}
}
return nil, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
}
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("listing routes: %w", err)
return ip, fmt.Errorf("listing routes: %w", err)
}
for _, route := range routes {
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
if err != nil {
return nil, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
}
interfaceName := link.Attrs().Name
if interfaceName == vpnIntf &&
route.Dst != nil &&
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
return route.Gw, nil
return netIPToNetipAddress(route.Gw), nil
}
}
return nil, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
}