Compare commits

...

1 Commits

Author SHA1 Message Date
Quentin McGaw
c3eca4a17c wip 2024-11-08 17:25:12 +00:00
8 changed files with 62 additions and 5 deletions

View File

@@ -32,6 +32,11 @@ type Route struct {
Type int
}
func (r Route) String() string {
return fmt.Sprintf("{link %d, dst %s, src %s, gw %s, priority %d, family %d, table %d, type %d}",
r.LinkIndex, r.Dst, r.Src, r.Gw, r.Priority, r.Family, r.Table, r.Type)
}
type Rule struct {
Priority int
Family int

View File

@@ -67,6 +67,7 @@ type NetLinker interface {
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {

View File

@@ -38,7 +38,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard
vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.routing, l.fw,
providerConf, settings, l.ipv6Supported, subLogger)
}
if err != nil {

View File

@@ -13,7 +13,7 @@ import (
// setupWireguard sets Wireguard up using the configurators and settings given.
// It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupWireguard(ctx context.Context, netlinker NetLinker,
func setupWireguard(ctx context.Context, netlinker NetLinker, routing Routing,
fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error,
@@ -29,7 +29,7 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
logger.Debug("Wireguard client private key: " + gosettings.ObfuscateKey(wireguardSettings.PrivateKey))
logger.Debug("Wireguard pre-shared key: " + gosettings.ObfuscateKey(wireguardSettings.PreSharedKey))
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
wireguarder, err = wireguard.New(wireguardSettings, netlinker, routing, logger)
if err != nil {
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
}

View File

@@ -4,10 +4,11 @@ type Wireguard struct {
logger Logger
settings Settings
netlink NetLinker
routing Routing
}
func New(settings Settings, netlink NetLinker,
logger Logger,
routing Routing, logger Logger,
) (w *Wireguard, err error) {
settings.SetDefaults()
if err := settings.Check(); err != nil {
@@ -18,5 +19,6 @@ func New(settings Settings, netlink NetLinker,
logger: logger,
settings: settings,
netlink: netlink,
routing: routing,
}, nil
}

View File

@@ -0,0 +1,7 @@
package wireguard
import "net/netip"
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
}

View File

@@ -1,6 +1,8 @@
package wireguard
import "github.com/qdm12/gluetun/internal/netlink"
import (
"github.com/qdm12/gluetun/internal/netlink"
)
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
@@ -15,6 +17,7 @@ type NetLinker interface {
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {

View File

@@ -1,6 +1,7 @@
package wireguard
import (
"errors"
"fmt"
"net/netip"
"strings"
@@ -29,6 +30,10 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
return nil
}
var (
ErrDefaultRouteNotFound = errors.New("default route not found")
)
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark uint32,
) (err error) {
@@ -45,5 +50,39 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
link.Name, dst, firewallMark, err)
}
vpnGatewayIP, err := w.routing.VPNLocalGatewayIP(link.Name)
if err != nil {
return fmt.Errorf("getting VPN gateway IP: %w", err)
}
routes, err := w.netlink.RouteList(netlink.FamilyV4)
if err != nil {
return fmt.Errorf("listing routes: %w", err)
}
var defaultRoute netlink.Route
var defaultRouteFound bool
for _, route = range routes {
if !route.Dst.IsValid() || route.Dst.Addr().IsUnspecified() {
defaultRoute = route
defaultRouteFound = true
break
}
}
if !defaultRouteFound {
return fmt.Errorf("%w: in %d routes", ErrDefaultRouteNotFound, len(routes))
}
// Equivalent replacement to:
// ip route replace default via <vpn-gateway> dev tun0
defaultRoute.Gw = vpnGatewayIP
defaultRoute.LinkIndex = link.Index
err = w.netlink.RouteReplace(defaultRoute)
if err != nil {
return fmt.Errorf("replacing default route: %w", err)
}
return err
}