Compare commits

..

1 Commits

Author SHA1 Message Date
Quentin McGaw
c3eca4a17c wip 2024-11-08 17:25:12 +00:00
16 changed files with 62 additions and 840 deletions

View File

@@ -32,6 +32,11 @@ type Route struct {
Type int
}
func (r Route) String() string {
return fmt.Sprintf("{link %d, dst %s, src %s, gw %s, priority %d, family %d, table %d, type %d}",
r.LinkIndex, r.Dst, r.Src, r.Gw, r.Priority, r.Family, r.Table, r.Type)
}
type Rule struct {
Priority int
Family int

View File

@@ -1,86 +0,0 @@
package socks5
import "fmt"
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
type authMethod byte
const (
authNotRequired authMethod = 0
authGssapi authMethod = 1
authUsernamePassword authMethod = 2
authNotAcceptable authMethod = 255
)
func (a authMethod) String() string {
switch a {
case authNotRequired:
return "no authentication required"
case authGssapi:
return "GSSAPI"
case authUsernamePassword:
return "username/password"
case authNotAcceptable:
return "no acceptable methods"
default:
return fmt.Sprintf("unknown method (%d)", a)
}
}
// Subnegotiation version
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
const (
authUsernamePasswordSubNegotiation1 byte = 1
)
// SOCKS versions.
const (
socks5Version byte = 5
)
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4
type cmdType byte
const (
connect cmdType = 1
bind cmdType = 2
udpAssociate cmdType = 3
)
func (c cmdType) String() string {
switch c {
case connect:
return "connect"
case bind:
return "bind"
case udpAssociate:
return "UDP associate"
default:
return fmt.Sprintf("unknown command (%d)", c)
}
}
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 and
// https://datatracker.ietf.org/doc/html/rfc1928#section-5
type addrType byte
const (
ipv4 addrType = 1
domainName addrType = 3
ipv6 addrType = 4
)
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
type replyCode byte
const (
succeeded replyCode = iota
generalServerFailure
connectionNotAllowedByRuleset
networkUnreachable
hostUnreachable
connectionRefused
ttlExpired
commandNotSupported
addressTypeNotSupported
)

View File

@@ -1,6 +0,0 @@
package socks5
type Logger interface {
Infof(format string, a ...interface{})
Warnf(format string, a ...interface{})
}

View File

@@ -1,103 +0,0 @@
package socks5
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
)
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) {
_, err := writer.Write([]byte{
socksVersion,
byte(reply),
0, // RSV byte
// TODO do we need to set the bind addr type to 0??
})
if err != nil {
c.logger.Warnf("failed writing failed response: %s", err)
}
}
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
func (c *socksConn) encodeSuccessResponse(writer io.Writer, socksVersion byte,
reply replyCode, bindAddrType addrType, bindAddress string,
bindPort uint16) (err error) {
bindData, err := encodeBindData(bindAddrType, bindAddress, bindPort)
if err != nil { // TODO encode with below block if this changes
return err
}
const initialPacketLength = 3
capacity := initialPacketLength + len(bindData)
packet := make([]byte, initialPacketLength, capacity)
packet[0] = socksVersion
packet[1] = byte(reply)
packet[2] = 0 // RSV byte
packet = append(packet, bindData...)
_, err = writer.Write(packet)
if err != nil {
c.logger.Warnf("failed writing success response: %s", err)
}
return nil
}
var (
ErrIPVersionUnexpected = errors.New("ip version is unexpected")
ErrDomainNameTooLong = errors.New("domain name is too long")
)
func encodeBindData(addrType addrType, address string, port uint16) (
data []byte, err error) {
capacity := bindDataLength(addrType, address)
data = make([]byte, 0, capacity)
data = append(data, byte(addrType))
switch addrType {
case ipv4, ipv6:
ip, err := netip.ParseAddr(address)
if err != nil {
return nil, fmt.Errorf("parsing IP address: %w", err)
}
switch {
case addrType == ipv4 && !ip.Is4():
return nil, fmt.Errorf("%w: expected IPv4 for %s", ErrIPVersionUnexpected, ip)
case addrType == ipv6 && !ip.Is6():
return nil, fmt.Errorf("%w: expected IPv6 for %s", ErrIPVersionUnexpected, ip)
}
data = append(data, ip.AsSlice()...)
case domainName:
const maxDomainNameLength = 255
if len(address) > maxDomainNameLength {
return nil, fmt.Errorf("%w: %s", ErrDomainNameTooLong, address)
}
data = append(data, byte(len(address)))
data = append(data, []byte(address)...)
default:
panic(fmt.Sprintf("unsupported address type %d", addrType))
}
data = binary.BigEndian.AppendUint16(data, port)
return data, nil
}
func bindDataLength(addrType addrType, address string) (maxLength int) {
maxLength++ // address type
switch addrType {
case ipv4:
maxLength += net.IPv4len
case domainName:
maxLength++ // domain name length
maxLength += len([]byte(address))
case ipv6:
maxLength += net.IPv6len
default:
panic("unsupported address type: " + fmt.Sprint(addrType))
}
maxLength += 2 // port
return maxLength
}

