chore(all): replace net.IP with netip.Addr
This commit is contained in:
@@ -3,7 +3,7 @@ package routing
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
@@ -14,8 +14,8 @@ var (
|
||||
|
||||
type DefaultRoute struct {
|
||||
NetInterface string
|
||||
Gateway net.IP
|
||||
AssignedIP net.IP
|
||||
Gateway netip.Addr
|
||||
AssignedIP netip.Addr
|
||||
Family int
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
|
||||
continue
|
||||
}
|
||||
defaultRoute := DefaultRoute{
|
||||
Gateway: route.Gw,
|
||||
Gateway: netIPToNetipAddress(route.Gw),
|
||||
Family: route.Family,
|
||||
}
|
||||
linkIndex := route.LinkIndex
|
||||
|
||||
@@ -62,7 +62,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
|
||||
|
||||
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
|
||||
for _, defaultRoute := range defaultRoutes {
|
||||
assignedIP := netIPToNetipAddress(defaultRoute.AssignedIP)
|
||||
assignedIP := defaultRoute.AssignedIP
|
||||
bits := 32
|
||||
if assignedIP.Is6() {
|
||||
bits = 128
|
||||
@@ -82,7 +82,7 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
|
||||
|
||||
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
|
||||
for _, defaultRoute := range defaultRoutes {
|
||||
assignedIP := netIPToNetipAddress(defaultRoute.AssignedIP)
|
||||
assignedIP := defaultRoute.AssignedIP
|
||||
bits := 32
|
||||
if assignedIP.Is6() {
|
||||
bits = 128
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func IPIsPrivate(ip net.IP) bool {
|
||||
func IPIsPrivate(ip netip.Addr) bool {
|
||||
return ip.IsPrivate() || ip.IsLoopback() ||
|
||||
ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()
|
||||
}
|
||||
@@ -17,38 +18,26 @@ var (
|
||||
errInterfaceIPNotFound = errors.New("IP address not found for interface")
|
||||
)
|
||||
|
||||
func ipMatchesFamily(ip net.IP, family int) bool {
|
||||
return (family == netlink.FAMILY_V6 && ip.To4() == nil) ||
|
||||
(family == netlink.FAMILY_V4 && ip.To4() != nil)
|
||||
func ipMatchesFamily(ip netip.Addr, family int) bool {
|
||||
return (family == netlink.FAMILY_V6 && ip.Is6()) ||
|
||||
(family == netlink.FAMILY_V4 && (ip.Is4() || ip.Is4In6()))
|
||||
}
|
||||
|
||||
func ensureNoIPv6WrappedIPv4(candidateIP net.IP) (resultIP net.IP) {
|
||||
const ipv4Size = 4
|
||||
if candidateIP.To4() == nil || len(candidateIP) == ipv4Size { // ipv6 or ipv4
|
||||
return candidateIP
|
||||
}
|
||||
|
||||
// ipv6-wrapped ipv4
|
||||
resultIP = make(net.IP, ipv4Size)
|
||||
copy(resultIP, candidateIP[12:16])
|
||||
return resultIP
|
||||
}
|
||||
|
||||
func (r *Routing) assignedIP(interfaceName string, family int) (ip net.IP, err error) {
|
||||
func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
|
||||
return ip, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
|
||||
}
|
||||
addresses, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing interface %s addresses: %w", interfaceName, err)
|
||||
return ip, fmt.Errorf("listing interface %s addresses: %w", interfaceName, err)
|
||||
}
|
||||
for _, address := range addresses {
|
||||
switch value := address.(type) {
|
||||
case *net.IPAddr:
|
||||
ip = value.IP
|
||||
ip = netIPToNetipAddress(value.IP)
|
||||
case *net.IPNet:
|
||||
ip = value.IP
|
||||
ip = netIPToNetipAddress(value.IP)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
@@ -60,9 +49,8 @@ func (r *Routing) assignedIP(interfaceName string, family int) (ip net.IP, err e
|
||||
// Ensure we don't return an IPv6-wrapped IPv4 address
|
||||
// since netip.Address String method works differently than
|
||||
// net.IP String method for this kind of addresses.
|
||||
ip = ensureNoIPv6WrappedIPv4(ip)
|
||||
return ip, nil
|
||||
return ip.Unmap(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("%w: interface %s in %d addresses",
|
||||
return ip, fmt.Errorf("%w: interface %s in %d addresses",
|
||||
errInterfaceIPNotFound, interfaceName, len(addresses))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -87,8 +87,8 @@ func Test_IPIsPrivate(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ip := net.ParseIP(testCase.ipString)
|
||||
require.NotNil(t, ip)
|
||||
ip, err := netip.ParseAddr(testCase.ipString)
|
||||
require.NoError(t, err)
|
||||
|
||||
isPrivate := IPIsPrivate(ip)
|
||||
|
||||
@@ -96,35 +96,3 @@ func Test_IPIsPrivate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ensureNoIPv6WrappedIPv4(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
candidateIP net.IP
|
||||
resultIP net.IP
|
||||
}{
|
||||
"nil": {},
|
||||
"ipv6": {
|
||||
candidateIP: net.IPv6loopback,
|
||||
resultIP: net.IPv6loopback,
|
||||
},
|
||||
"ipv4": {
|
||||
candidateIP: net.IP{1, 2, 3, 4},
|
||||
resultIP: net.IP{1, 2, 3, 4},
|
||||
},
|
||||
"ipv6_wrapped_ipv4": {
|
||||
candidateIP: net.IPv4(1, 2, 3, 4),
|
||||
resultIP: net.IP{1, 2, 3, 4},
|
||||
},
|
||||
}
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resultIP := ensureNoIPv6WrappedIPv4(testCase.candidateIP)
|
||||
assert.Equal(t, testCase.resultIP, resultIP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package routing
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
@@ -18,7 +17,7 @@ var (
|
||||
type LocalNetwork struct {
|
||||
IPNet netip.Prefix
|
||||
InterfaceName string
|
||||
IP net.IP
|
||||
IP netip.Addr
|
||||
}
|
||||
|
||||
func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
||||
|
||||
@@ -2,14 +2,13 @@ package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
|
||||
func (r *Routing) addRouteVia(destination netip.Prefix, gateway net.IP,
|
||||
func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
|
||||
iface string, table int) error {
|
||||
destinationStr := destination.String()
|
||||
r.logger.Info("adding route for " + destinationStr)
|
||||
@@ -25,7 +24,7 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway net.IP,
|
||||
|
||||
route := netlink.Route{
|
||||
Dst: NetipPrefixToIPNet(&destination),
|
||||
Gw: gateway,
|
||||
Gw: gateway.AsSlice(),
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: table,
|
||||
}
|
||||
@@ -37,7 +36,7 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway net.IP,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway net.IP,
|
||||
func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
|
||||
iface string, table int) (err error) {
|
||||
destinationStr := destination.String()
|
||||
r.logger.Info("deleting route for " + destinationStr)
|
||||
@@ -53,7 +52,7 @@ func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway net.IP,
|
||||
|
||||
route := netlink.Route{
|
||||
Dst: NetipPrefixToIPNet(&destination),
|
||||
Gw: gateway,
|
||||
Gw: gateway.AsSlice(),
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: table,
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/netlink"
|
||||
)
|
||||
@@ -14,10 +15,10 @@ var (
|
||||
ErrVPNLocalGatewayIPNotFound = errors.New("VPN local gateway IP address not found")
|
||||
)
|
||||
|
||||
func (r *Routing) VPNDestinationIP() (ip net.IP, err error) {
|
||||
func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
|
||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing routes: %w", err)
|
||||
return ip, fmt.Errorf("listing routes: %w", err)
|
||||
}
|
||||
|
||||
defaultLinkIndex := -1
|
||||
@@ -28,36 +29,36 @@ func (r *Routing) VPNDestinationIP() (ip net.IP, err error) {
|
||||
}
|
||||
}
|
||||
if defaultLinkIndex == -1 {
|
||||
return nil, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes))
|
||||
return ip, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes))
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.LinkIndex == defaultLinkIndex &&
|
||||
route.Dst != nil &&
|
||||
!IPIsPrivate(route.Dst.IP) &&
|
||||
!IPIsPrivate(netIPToNetipAddress(route.Dst.IP)) &&
|
||||
bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) {
|
||||
return route.Dst.IP, nil
|
||||
return netIPToNetipAddress(route.Dst.IP), nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
|
||||
return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
|
||||
}
|
||||
|
||||
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
|
||||
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
|
||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing routes: %w", err)
|
||||
return ip, fmt.Errorf("listing routes: %w", err)
|
||||
}
|
||||
for _, route := range routes {
|
||||
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
|
||||
return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
|
||||
}
|
||||
interfaceName := link.Attrs().Name
|
||||
if interfaceName == vpnIntf &&
|
||||
route.Dst != nil &&
|
||||
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
|
||||
return route.Gw, nil
|
||||
return netIPToNetipAddress(route.Gw), nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
|
||||
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user