This commit is contained in:
Quentin McGaw
2023-06-28 14:42:18 +00:00
parent 92011205be
commit c3eca4a17c
8 changed files with 62 additions and 5 deletions

View File

@@ -4,10 +4,11 @@ type Wireguard struct {
logger Logger
settings Settings
netlink NetLinker
routing Routing
}
func New(settings Settings, netlink NetLinker,
logger Logger,
routing Routing, logger Logger,
) (w *Wireguard, err error) {
settings.SetDefaults()
if err := settings.Check(); err != nil {
@@ -18,5 +19,6 @@ func New(settings Settings, netlink NetLinker,
logger: logger,
settings: settings,
netlink: netlink,
routing: routing,
}, nil
}

View File

@@ -0,0 +1,7 @@
package wireguard
import "net/netip"
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
}

View File

@@ -1,6 +1,8 @@
package wireguard
import "github.com/qdm12/gluetun/internal/netlink"
import (
"github.com/qdm12/gluetun/internal/netlink"
)
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
@@ -15,6 +17,7 @@ type NetLinker interface {
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {

View File

@@ -1,6 +1,7 @@
package wireguard
import (
"errors"
"fmt"
"net/netip"
"strings"
@@ -29,6 +30,10 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
return nil
}
var (
ErrDefaultRouteNotFound = errors.New("default route not found")
)
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark uint32,
) (err error) {
@@ -45,5 +50,39 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
link.Name, dst, firewallMark, err)
}
vpnGatewayIP, err := w.routing.VPNLocalGatewayIP(link.Name)
if err != nil {
return fmt.Errorf("getting VPN gateway IP: %w", err)
}
routes, err := w.netlink.RouteList(netlink.FamilyV4)
if err != nil {
return fmt.Errorf("listing routes: %w", err)
}
var defaultRoute netlink.Route
var defaultRouteFound bool
for _, route = range routes {
if !route.Dst.IsValid() || route.Dst.Addr().IsUnspecified() {
defaultRoute = route
defaultRouteFound = true
break
}
}
if !defaultRouteFound {
return fmt.Errorf("%w: in %d routes", ErrDefaultRouteNotFound, len(routes))
}
// Equivalent replacement to:
// ip route replace default via <vpn-gateway> dev tun0
defaultRoute.Gw = vpnGatewayIP
defaultRoute.LinkIndex = link.Index
err = w.netlink.RouteReplace(defaultRoute)
if err != nil {
return fmt.Errorf("replacing default route: %w", err)
}
return err
}