View File

@@ -1,105 +0,0 @@
package socks5
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
)
type Server struct {
username string
password string
address string
logger Logger
// internal fields
listener net.Listener
listening atomic.Bool
socksConnCtx context.Context //nolint:containedctx
socksConnCancel context.CancelFunc
done <-chan struct{}
stopping atomic.Bool
}
func New(settings Settings) *Server {
return &Server{
username: settings.Username,
password: settings.Password,
address: settings.Address,
logger: settings.Logger,
}
}
func (s *Server) Start(_ context.Context) (runErr <-chan error, err error) {
s.listener, err = net.Listen("tcp", s.address)
if err != nil {
return nil, fmt.Errorf("listening on %s: %w", s.address, err)
}
s.listening.Store(true)
s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background())
ready := make(chan struct{})
runErrCh := make(chan error)
runErr = runErrCh
done := make(chan struct{})
s.done = done
go s.runServer(ready, runErrCh, done)
<-ready
return runErr, nil
}
func (s *Server) runServer(ready chan<- struct{},
runErrCh chan<- error, done chan<- struct{}) {
close(ready)
defer close(done)
wg := new(sync.WaitGroup)
defer wg.Wait()
dialer := &net.Dialer{}
for {
connection, err := s.listener.Accept()
if err != nil {
if !s.stopping.Load() {
_ = s.Stop()
runErrCh <- fmt.Errorf("accepting connection: %w", err)
}
return
}
wg.Add(1)
go func(ctx context.Context, connection net.Conn,
dialer *net.Dialer, wg *sync.WaitGroup) {
defer wg.Done()
socksConn := &socksConn{
dialer: dialer,
username: s.username,
password: s.password,
clientConn: connection,
logger: s.logger,
}
err := socksConn.run(ctx)
if err != nil {
s.logger.Infof("running socks connection: %s", err)
}
}(s.socksConnCtx, connection, dialer, wg)
}
}
func (s *Server) Stop() (err error) {
s.stopping.Store(true)
s.listening.Store(false)
err = s.listener.Close()
s.socksConnCancel() // stop ongoing socks connections
<-s.done // wait for run goroutine to finish
s.stopping.Store(false)
return err
}
func (s *Server) listeningAddress() net.Addr {
if s.listening.Load() {
return s.listener.Addr()
}
return nil
}

View File

@@ -1,8 +0,0 @@
package socks5
type Settings struct {
Username string
Password string
Address string
Logger Logger
}

View File

