Compare commits

...

1 Commits

Author SHA1 Message Date
Quentin McGaw
c3eca4a17c wip 2024-11-08 17:25:12 +00:00
8 changed files with 62 additions and 5 deletions

View File

@@ -32,6 +32,11 @@ type Route struct {
Type int Type int
} }
func (r Route) String() string {
return fmt.Sprintf("{link %d, dst %s, src %s, gw %s, priority %d, family %d, table %d, type %d}",
r.LinkIndex, r.Dst, r.Src, r.Gw, r.Priority, r.Family, r.Table, r.Type)
}
type Rule struct { type Rule struct {
Priority int Priority int
Family int Family int

View File

@@ -67,6 +67,7 @@ type NetLinker interface {
type Router interface { type Router interface {
RouteList(family int) (routes []netlink.Route, err error) RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error RouteAdd(route netlink.Route) error
RouteReplace(route netlink.Route) error
} }
type Ruler interface { type Ruler interface {

View File

@@ -38,7 +38,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger) l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard } else { // Wireguard
vpnInterface = settings.Wireguard.Interface vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw, vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.routing, l.fw,
providerConf, settings, l.ipv6Supported, subLogger) providerConf, settings, l.ipv6Supported, subLogger)
} }
if err != nil { if err != nil {

View File

@@ -13,7 +13,7 @@ import (
// setupWireguard sets Wireguard up using the configurators and settings given. // setupWireguard sets Wireguard up using the configurators and settings given.
// It returns a serverName for port forwarding (PIA) and an error if it fails. // It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupWireguard(ctx context.Context, netlinker NetLinker, func setupWireguard(ctx context.Context, netlinker NetLinker, routing Routing,
fw Firewall, providerConf provider.Provider, fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) ( settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error, wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error,
@@ -29,7 +29,7 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
logger.Debug("Wireguard client private key: " + gosettings.ObfuscateKey(wireguardSettings.PrivateKey)) logger.Debug("Wireguard client private key: " + gosettings.ObfuscateKey(wireguardSettings.PrivateKey))
logger.Debug("Wireguard pre-shared key: " + gosettings.ObfuscateKey(wireguardSettings.PreSharedKey)) logger.Debug("Wireguard pre-shared key: " + gosettings.ObfuscateKey(wireguardSettings.PreSharedKey))
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger) wireguarder, err = wireguard.New(wireguardSettings, netlinker, routing, logger)
if err != nil { if err != nil {
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err) return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
} }

View File

@@ -4,10 +4,11 @@ type Wireguard struct {
logger Logger logger Logger
settings Settings settings Settings
netlink NetLinker netlink NetLinker
routing Routing
} }
func New(settings Settings, netlink NetLinker, func New(settings Settings, netlink NetLinker,
logger Logger, routing Routing, logger Logger,
) (w *Wireguard, err error) { ) (w *Wireguard, err error) {
settings.SetDefaults() settings.SetDefaults()
if err := settings.Check(); err != nil { if err := settings.Check(); err != nil {
@@ -18,5 +19,6 @@ func New(settings Settings, netlink NetLinker,
logger: logger, logger: logger,
settings: settings, settings: settings,
netlink: netlink, netlink: netlink,
routing: routing,
}, nil }, nil
} }

View File

@@ -0,0 +1,7 @@
package wireguard
import "net/netip"
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
}

View File

@@ -1,6 +1,8 @@
package wireguard package wireguard
import "github.com/qdm12/gluetun/internal/netlink" import (
"github.com/qdm12/gluetun/internal/netlink"
)
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker //go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
@@ -15,6 +17,7 @@ type NetLinker interface {
type Router interface { type Router interface {
RouteList(family int) (routes []netlink.Route, err error) RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error RouteAdd(route netlink.Route) error
RouteReplace(route netlink.Route) error
} }
type Ruler interface { type Ruler interface {

View File

@@ -1,6 +1,7 @@
package wireguard package wireguard
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"strings" "strings"
@@ -29,6 +30,10 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
return nil return nil
} }
var (
ErrDefaultRouteNotFound = errors.New("default route not found")
)
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix, func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark uint32, firewallMark uint32,
) (err error) { ) (err error) {
@@ -45,5 +50,39 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
link.Name, dst, firewallMark, err) link.Name, dst, firewallMark, err)
} }
vpnGatewayIP, err := w.routing.VPNLocalGatewayIP(link.Name)
if err != nil {
return fmt.Errorf("getting VPN gateway IP: %w", err)
}
routes, err := w.netlink.RouteList(netlink.FamilyV4)
if err != nil {
return fmt.Errorf("listing routes: %w", err)
}
var defaultRoute netlink.Route
var defaultRouteFound bool
for _, route = range routes {
if !route.Dst.IsValid() || route.Dst.Addr().IsUnspecified() {
defaultRoute = route
defaultRouteFound = true
break
}
}
if !defaultRouteFound {
return fmt.Errorf("%w: in %d routes", ErrDefaultRouteNotFound, len(routes))
}
// Equivalent replacement to:
// ip route replace default via <vpn-gateway> dev tun0
defaultRoute.Gw = vpnGatewayIP
defaultRoute.LinkIndex = link.Index
err = w.netlink.RouteReplace(defaultRoute)
if err != nil {
return fmt.Errorf("replacing default route: %w", err)
}
return err return err
} }