Wireguard support for Mullvad and Windscribe (#565)
- `internal/wireguard` client package with unit tests - Implementation works with kernel space or user space if unavailable - `WIREGUARD_PRIVATE_KEY` - `WIREGUARD_ADDRESS` - `WIREGUARD_PRESHARED_KEY` - `WIREGUARD_PORT` - `internal/netlink` package used by `internal/wireguard`
This commit is contained in:
25
internal/wireguard/address.go
Normal file
25
internal/wireguard/address.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addAddresses(link netlink.Link,
|
||||
addresses []*net.IPNet) (err error) {
|
||||
for _, ipNet := range addresses {
|
||||
address := &netlink.Addr{
|
||||
IPNet: ipNet,
|
||||
}
|
||||
|
||||
err = w.netlink.AddrAdd(link, address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when adding address %s to link %s",
|
||||
err, address, link.Attrs().Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
94
internal/wireguard/address_test.go
Normal file
94
internal/wireguard/address_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
|
||||
ipNetTwo := &net.IPNet{IP: net.IPv4(4, 5, 6, 7), Mask: net.IPv4Mask(255, 255, 255, 128)}
|
||||
|
||||
newLink := func() netlink.Link {
|
||||
linkAttrs := netlink.NewLinkAttrs()
|
||||
linkAttrs.Name = "a_bridge"
|
||||
return &netlink.Bridge{
|
||||
LinkAttrs: linkAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
link netlink.Link
|
||||
addrs []*net.IPNet
|
||||
expectedAddrs []*netlink.Addr
|
||||
addrAddErrs []error
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
link: newLink(),
|
||||
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
||||
expectedAddrs: []*netlink.Addr{
|
||||
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
|
||||
},
|
||||
addrAddErrs: []error{nil, nil},
|
||||
},
|
||||
"first add error": {
|
||||
link: newLink(),
|
||||
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
||||
expectedAddrs: []*netlink.Addr{
|
||||
{IPNet: ipNetOne},
|
||||
},
|
||||
addrAddErrs: []error{errDummy},
|
||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
|
||||
},
|
||||
"second add error": {
|
||||
link: newLink(),
|
||||
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
||||
expectedAddrs: []*netlink.Addr{
|
||||
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
|
||||
},
|
||||
addrAddErrs: []error{nil, errDummy},
|
||||
err: errors.New("dummy: when adding address 4.5.6.7/25 to link a_bridge"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
require.Equal(t, len(testCase.expectedAddrs), len(testCase.addrAddErrs))
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
for i := range testCase.expectedAddrs {
|
||||
netLinker.EXPECT().
|
||||
AddrAdd(testCase.link, testCase.expectedAddrs[i]).
|
||||
Return(testCase.addrAddErrs[i])
|
||||
}
|
||||
|
||||
err := wg.addAddresses(testCase.link, testCase.addrs)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
59
internal/wireguard/cleanup.go
Normal file
59
internal/wireguard/cleanup.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package wireguard
|
||||
|
||||
import "sort"
|
||||
|
||||
type closer struct {
|
||||
operation string
|
||||
step step
|
||||
close func() error
|
||||
closed bool
|
||||
}
|
||||
|
||||
type closers []closer
|
||||
|
||||
func (c *closers) add(operation string, step step,
|
||||
closeFunc func() error) {
|
||||
closer := closer{
|
||||
operation: operation,
|
||||
step: step,
|
||||
close: closeFunc,
|
||||
}
|
||||
*c = append(*c, closer)
|
||||
}
|
||||
|
||||
func (c *closers) cleanup(logger Logger) {
|
||||
closers := *c
|
||||
|
||||
sort.Slice(closers, func(i, j int) bool {
|
||||
return closers[i].step < closers[j].step
|
||||
})
|
||||
|
||||
for i, closer := range closers {
|
||||
if closer.closed {
|
||||
continue
|
||||
} else {
|
||||
closers[i].closed = true
|
||||
}
|
||||
logger.Debug(closer.operation + "...")
|
||||
err := closer.close()
|
||||
if err != nil {
|
||||
logger.Error("failed " + closer.operation + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type step int
|
||||
|
||||
const (
|
||||
// stepOne closes the wireguard controller client,
|
||||
// and removes the IP rule.
|
||||
stepOne step = iota
|
||||
// stepTwo closes the UAPI listener.
|
||||
stepTwo
|
||||
// stepThree closes the UAPI file.
|
||||
stepThree
|
||||
// stepFour closes the Wireguard device.
|
||||
stepFour
|
||||
// stepFive closes the bind connection and the TUN device file.
|
||||
stepFive
|
||||
)
|
||||
57
internal/wireguard/cleanup_test.go
Normal file
57
internal/wireguard/cleanup_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_closers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var ACloseCalled, BCloseCalled, CCloseCalled bool
|
||||
var (
|
||||
AErr error
|
||||
BErr = errors.New("B failed")
|
||||
CErr = errors.New("C failed")
|
||||
)
|
||||
|
||||
var closers closers
|
||||
closers.add("closing A", stepFive, func() error {
|
||||
ACloseCalled = true
|
||||
return AErr
|
||||
})
|
||||
|
||||
closers.add("closing B", stepThree, func() error {
|
||||
BCloseCalled = true
|
||||
return BErr
|
||||
})
|
||||
|
||||
closers.add("closing C", stepTwo, func() error {
|
||||
CCloseCalled = true
|
||||
return CErr
|
||||
})
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
prevCall := logger.EXPECT().Debug("closing C...")
|
||||
prevCall = logger.EXPECT().Error("failed closing C: C failed").After(prevCall)
|
||||
prevCall = logger.EXPECT().Debug("closing B...").After(prevCall)
|
||||
prevCall = logger.EXPECT().Error("failed closing B: B failed").After(prevCall)
|
||||
logger.EXPECT().Debug("closing A...").After(prevCall)
|
||||
|
||||
closers.cleanup(logger)
|
||||
|
||||
closers.cleanup(logger) // run twice should not close already closed
|
||||
|
||||
for _, closer := range closers {
|
||||
assert.True(t, closer.closed)
|
||||
}
|
||||
|
||||
assert.True(t, ACloseCalled)
|
||||
assert.True(t, BCloseCalled)
|
||||
assert.True(t, CCloseCalled)
|
||||
}
|
||||
86
internal/wireguard/config.go
Normal file
86
internal/wireguard/config.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var (
|
||||
errMakeConfig = errors.New("cannot make device configuration")
|
||||
errConfigureDevice = errors.New("cannot configure device")
|
||||
)
|
||||
|
||||
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
||||
deviceConfig, err := makeDeviceConfig(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", errMakeConfig, err)
|
||||
}
|
||||
|
||||
err = client.ConfigureDevice(settings.InterfaceName, deviceConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", errConfigureDevice, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
|
||||
privateKey, err := wgtypes.ParseKey(settings.PrivateKey)
|
||||
if err != nil {
|
||||
return config, ErrPrivateKeyInvalid
|
||||
}
|
||||
|
||||
publicKey, err := wgtypes.ParseKey(settings.PublicKey)
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("%w: %s", ErrPublicKeyInvalid, settings.PublicKey)
|
||||
}
|
||||
|
||||
var preSharedKey *wgtypes.Key
|
||||
if settings.PreSharedKey != "" {
|
||||
preSharedKeyValue, err := wgtypes.ParseKey(settings.PreSharedKey)
|
||||
if err != nil {
|
||||
return config, ErrPreSharedKeyInvalid
|
||||
}
|
||||
preSharedKey = &preSharedKeyValue
|
||||
}
|
||||
|
||||
firewallMark := settings.FirewallMark
|
||||
|
||||
config = wgtypes.Config{
|
||||
PrivateKey: &privateKey,
|
||||
ReplacePeers: true,
|
||||
FirewallMark: &firewallMark,
|
||||
Peers: []wgtypes.PeerConfig{
|
||||
{
|
||||
PublicKey: publicKey,
|
||||
PresharedKey: preSharedKey,
|
||||
AllowedIPs: []net.IPNet{
|
||||
*allIPv4(),
|
||||
*allIPv6(),
|
||||
},
|
||||
ReplaceAllowedIPs: true,
|
||||
Endpoint: settings.Endpoint,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func allIPv4() (ipNet *net.IPNet) {
|
||||
return &net.IPNet{
|
||||
IP: net.IPv4(0, 0, 0, 0),
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
}
|
||||
}
|
||||
|
||||
func allIPv6() (ipNet *net.IPNet) {
|
||||
return &net.IPNet{
|
||||
IP: net.IPv6zero,
|
||||
Mask: []byte(net.IPv6zero),
|
||||
}
|
||||
}
|
||||
126
internal/wireguard/config_test.go
Normal file
126
internal/wireguard/config_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func Test_makeDeviceConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
validKey1 = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
|
||||
validKey2 = "aPjc9US5ICB30D1P4glR9tO7bkB2Ga+KZiFqnoypBHk="
|
||||
validKey3 = "gFIW0lTmBYEucynoIg+XmeWckDUXTcC4Po5ijR5G+HM="
|
||||
)
|
||||
|
||||
parseKey := func(t *testing.T, s string) *wgtypes.Key {
|
||||
t.Helper()
|
||||
key, err := wgtypes.ParseKey(s)
|
||||
require.NoError(t, err)
|
||||
return &key
|
||||
}
|
||||
|
||||
intPtr := func(n int) *int { return &n }
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
config wgtypes.Config
|
||||
err error
|
||||
}{
|
||||
"bad private key": {
|
||||
settings: Settings{
|
||||
PrivateKey: "bad key",
|
||||
},
|
||||
err: ErrPrivateKeyInvalid,
|
||||
},
|
||||
"bad public key": {
|
||||
settings: Settings{
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: "bad key",
|
||||
},
|
||||
err: errors.New("cannot parse public key: bad key"),
|
||||
},
|
||||
"bad pre-shared key": {
|
||||
settings: Settings{
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
PreSharedKey: "bad key",
|
||||
},
|
||||
err: errors.New("cannot parse pre-shared key"),
|
||||
},
|
||||
"valid settings": {
|
||||
settings: Settings{
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
PreSharedKey: validKey3,
|
||||
FirewallMark: 9876,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(99, 99, 99, 99),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
config: wgtypes.Config{
|
||||
PrivateKey: parseKey(t, validKey1),
|
||||
ReplacePeers: true,
|
||||
FirewallMark: intPtr(9876),
|
||||
Peers: []wgtypes.PeerConfig{
|
||||
{
|
||||
PublicKey: *parseKey(t, validKey2),
|
||||
PresharedKey: parseKey(t, validKey3),
|
||||
AllowedIPs: []net.IPNet{
|
||||
{
|
||||
IP: net.IPv4(0, 0, 0, 0),
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
},
|
||||
{
|
||||
IP: net.IPv6zero,
|
||||
Mask: []byte(net.IPv6zero),
|
||||
},
|
||||
},
|
||||
ReplaceAllowedIPs: true,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(99, 99, 99, 99),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := makeDeviceConfig(testCase.settings)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.config, config)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_allIPv4(t *testing.T) {
|
||||
t.Parallel()
|
||||
ipNet := allIPv4()
|
||||
assert.Equal(t, "0.0.0.0/0", ipNet.String())
|
||||
}
|
||||
|
||||
func Test_allIPv6(t *testing.T) {
|
||||
t.Parallel()
|
||||
ipNet := allIPv6()
|
||||
assert.Equal(t, "::/0", ipNet.String())
|
||||
}
|
||||
30
internal/wireguard/constructor.go
Normal file
30
internal/wireguard/constructor.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package wireguard
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/netlink"
|
||||
|
||||
var _ Wireguarder = (*Wireguard)(nil)
|
||||
|
||||
type Wireguarder interface {
|
||||
Runner
|
||||
Runner
|
||||
}
|
||||
|
||||
type Wireguard struct {
|
||||
logger Logger
|
||||
settings Settings
|
||||
netlink netlink.NetLinker
|
||||
}
|
||||
|
||||
func New(settings Settings, netlink NetLinker,
|
||||
logger Logger) (w *Wireguard, err error) {
|
||||
settings.SetDefaults()
|
||||
if err := settings.Check(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Wireguard{
|
||||
logger: logger,
|
||||
settings: settings,
|
||||
netlink: netlink,
|
||||
}, nil
|
||||
}
|
||||
80
internal/wireguard/constructor_test.go
Normal file
80
internal/wireguard/constructor_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_New(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const validKeyString = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
|
||||
logger := NewMockLogger(nil)
|
||||
netLinker := NewMockNetLinker(nil)
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
wireguard *Wireguard
|
||||
err error
|
||||
}{
|
||||
"bad settings": {
|
||||
settings: Settings{
|
||||
PrivateKey: "",
|
||||
},
|
||||
err: ErrPrivateKeyMissing,
|
||||
},
|
||||
"minimal valid settings": {
|
||||
settings: Settings{
|
||||
PrivateKey: validKeyString,
|
||||
PublicKey: validKeyString,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
},
|
||||
Addresses: []*net.IPNet{{
|
||||
IP: net.IPv4(5, 6, 7, 8),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255)},
|
||||
},
|
||||
FirewallMark: 100,
|
||||
},
|
||||
wireguard: &Wireguard{
|
||||
logger: logger,
|
||||
netlink: netLinker,
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKeyString,
|
||||
PublicKey: validKeyString,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{
|
||||
IP: net.IPv4(5, 6, 7, 8),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255)},
|
||||
},
|
||||
FirewallMark: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wireguard, err := New(testCase.settings, netLinker, logger)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.wireguard, wireguard)
|
||||
})
|
||||
}
|
||||
}
|
||||
26
internal/wireguard/log.go
Normal file
26
internal/wireguard/log.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=log_mock_test.go -package wireguard . Logger
|
||||
|
||||
type Logger interface {
|
||||
Debug(s string)
|
||||
Info(s string)
|
||||
Error(s string)
|
||||
}
|
||||
|
||||
func makeDeviceLogger(logger Logger) (deviceLogger *device.Logger) {
|
||||
return &device.Logger{
|
||||
Verbosef: func(format string, args ...interface{}) {
|
||||
logger.Debug(fmt.Sprintf(format, args...))
|
||||
},
|
||||
Errorf: func(format string, args ...interface{}) {
|
||||
logger.Error(fmt.Sprintf(format, args...))
|
||||
},
|
||||
}
|
||||
}
|
||||
70
internal/wireguard/log_mock_test.go
Normal file
70
internal/wireguard/log_mock_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/wireguard (interfaces: Logger)
|
||||
|
||||
// Package wireguard is a generated GoMock package.
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLogger is a mock of Logger interface.
|
||||
type MockLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockLoggerMockRecorder is the mock recorder for MockLogger.
|
||||
type MockLoggerMockRecorder struct {
|
||||
mock *MockLogger
|
||||
}
|
||||
|
||||
// NewMockLogger creates a new mock instance.
|
||||
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
|
||||
mock := &MockLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debug mocks base method.
|
||||
func (m *MockLogger) Debug(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Debug", arg0)
|
||||
}
|
||||
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Error mocks base method.
|
||||
func (m *MockLogger) Error(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Error", arg0)
|
||||
}
|
||||
|
||||
// Error indicates an expected call of Error.
|
||||
func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0)
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockLogger) Info(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Info", arg0)
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0)
|
||||
}
|
||||
23
internal/wireguard/log_test.go
Normal file
23
internal/wireguard/log_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
func Test_makeDeviceLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
logger := NewMockLogger(ctrl)
|
||||
|
||||
deviceLogger := makeDeviceLogger(logger)
|
||||
|
||||
logger.EXPECT().Debug("test 1")
|
||||
deviceLogger.Verbosef("test %d", 1)
|
||||
|
||||
logger.EXPECT().Error("test 2")
|
||||
deviceLogger.Errorf("test %d", 2)
|
||||
}
|
||||
113
internal/wireguard/netlink_integration_test.go
Normal file
113
internal/wireguard/netlink_integration_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// +build netlink
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
inetlink "github.com/qdm12/gluetun/internal/netlink"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := inetlink.New()
|
||||
wg := &Wireguard{
|
||||
netlink: netlinker,
|
||||
}
|
||||
|
||||
intfName := "test_" + fmt.Sprint(rand.Intn(10000)) //nolint:gosec
|
||||
|
||||
// Add link
|
||||
linkAttrs := netlink.NewLinkAttrs()
|
||||
linkAttrs.Name = intfName
|
||||
link := &netlink.Bridge{
|
||||
LinkAttrs: linkAttrs,
|
||||
}
|
||||
err := netlink.LinkAdd(link)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err = netlink.LinkDel(link)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
addresses := []*net.IPNet{
|
||||
{IP: net.IP{1, 2, 3, 4}, Mask: net.IPv4Mask(255, 255, 255, 255)},
|
||||
{IP: net.IP{5, 6, 7, 8}, Mask: net.IPv4Mask(255, 255, 255, 255)},
|
||||
}
|
||||
|
||||
// Success
|
||||
err = wg.addAddresses(link, addresses)
|
||||
require.NoError(t, err)
|
||||
|
||||
netlinkAddresses, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(addresses), len(netlinkAddresses))
|
||||
for i, netlinkAddress := range netlinkAddresses {
|
||||
ipNet := netlinkAddress.IPNet
|
||||
assert.Equal(t, addresses[i], ipNet)
|
||||
}
|
||||
|
||||
// Existing address cannot be added
|
||||
err = wg.addAddresses(link, addresses)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "file exists: when adding address 1.2.3.4/32 to link test_8081", err.Error())
|
||||
}
|
||||
|
||||
func Test_netlink_Wireguard_addRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
netlinker := inetlink.New()
|
||||
wg := &Wireguard{
|
||||
netlink: netlinker,
|
||||
}
|
||||
|
||||
rulePriority := 10000
|
||||
const firewallMark = 999
|
||||
|
||||
cleanup, err := wg.addRule(rulePriority, firewallMark)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := cleanup()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
rules, err := netlink.RuleList(netlink.FAMILY_ALL)
|
||||
require.NoError(t, err)
|
||||
var rule netlink.Rule
|
||||
var ruleFound bool
|
||||
for _, rule = range rules {
|
||||
if rule.Mark == firewallMark {
|
||||
ruleFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, ruleFound)
|
||||
expectedRule := netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
Mask: 4294967295,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
}
|
||||
assert.Equal(t, expectedRule, rule)
|
||||
|
||||
// Existing rule cannot be added
|
||||
nilCleanup, err := wg.addRule(rulePriority, firewallMark)
|
||||
if nilCleanup != nil {
|
||||
_ = nilCleanup() // in case it succeeds
|
||||
}
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "file exists: when adding rule: ip rule 10000: from <nil> table 999", err.Error())
|
||||
}
|
||||
12
internal/wireguard/netlinker.go
Normal file
12
internal/wireguard/netlinker.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package wireguard
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
|
||||
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
|
||||
|
||||
type NetLinker interface {
|
||||
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
||||
RouteAdd(route *netlink.Route) error
|
||||
RuleAdd(rule *netlink.Rule) error
|
||||
RuleDel(rule *netlink.Rule) error
|
||||
}
|
||||
91
internal/wireguard/netlinker_mock_test.go
Normal file
91
internal/wireguard/netlinker_mock_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/wireguard (interfaces: NetLinker)
|
||||
|
||||
// Package wireguard is a generated GoMock package.
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
netlink "github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// MockNetLinker is a mock of NetLinker interface.
|
||||
type MockNetLinker struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockNetLinkerMockRecorder
|
||||
}
|
||||
|
||||
// MockNetLinkerMockRecorder is the mock recorder for MockNetLinker.
|
||||
type MockNetLinkerMockRecorder struct {
|
||||
mock *MockNetLinker
|
||||
}
|
||||
|
||||
// NewMockNetLinker creates a new mock instance.
|
||||
func NewMockNetLinker(ctrl *gomock.Controller) *MockNetLinker {
|
||||
mock := &MockNetLinker{ctrl: ctrl}
|
||||
mock.recorder = &MockNetLinkerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddrAdd mocks base method.
|
||||
func (m *MockNetLinker) AddrAdd(arg0 netlink.Link, arg1 *netlink.Addr) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddrAdd", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddrAdd indicates an expected call of AddrAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) AddrAdd(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1)
|
||||
}
|
||||
|
||||
// RouteAdd mocks base method.
|
||||
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RouteAdd", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RouteAdd indicates an expected call of RouteAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0)
|
||||
}
|
||||
|
||||
// RuleAdd mocks base method.
|
||||
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RuleAdd", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RuleAdd indicates an expected call of RuleAdd.
|
||||
func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleAdd", reflect.TypeOf((*MockNetLinker)(nil).RuleAdd), arg0)
|
||||
}
|
||||
|
||||
// RuleDel mocks base method.
|
||||
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RuleDel", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RuleDel indicates an expected call of RuleDel.
|
||||
func (mr *MockNetLinkerMockRecorder) RuleDel(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuleDel", reflect.TypeOf((*MockNetLinker)(nil).RuleDel), arg0)
|
||||
}
|
||||
26
internal/wireguard/route.go
Normal file
26
internal/wireguard/route.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// TODO add IPv6 route if IPv6 is supported
|
||||
|
||||
func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
|
||||
firewallMark int) (err error) {
|
||||
route := &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: dst,
|
||||
Table: firewallMark,
|
||||
}
|
||||
|
||||
err = w.netlink.RouteAdd(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when adding route: %s", err, route)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
85
internal/wireguard/route_test.go
Normal file
85
internal/wireguard/route_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const linkIndex = 88
|
||||
newLink := func() netlink.Link {
|
||||
linkAttrs := netlink.NewLinkAttrs()
|
||||
linkAttrs.Name = "a_bridge"
|
||||
linkAttrs.Index = linkIndex
|
||||
return &netlink.Bridge{
|
||||
LinkAttrs: linkAttrs,
|
||||
}
|
||||
}
|
||||
ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
|
||||
const firewallMark = 51820
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
link netlink.Link
|
||||
dst *net.IPNet
|
||||
expectedRoute *netlink.Route
|
||||
routeAddErr error
|
||||
err error
|
||||
}{
|
||||
"success": {
|
||||
link: newLink(),
|
||||
dst: ipNet,
|
||||
expectedRoute: &netlink.Route{
|
||||
LinkIndex: linkIndex,
|
||||
Dst: ipNet,
|
||||
Table: firewallMark,
|
||||
},
|
||||
},
|
||||
"route add error": {
|
||||
link: newLink(),
|
||||
dst: ipNet,
|
||||
expectedRoute: &netlink.Route{
|
||||
LinkIndex: linkIndex,
|
||||
Dst: ipNet,
|
||||
Table: firewallMark,
|
||||
},
|
||||
routeAddErr: errDummy,
|
||||
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820}"), //nolint:lll
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
netLinker.EXPECT().
|
||||
RouteAdd(testCase.expectedRoute).
|
||||
Return(testCase.routeAddErr)
|
||||
|
||||
err := wg.addRoute(testCase.link, testCase.dst, firewallMark)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
28
internal/wireguard/rule.go
Normal file
28
internal/wireguard/rule.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (w *Wireguard) addRule(rulePriority, firewallMark int) (
|
||||
cleanup func() error, err error) {
|
||||
rule := netlink.NewRule()
|
||||
rule.Invert = true
|
||||
rule.Priority = rulePriority
|
||||
rule.Mark = firewallMark
|
||||
rule.Table = firewallMark
|
||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
||||
return nil, fmt.Errorf("%w: when adding rule: %s", err, rule)
|
||||
}
|
||||
|
||||
cleanup = func() error {
|
||||
err := w.netlink.RuleDel(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: when deleting rule: %s", err, rule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return cleanup, nil
|
||||
}
|
||||
106
internal/wireguard/rule_test.go
Normal file
106
internal/wireguard/rule_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func Test_Wireguard_addRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rulePriority = 987
|
||||
const firewallMark = 456
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
testCases := map[string]struct {
|
||||
expectedRule *netlink.Rule
|
||||
ruleAddErr error
|
||||
err error
|
||||
ruleDelErr error
|
||||
cleanupErr error
|
||||
}{
|
||||
"success": {
|
||||
expectedRule: &netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
},
|
||||
},
|
||||
"rule add error": {
|
||||
expectedRule: &netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
},
|
||||
ruleAddErr: errDummy,
|
||||
err: errors.New("dummy: when adding rule: ip rule 987: from <nil> table 456"),
|
||||
},
|
||||
"rule delete error": {
|
||||
expectedRule: &netlink.Rule{
|
||||
Invert: true,
|
||||
Priority: rulePriority,
|
||||
Mark: firewallMark,
|
||||
Table: firewallMark,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
},
|
||||
ruleDelErr: errDummy,
|
||||
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from <nil> table 456"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
netLinker := NewMockNetLinker(ctrl)
|
||||
wg := Wireguard{
|
||||
netlink: netLinker,
|
||||
}
|
||||
|
||||
netLinker.EXPECT().RuleAdd(testCase.expectedRule).
|
||||
Return(testCase.ruleAddErr)
|
||||
cleanup, err := wg.addRule(rulePriority, firewallMark)
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
netLinker.EXPECT().RuleDel(testCase.expectedRule).
|
||||
Return(testCase.ruleDelErr)
|
||||
err = cleanup()
|
||||
if testCase.cleanupErr != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.cleanupErr.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
165
internal/wireguard/run.go
Normal file
165
internal/wireguard/run.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCreateTun = errors.New("cannot create TUN device")
|
||||
ErrFindLink = errors.New("cannot find link")
|
||||
ErrFindDevice = errors.New("cannot find Wireguard device")
|
||||
ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
|
||||
ErrWgctrlOpen = errors.New("cannot open wgctrl")
|
||||
ErrUAPIListen = errors.New("cannot listen on UAPI socket")
|
||||
ErrAddAddress = errors.New("cannot add address to wireguard interface")
|
||||
ErrConfigure = errors.New("cannot configure wireguard interface")
|
||||
ErrIfaceUp = errors.New("cannot set the interface to UP")
|
||||
ErrRouteAdd = errors.New("cannot add route for interface")
|
||||
ErrRuleAdd = errors.New("cannot add rule for interface")
|
||||
ErrDeviceWaited = errors.New("device waited for")
|
||||
)
|
||||
|
||||
type Runner interface {
|
||||
Run(ctx context.Context, waitError chan<- error, ready chan<- struct{})
|
||||
}
|
||||
|
||||
// See https://git.zx2c4.com/wireguard-go/tree/main.go
|
||||
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
|
||||
return
|
||||
}
|
||||
|
||||
var closers closers
|
||||
closers.add("closing controller client", stepOne, client.Close)
|
||||
|
||||
defer closers.cleanup(w.logger)
|
||||
|
||||
tun, err := tun.CreateTUN(w.settings.InterfaceName, device.DefaultMTU)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||
return
|
||||
}
|
||||
|
||||
closers.add("closing TUN device", stepFive, tun.Close)
|
||||
|
||||
tunName, err := tun.Name()
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||
return
|
||||
} else if tunName != w.settings.InterfaceName {
|
||||
waitError <- fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||
ErrCreateTun, w.settings.InterfaceName, tunName)
|
||||
return
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(w.settings.InterfaceName)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s: %s", ErrFindLink, w.settings.InterfaceName, err)
|
||||
return
|
||||
}
|
||||
|
||||
bind := conn.NewDefaultBind()
|
||||
|
||||
closers.add("closing bind", stepFive, bind.Close)
|
||||
|
||||
deviceLogger := makeDeviceLogger(w.logger)
|
||||
device := device.NewDevice(tun, bind, deviceLogger)
|
||||
|
||||
closers.add("closing Wireguard device", stepFour, func() error {
|
||||
device.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
uapiFile, err := ipc.UAPIOpen(w.settings.InterfaceName)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||
return
|
||||
}
|
||||
|
||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
||||
|
||||
uapiListener, err := ipc.UAPIListen(w.settings.InterfaceName, uapiFile)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||
return
|
||||
}
|
||||
|
||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
||||
|
||||
// acceptAndHandle exits when uapiListener is closed
|
||||
uapiAcceptErrorCh := make(chan error)
|
||||
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
|
||||
|
||||
err = w.addAddresses(link, w.settings.Addresses)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = configureDevice(client, w.settings)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrConfigure, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
||||
return
|
||||
}
|
||||
|
||||
ruleCleanup, err := w.addRule(
|
||||
w.settings.RulePriority, w.settings.FirewallMark)
|
||||
if err != nil {
|
||||
waitError <- fmt.Errorf("%w: %s", ErrRuleAdd, err)
|
||||
return
|
||||
}
|
||||
closers.add("removing rule", stepOne, ruleCleanup)
|
||||
|
||||
w.logger.Info("Wireguard is up")
|
||||
ready <- struct{}{}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case err = <-uapiAcceptErrorCh:
|
||||
close(uapiAcceptErrorCh)
|
||||
case <-device.Wait():
|
||||
err = ErrDeviceWaited
|
||||
}
|
||||
|
||||
closers.cleanup(w.logger)
|
||||
|
||||
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit
|
||||
|
||||
waitError <- err
|
||||
}
|
||||
|
||||
func acceptAndHandle(uapi net.Listener, device *device.Device,
|
||||
uapiAcceptErrorCh chan<- error) {
|
||||
for { // stopped by uapiFile.Close()
|
||||
conn, err := uapi.Accept()
|
||||
if err != nil {
|
||||
uapiAcceptErrorCh <- err
|
||||
return
|
||||
}
|
||||
go device.IpcHandle(conn)
|
||||
}
|
||||
}
|
||||
212
internal/wireguard/settings.go
Normal file
212
internal/wireguard/settings.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
// Interface name for the Wireguard interface.
|
||||
// It defaults to wg0 if unset.
|
||||
InterfaceName string
|
||||
// Private key in base 64 format
|
||||
PrivateKey string
|
||||
// Public key in base 64 format
|
||||
PublicKey string
|
||||
// Pre shared key in base 64 format
|
||||
PreSharedKey string
|
||||
// Wireguard server endpoint to connect to.
|
||||
Endpoint *net.UDPAddr
|
||||
// Addresses assigned to the client.
|
||||
Addresses []*net.IPNet
|
||||
// FirewallMark to be used in routing tables and IP rules.
|
||||
// It defaults to 51820 if left to 0.
|
||||
FirewallMark int
|
||||
// RulePriority is the priority for the rule created with the
|
||||
// FirewallMark.
|
||||
RulePriority int
|
||||
}
|
||||
|
||||
func (s *Settings) SetDefaults() {
|
||||
if s.InterfaceName == "" {
|
||||
const defaultInterfaceName = "wg0"
|
||||
s.InterfaceName = defaultInterfaceName
|
||||
}
|
||||
|
||||
if s.Endpoint != nil && s.Endpoint.Port == 0 {
|
||||
const defaultPort = 51820
|
||||
s.Endpoint.Port = defaultPort
|
||||
}
|
||||
|
||||
if s.FirewallMark == 0 {
|
||||
const defaultFirewallMark = 51820
|
||||
s.FirewallMark = defaultFirewallMark
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInterfaceNameInvalid = errors.New("invalid interface name")
|
||||
ErrPrivateKeyMissing = errors.New("private key is missing")
|
||||
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
|
||||
ErrPublicKeyMissing = errors.New("public key is missing")
|
||||
ErrPublicKeyInvalid = errors.New("cannot parse public key")
|
||||
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
|
||||
ErrEndpointMissing = errors.New("endpoint is missing")
|
||||
ErrEndpointIPMissing = errors.New("endpoint IP is missing")
|
||||
ErrEndpointPortMissing = errors.New("endpoint port is missing")
|
||||
ErrAddressMissing = errors.New("interface address is missing")
|
||||
ErrAddressNil = errors.New("interface address is nil")
|
||||
ErrAddressIPMissing = errors.New("interface address IP is missing")
|
||||
ErrAddressMaskMissing = errors.New("interface address mask is missing")
|
||||
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
||||
)
|
||||
|
||||
var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
|
||||
func (s *Settings) Check() (err error) {
|
||||
if !interfaceNameRegexp.MatchString(s.InterfaceName) {
|
||||
return fmt.Errorf("%w: %s", ErrInterfaceNameInvalid, s.InterfaceName)
|
||||
}
|
||||
|
||||
if s.PrivateKey == "" {
|
||||
return ErrPrivateKeyMissing
|
||||
} else if _, err := wgtypes.ParseKey(s.PrivateKey); err != nil {
|
||||
return ErrPrivateKeyInvalid
|
||||
}
|
||||
|
||||
if s.PublicKey == "" {
|
||||
return ErrPublicKeyMissing
|
||||
} else if _, err := wgtypes.ParseKey(s.PublicKey); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrPublicKeyInvalid, s.PublicKey)
|
||||
}
|
||||
|
||||
if s.PreSharedKey != "" {
|
||||
if _, err := wgtypes.ParseKey(s.PreSharedKey); err != nil {
|
||||
return ErrPreSharedKeyInvalid
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case s.Endpoint == nil:
|
||||
return ErrEndpointMissing
|
||||
case s.Endpoint.IP == nil:
|
||||
return ErrEndpointIPMissing
|
||||
case s.Endpoint.Port == 0:
|
||||
return ErrEndpointPortMissing
|
||||
}
|
||||
|
||||
if len(s.Addresses) == 0 {
|
||||
return ErrAddressMissing
|
||||
}
|
||||
for i, addr := range s.Addresses {
|
||||
switch {
|
||||
case addr == nil:
|
||||
return fmt.Errorf("%w: for address %d of %d",
|
||||
ErrAddressNil, i+1, len(s.Addresses))
|
||||
case addr.IP == nil:
|
||||
return fmt.Errorf("%w: for address %d of %d",
|
||||
ErrAddressIPMissing, i+1, len(s.Addresses))
|
||||
case addr.Mask == nil:
|
||||
return fmt.Errorf("%w: for address %d of %d",
|
||||
ErrAddressMaskMissing, i+1, len(s.Addresses))
|
||||
}
|
||||
}
|
||||
|
||||
if s.FirewallMark == 0 {
|
||||
return ErrFirewallMarkMissing
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Settings) String() string {
|
||||
lines := s.ToLines(ToLinesSettings{})
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
type ToLinesSettings struct {
|
||||
// Indent defaults to 4 spaces " ".
|
||||
Indent *string
|
||||
// FieldPrefix defaults to "├── ".
|
||||
FieldPrefix *string
|
||||
// LastFieldPrefix defaults to "└── ".
|
||||
LastFieldPrefix *string
|
||||
}
|
||||
|
||||
func (settings *ToLinesSettings) setDefaults() {
|
||||
toStringPtr := func(s string) *string { return &s }
|
||||
if settings.Indent == nil {
|
||||
settings.Indent = toStringPtr(" ")
|
||||
}
|
||||
if settings.FieldPrefix == nil {
|
||||
settings.FieldPrefix = toStringPtr("├── ")
|
||||
}
|
||||
if settings.LastFieldPrefix == nil {
|
||||
settings.LastFieldPrefix = toStringPtr("└── ")
|
||||
}
|
||||
}
|
||||
|
||||
// ToLines serializes the settings to a slice of strings for display.
|
||||
func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
|
||||
settings.setDefaults()
|
||||
|
||||
indent := *settings.Indent
|
||||
fieldPrefix := *settings.FieldPrefix
|
||||
lastFieldPrefix := *settings.LastFieldPrefix
|
||||
|
||||
lines = append(lines, fieldPrefix+"Interface name: "+s.InterfaceName)
|
||||
const (
|
||||
set = "set"
|
||||
notSet = "not set"
|
||||
)
|
||||
|
||||
isSet := notSet
|
||||
if s.PrivateKey != "" {
|
||||
isSet = set
|
||||
}
|
||||
lines = append(lines, fieldPrefix+"Private key: "+isSet)
|
||||
|
||||
if s.PublicKey != "" {
|
||||
lines = append(lines, fieldPrefix+"PublicKey: "+s.PublicKey)
|
||||
}
|
||||
|
||||
isSet = notSet
|
||||
if s.PreSharedKey != "" {
|
||||
isSet = set
|
||||
}
|
||||
lines = append(lines, fieldPrefix+"Pre shared key: "+isSet)
|
||||
|
||||
endpointStr := notSet
|
||||
if s.Endpoint != nil {
|
||||
endpointStr = s.Endpoint.String()
|
||||
}
|
||||
lines = append(lines, fieldPrefix+"Endpoint: "+endpointStr)
|
||||
|
||||
if s.FirewallMark != 0 {
|
||||
lines = append(lines, fieldPrefix+"Firewall mark: "+fmt.Sprint(s.FirewallMark))
|
||||
}
|
||||
|
||||
if s.RulePriority != 0 {
|
||||
lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority))
|
||||
}
|
||||
|
||||
if len(s.Addresses) == 0 {
|
||||
lines = append(lines, lastFieldPrefix+"Addresses: "+notSet)
|
||||
} else {
|
||||
lines = append(lines, lastFieldPrefix+"Addresses:")
|
||||
for i, address := range s.Addresses {
|
||||
prefix := fieldPrefix
|
||||
if i == len(s.Addresses)-1 {
|
||||
prefix = lastFieldPrefix
|
||||
}
|
||||
lines = append(lines, indent+prefix+address.String())
|
||||
}
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
377
internal/wireguard/settings_test.go
Normal file
377
internal/wireguard/settings_test.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Settings_SetDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
original Settings
|
||||
expected Settings
|
||||
}{
|
||||
"empty settings": {
|
||||
expected: Settings{
|
||||
InterfaceName: "wg0",
|
||||
FirewallMark: 51820,
|
||||
},
|
||||
},
|
||||
"default endpoint port": {
|
||||
original: Settings{
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
},
|
||||
},
|
||||
expected: Settings{
|
||||
InterfaceName: "wg0",
|
||||
FirewallMark: 51820,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
},
|
||||
"not empty settings": {
|
||||
original: Settings{
|
||||
InterfaceName: "wg1",
|
||||
FirewallMark: 999,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 9999,
|
||||
},
|
||||
},
|
||||
expected: Settings{
|
||||
InterfaceName: "wg1",
|
||||
FirewallMark: 999,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 9999,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCase.original.SetDefaults()
|
||||
|
||||
assert.Equal(t, testCase.expected, testCase.original)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Settings_Check(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
validKey1 = "oMNSf/zJ0pt1ciy+qIRk8Rlyfs9accwuRLnKd85Yl1Q="
|
||||
validKey2 = "aPjc9US5ICB30D1P4glR9tO7bkB2Ga+KZiFqnoypBHk="
|
||||
)
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
err error
|
||||
}{
|
||||
"empty settings": {
|
||||
err: errors.New("invalid interface name: "),
|
||||
},
|
||||
"bad interface name": {
|
||||
settings: Settings{
|
||||
InterfaceName: "$H1T",
|
||||
},
|
||||
err: errors.New("invalid interface name: $H1T"),
|
||||
},
|
||||
"empty private key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
},
|
||||
err: ErrPrivateKeyMissing,
|
||||
},
|
||||
"bad private key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: "bad key",
|
||||
},
|
||||
err: ErrPrivateKeyInvalid,
|
||||
},
|
||||
"empty public key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
},
|
||||
err: ErrPublicKeyMissing,
|
||||
},
|
||||
"bad public key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: "bad key",
|
||||
},
|
||||
err: errors.New("cannot parse public key: bad key"),
|
||||
},
|
||||
"bad preshared key": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
PreSharedKey: "bad key",
|
||||
},
|
||||
err: errors.New("cannot parse pre-shared key"),
|
||||
},
|
||||
"empty endpoint": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
},
|
||||
err: ErrEndpointMissing,
|
||||
},
|
||||
"nil endpoint IP": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{},
|
||||
},
|
||||
err: ErrEndpointIPMissing,
|
||||
},
|
||||
"nil endpoint port": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
},
|
||||
},
|
||||
err: ErrEndpointPortMissing,
|
||||
},
|
||||
"no address": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
err: ErrAddressMissing,
|
||||
},
|
||||
"nil address": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{nil},
|
||||
},
|
||||
err: errors.New("interface address is nil: for address 1 of 1"),
|
||||
},
|
||||
"nil address IP": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{}},
|
||||
},
|
||||
err: errors.New("interface address IP is missing: for address 1 of 1"),
|
||||
},
|
||||
"nil address mask": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4)}},
|
||||
},
|
||||
err: errors.New("interface address mask is missing: for address 1 of 1"),
|
||||
},
|
||||
"zero firewall mark": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
|
||||
},
|
||||
err: ErrFirewallMarkMissing,
|
||||
},
|
||||
"all valid": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: validKey1,
|
||||
PublicKey: validKey2,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
Addresses: []*net.IPNet{{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}},
|
||||
FirewallMark: 999,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := testCase.settings.Check()
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func toStringPtr(s string) *string { return &s }
|
||||
|
||||
func Test_ToLinesSettings_setDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
settings := ToLinesSettings{
|
||||
Indent: toStringPtr("indent"),
|
||||
}
|
||||
|
||||
someFunc := func(settings ToLinesSettings) {
|
||||
settings.setDefaults()
|
||||
expectedSettings := ToLinesSettings{
|
||||
Indent: toStringPtr("indent"),
|
||||
FieldPrefix: toStringPtr("├── "),
|
||||
LastFieldPrefix: toStringPtr("└── "),
|
||||
}
|
||||
assert.Equal(t, expectedSettings, settings)
|
||||
}
|
||||
someFunc(settings)
|
||||
|
||||
untouchedSettings := ToLinesSettings{
|
||||
Indent: toStringPtr("indent"),
|
||||
}
|
||||
assert.Equal(t, untouchedSettings, settings)
|
||||
}
|
||||
|
||||
func Test_Settings_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
settings := Settings{
|
||||
InterfaceName: "wg0",
|
||||
}
|
||||
const expected = `├── Interface name: wg0
|
||||
├── Private key: not set
|
||||
├── Pre shared key: not set
|
||||
├── Endpoint: not set
|
||||
└── Addresses: not set`
|
||||
s := settings.String()
|
||||
assert.Equal(t, expected, s)
|
||||
}
|
||||
|
||||
func Test_Settings_Lines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
lineSettings ToLinesSettings
|
||||
lines []string
|
||||
}{
|
||||
"empty settings": {
|
||||
lines: []string{
|
||||
"├── Interface name: ",
|
||||
"├── Private key: not set",
|
||||
"├── Pre shared key: not set",
|
||||
"├── Endpoint: not set",
|
||||
"└── Addresses: not set",
|
||||
},
|
||||
},
|
||||
"settings all set": {
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
PrivateKey: "private key",
|
||||
PublicKey: "public key",
|
||||
PreSharedKey: "pre-shared key",
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: net.IPv4(1, 2, 3, 4),
|
||||
Port: 51820,
|
||||
},
|
||||
FirewallMark: 999,
|
||||
RulePriority: 888,
|
||||
Addresses: []*net.IPNet{
|
||||
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
|
||||
},
|
||||
},
|
||||
lines: []string{
|
||||
"├── Interface name: wg0",
|
||||
"├── Private key: set",
|
||||
"├── PublicKey: public key",
|
||||
"├── Pre shared key: set",
|
||||
"├── Endpoint: 1.2.3.4:51820",
|
||||
"├── Firewall mark: 999",
|
||||
"├── Rule priority: 888",
|
||||
"└── Addresses:",
|
||||
" ├── 1.1.1.1/24",
|
||||
" └── 2.2.2.2/32",
|
||||
},
|
||||
},
|
||||
"custom line settings": {
|
||||
lineSettings: ToLinesSettings{
|
||||
Indent: toStringPtr(" "),
|
||||
FieldPrefix: toStringPtr("- "),
|
||||
LastFieldPrefix: toStringPtr("* "),
|
||||
},
|
||||
settings: Settings{
|
||||
InterfaceName: "wg0",
|
||||
Addresses: []*net.IPNet{
|
||||
{IP: net.IPv4(1, 1, 1, 1), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.IPv4(2, 2, 2, 2), Mask: net.CIDRMask(32, 32)},
|
||||
},
|
||||
},
|
||||
lines: []string{
|
||||
"- Interface name: wg0",
|
||||
"- Private key: not set",
|
||||
"- Pre shared key: not set",
|
||||
"- Endpoint: not set",
|
||||
"* Addresses:",
|
||||
" - 1.1.1.1/24",
|
||||
" * 2.2.2.2/32",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lines := testCase.settings.ToLines(testCase.lineSettings)
|
||||
|
||||
assert.Equal(t, testCase.lines, lines)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user