@@ -1,283 +0,0 @@
package socks5
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"strings"
)
type socksConn struct {
// Injected fields
dialer *net.Dialer
username string
password string
clientConn net.Conn
logger Logger
}
func (c *socksConn) closeClientConn(ctxErr error) {
err := c.clientConn.Close()
if err != nil && ctxErr == nil {
c.logger.Warnf("closing client connection: %s", err)
}
}
func (c *socksConn) run(ctx context.Context) error {
authMethod := authNotRequired
if c.username != "" || c.password != "" {
authMethod = authUsernamePassword
}
err := verifyFirstNegotiation(c.clientConn, authMethod)
if err != nil {
replyMethod := authMethod
if errors.Is(err, ErrNoMethodIdentifiers) || errors.Is(err, ErrNoValidMethodIdentifier) {
replyMethod = authNotAcceptable
}
_, writeErr := c.clientConn.Write([]byte{socks5Version, byte(replyMethod)})
if writeErr != nil {
c.logger.Warnf("failed writing first negotiation reply: %s", writeErr)
}
c.closeClientConn(ctx.Err())
return fmt.Errorf("verifying first negotiation: %w", err)
}
_, err = c.clientConn.Write([]byte{socks5Version, byte(authMethod)})
if err != nil {
c.closeClientConn(ctx.Err())
return fmt.Errorf("writing first negotiation reply: %w", err)
}
switch authMethod {
case authNotRequired, authNotAcceptable:
case authGssapi:
panic("not implemented")
// TODO
case authUsernamePassword:
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
err = usernamePasswordSubnegotiate(c.clientConn, c.username, c.password)
if err != nil {
// If the server returns a `failure' (STATUS value other than X'00') status,
// it MUST close the connection.
c.closeClientConn(ctx.Err())
return fmt.Errorf("subnegotiating username and password: %w", err)
}
default:
panic(fmt.Sprintf("unimplemented auth method %d", authMethod))
}
err = c.handleRequest(ctx)
c.closeClientConn(ctx.Err())
if err != nil {
return fmt.Errorf("handling request: %w", err)
}
return nil
}
var (
ErrCommandNotSupported = errors.New("command not supported")
)
func (c *socksConn) handleRequest(ctx context.Context) error {
const socksVersion = socks5Version
request, err := decodeRequest(c.clientConn, socksVersion)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return err
}
if request.command != connect {
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
return fmt.Errorf("%w: %s", ErrCommandNotSupported, request.command)
}
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return err
}
defer destinationConn.Close()
destinationServerAddress := destinationConn.LocalAddr().String()
destinationAddr, destinationPortStr, err := net.SplitHostPort(destinationServerAddress)
fmt.Println("===", destinationServerAddress)
if err != nil {
return err
}
destinationPort, err := strconv.Atoi(destinationPortStr)
if err != nil {
return err
}
var bindAddrType addrType
if ip := net.ParseIP(destinationAddr); ip != nil {
if ip.To4() != nil {
bindAddrType = ipv4
} else {
bindAddrType = ipv6
}
} else {
bindAddrType = domainName
}
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded, bindAddrType,
destinationAddr, uint16(destinationPort))
if err != nil {
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
return fmt.Errorf("writing successful %s response: %w", request.command, err)
}
errc := make(chan error)
go func() {
_, err := io.Copy(c.clientConn, destinationConn)
if err != nil {
err = fmt.Errorf("from backend to client: %w", err)
}
errc <- err
}()
go func() {
_, err := io.Copy(destinationConn, c.clientConn)
if err != nil {
err = fmt.Errorf("from client to backend: %w", err)
}
errc <- err
}()
select {
case err := <-errc:
return err
case <-ctx.Done():
_ = destinationConn.Close()
_ = c.clientConn.Close()
return nil
}
}
var (
ErrVersionNotSupported = errors.New("version not supported")
ErrNoMethodIdentifiers = errors.New("no method identifiers")
)
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
const headerLength = 2 // version + nMethods bytes
header := make([]byte, headerLength)
_, err := io.ReadFull(reader, header[:])
if err != nil {
return fmt.Errorf("reading header: %w", err)
}
if header[0] != socks5Version {
return fmt.Errorf("%w: %d", ErrVersionNotSupported, header[0])
}
nMethods := header[1]
if nMethods == 0 {
return fmt.Errorf("%w", ErrNoMethodIdentifiers)
}
methodIdentifiers := make([]byte, nMethods)
_, err = io.ReadFull(reader, methodIdentifiers)
if err != nil {
return fmt.Errorf("reading method identifiers: %w", err)
}
for _, methodIdentifier := range methodIdentifiers {
if methodIdentifier == byte(requiredMethod) {
return nil
}
}
return makeNoAcceptableMethodError(requiredMethod, methodIdentifiers)
}
var (
ErrNoValidMethodIdentifier = errors.New("no valid method identifier")
)
func makeNoAcceptableMethodError(requiredAuthMethod authMethod, methodIdentifiers []byte) error {
methodNames := make([]string, len(methodIdentifiers))
for i, methodIdentifier := range methodIdentifiers {
methodNames[i] = fmt.Sprintf("%q", authMethod(methodIdentifier))
}
return fmt.Errorf("%w: none of %s matches %s",
ErrNoValidMethodIdentifier, strings.Join(methodNames, ", "),
requiredAuthMethod)
}
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4
type request struct {
command cmdType
destination string
port uint16
addressType addrType
}
var (
ErrRequestSocksVersionMismatch = errors.New("request SOCKS version mismatch")
ErrAddressTypeNotSupported = errors.New("address type not supported")
)
func decodeRequest(reader io.Reader, expectedVersion byte) (req request, err error) {
const headerLength = 4
header := [headerLength]byte{}
_, err = io.ReadFull(reader, header[:])
if err != nil {
return request{}, fmt.Errorf("reading header: %w", err)
}
version := header[0]
if header[0] != expectedVersion {
return request{}, fmt.Errorf("%w: expected %d and got %d",
ErrRequestSocksVersionMismatch, expectedVersion, version)
}
req.command = cmdType(header[1])
// header[2] is RSV byte
req.addressType = addrType(header[3])
switch req.addressType {
case ipv4:
var ip [4]byte
_, err = io.ReadFull(reader, ip[:])
if err != nil {
return request{}, fmt.Errorf("reading IPv4 address: %w", err)
}
req.destination = netip.AddrFrom4(ip).String()
case ipv6:
var ip [16]byte
_, err = io.ReadFull(reader, ip[:])
if err != nil {
return request{}, fmt.Errorf("reading IPv6 address: %w", err)
}
req.destination = netip.AddrFrom16(ip).String()
case domainName:
var header [1]byte
_, err = io.ReadFull(reader, header[:])
if err != nil {
return request{}, fmt.Errorf("reading domain name header: %w", err)
}
domainName := make([]byte, header[0])
_, err = io.ReadFull(reader, domainName)
if err != nil {
return request{}, fmt.Errorf("reading domain name bytes: %w", err)
}
req.destination = string(domainName)
default:
return request{}, fmt.Errorf("%w: %d", ErrAddressTypeNotSupported, req.addressType)
}
var portBytes [2]byte
_, err = io.ReadFull(reader, portBytes[:])
if err != nil {
return request{}, fmt.Errorf("reading port: %w", err)
}
req.port = binary.BigEndian.Uint16(portBytes[:])
return req, nil
}

View File

@@ -1,175 +0,0 @@
package socks5
import (
"context"
"fmt"
"io"
"net"
"testing"
"time"
"github.com/qdm12/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/proxy"
)
func Test(t *testing.T) {
server := New(Settings{
Username: "test",
Password: "test",
Address: ":8000",
Logger: log.New(),
})
runErr, startErr := server.Start(context.Background())
require.NoError(t, startErr)
select {
case err := <-runErr:
require.NoError(t, err)
default:
}
t.Log("SlEEPING")
time.Sleep(15 * time.Second)
t.Log("Done sleeping")
err := server.Stop()
require.NoError(t, err)
}
func backendServer(listener net.Listener) {
conn, err := listener.Accept()
if err != nil {
panic(err)
}
conn.Write([]byte("Test"))
conn.Close()
listener.Close()
}
func TestRead(t *testing.T) {
// backend server which we'll use SOCKS5 to connect to
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
backendServerPort := listener.Addr().(*net.TCPAddr).Port
go backendServer(listener)
// SOCKS5 server
server := New(Settings{
Address: ":0",
})
_, err = server.Start(context.Background())
require.NoError(t, err)
t.Cleanup(func() {
err = server.Stop()
assert.NoError(t, err)
})
socks5Port := server.listeningAddress().(*net.TCPAddr).Port
addr := fmt.Sprintf("localhost:%d", socks5Port)
socksDialer, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct)
if err != nil {
t.Fatal(err)
}
addr = fmt.Sprintf("localhost:%d", backendServerPort)
conn, err := socksDialer.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 4)
_, err = io.ReadFull(conn, buf)
if err != nil {
t.Fatal(err)
}
if string(buf) != "Test" {
t.Fatalf("got: %q want: Test", buf)
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}
func TestReadPassword(t *testing.T) {
// backend server which we'll use SOCKS5 to connect to
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
backendServerPort := ln.Addr().(*net.TCPAddr).Port
go backendServer(ln)
auth := &proxy.Auth{User: "foo", Password: "bar"}
server := Server{
logger: log.New(),
username: auth.User,
password: auth.Password,
address: ":0",
}
_, err = server.Start(context.Background())
require.NoError(t, err)
t.Cleanup(func() {
err = server.Stop()
assert.NoError(t, err)
})
addr := fmt.Sprintf("localhost:%d", server.listeningAddress().(*net.TCPAddr).Port)
if d, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct); err != nil {
t.Fatal(err)
} else {
if _, err := d.Dial("tcp", addr); err == nil {
t.Fatal("expected no-auth dial error")
}
}
badPwd := &proxy.Auth{User: "foo", Password: "not right"}
if d, err := proxy.SOCKS5("tcp", addr, badPwd, proxy.Direct); err != nil {
t.Fatal(err)
} else {
if _, err := d.Dial("tcp", addr); err == nil {
t.Fatal("expected bad password dial error")
}
}
badUsr := &proxy.Auth{User: "not right", Password: "bar"}
if d, err := proxy.SOCKS5("tcp", addr, badUsr, proxy.Direct); err != nil {
t.Fatal(err)
} else {
if _, err := d.Dial("tcp", addr); err == nil {
t.Fatal("expected bad username dial error")
}
}
socksDialer, err := proxy.SOCKS5("tcp", addr, auth, proxy.Direct)
if err != nil {
t.Fatal(err)
}
addr = fmt.Sprintf("localhost:%d", backendServerPort)
conn, err := socksDialer.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatal(err)
}
if string(buf) != "Test" {
t.Fatalf("got: %q want: Test", buf)
}
if err := conn.Close(); err != nil {
t.Fatal(err)
}
}

View File

@@ -1,69 +0,0 @@
package socks5
import (
"errors"
"fmt"
"io"
)
var (
ErrSubnegotiationVersionNotSupported = errors.New("subnegotiation version not supported")
ErrUsernameNotValid = errors.New("username not valid")
ErrPasswordNotValid = errors.New("password not valid")
)
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
func usernamePasswordSubnegotiate(conn io.ReadWriter, username, password string) (err error) {
status := byte(1)
const defaultVersion = byte(1)
const headerLength = 2
var header [headerLength]byte
_, err = io.ReadFull(conn, header[:])
if err != nil {
_, _ = conn.Write([]byte{defaultVersion, status})
return fmt.Errorf("reading header: %w", err)
}
if header[0] != authUsernamePasswordSubNegotiation1 {
_, _ = conn.Write([]byte{defaultVersion, status})
return fmt.Errorf("%w: %d", ErrSubnegotiationVersionNotSupported, header[0])
}
version := header[0]
usernameBytes := make([]byte, header[1])
_, err = io.ReadFull(conn, usernameBytes)
if err != nil {
_, _ = conn.Write([]byte{version, status})
return fmt.Errorf("reading username bytes: %w", err)
} else if username != string(usernameBytes) {
_, _ = conn.Write([]byte{version, status})
return fmt.Errorf("%w: %s", ErrUsernameNotValid, string(usernameBytes))
}
const passwordHeaderLength = 1
passwordHeader := make([]byte, passwordHeaderLength)
_, err = io.ReadFull(conn, passwordHeader[:])
if err != nil {
_, _ = conn.Write([]byte{version, status})
return fmt.Errorf("reading password length: %w", err)
}
passwordBytes := make([]byte, passwordHeader[0])
_, err = io.ReadFull(conn, passwordBytes)
if err != nil {
_, _ = conn.Write([]byte{version, status})
return fmt.Errorf("reading password bytes: %w", err)
} else if password != string(passwordBytes) {
_, _ = conn.Write([]byte{version, status})
return fmt.Errorf("%w: %s", ErrPasswordNotValid, string(passwordBytes))
}
status = 0
_, err = conn.Write([]byte{version, status})
if err != nil {
return fmt.Errorf("writing success status: %w", err)
}
return nil
}

View File

@@ -67,6 +67,7 @@ type NetLinker interface {
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {

View File

@@ -38,7 +38,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard
vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.routing, l.fw,
providerConf, settings, l.ipv6Supported, subLogger)
}
if err != nil {

View File

@@ -13,7 +13,7 @@ import (
// setupWireguard sets Wireguard up using the configurators and settings given.
// It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupWireguard(ctx context.Context, netlinker NetLinker,
func setupWireguard(ctx context.Context, netlinker NetLinker, routing Routing,
fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error,
@@ -29,7 +29,7 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
logger.Debug("Wireguard client private key: " + gosettings.ObfuscateKey(wireguardSettings.PrivateKey))
logger.Debug("Wireguard pre-shared key: " + gosettings.ObfuscateKey(wireguardSettings.PreSharedKey))
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
wireguarder, err = wireguard.New(wireguardSettings, netlinker, routing, logger)
if err != nil {
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
}

View File

@@ -4,10 +4,11 @@ type Wireguard struct {
logger Logger
settings Settings
netlink NetLinker
routing Routing
}
func New(settings Settings, netlink NetLinker,
logger Logger,
routing Routing, logger Logger,
) (w *Wireguard, err error) {
settings.SetDefaults()
if err := settings.Check(); err != nil {
@@ -18,5 +19,6 @@ func New(settings Settings, netlink NetLinker,
logger: logger,
settings: settings,
netlink: netlink,
routing: routing,
}, nil
}

View File

@@ -0,0 +1,7 @@
package wireguard
import "net/netip"
type Routing interface {
VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error)
}

View File

@@ -1,6 +1,8 @@
package wireguard
import "github.com/qdm12/gluetun/internal/netlink"
import (
"github.com/qdm12/gluetun/internal/netlink"
)
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
@@ -15,6 +17,7 @@ type NetLinker interface {
type Router interface {
RouteList(family int) (routes []netlink.Route, err error)
RouteAdd(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {

View File

@@ -1,6 +1,7 @@
package wireguard
import (
"errors"
"fmt"
"net/netip"
"strings"
@@ -29,6 +30,10 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
return nil
}
var (
ErrDefaultRouteNotFound = errors.New("default route not found")
)
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark uint32,
) (err error) {
@@ -45,5 +50,39 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
link.Name, dst, firewallMark, err)
}
vpnGatewayIP, err := w.routing.VPNLocalGatewayIP(link.Name)
if err != nil {
return fmt.Errorf("getting VPN gateway IP: %w", err)
}
routes, err := w.netlink.RouteList(netlink.FamilyV4)
if err != nil {
return fmt.Errorf("listing routes: %w", err)
}
var defaultRoute netlink.Route
var defaultRouteFound bool
for _, route = range routes {
if !route.Dst.IsValid() || route.Dst.Addr().IsUnspecified() {
defaultRoute = route
defaultRouteFound = true
break
}
}
if !defaultRouteFound {
return fmt.Errorf("%w: in %d routes", ErrDefaultRouteNotFound, len(routes))
}
// Equivalent replacement to:
// ip route replace default via <vpn-gateway> dev tun0
defaultRoute.Gw = vpnGatewayIP
defaultRoute.LinkIndex = link.Index
err = w.netlink.RouteReplace(defaultRoute)
if err != nil {
return fmt.Errorf("replacing default route: %w", err)
}
return err
}