feat(wireguard): WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL option

This commit is contained in:
Quentin McGaw
2024-04-25 10:42:09 +00:00
parent 7b4befce61
commit c87c0e12fe
7 changed files with 65 additions and 15 deletions

View File

@@ -100,6 +100,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \ WIREGUARD_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \
WIREGUARD_PUBLIC_KEY= \ WIREGUARD_PUBLIC_KEY= \
WIREGUARD_ALLOWED_IPS= \ WIREGUARD_ALLOWED_IPS= \
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
WIREGUARD_ADDRESSES= \ WIREGUARD_ADDRESSES= \
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \ WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
WIREGUARD_MTU=1400 \ WIREGUARD_MTU=1400 \

View File

@@ -50,5 +50,6 @@ var (
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set") ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
ErrWireguardPublicKeyNotSet = errors.New("public key is not set") ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid") ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
ErrWireguardImplementationNotValid = errors.New("implementation is not valid") ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
) )

View File

@@ -5,6 +5,7 @@ import (
"net/netip" "net/netip"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
@@ -34,7 +35,8 @@ type Wireguard struct {
// Interface is the name of the Wireguard interface // Interface is the name of the Wireguard interface
// to create. It cannot be the empty string in the // to create. It cannot be the empty string in the
// internal state. // internal state.
Interface string `json:"interface"` Interface string `json:"interface"`
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
// Maximum Transmission Unit (MTU) of the Wireguard interface. // Maximum Transmission Unit (MTU) of the Wireguard interface.
// It cannot be zero in the internal state, and defaults to // It cannot be zero in the internal state, and defaults to
// 1400. Note it is not the wireguard-go MTU default of 1420 // 1400. Note it is not the wireguard-go MTU default of 1420
@@ -123,6 +125,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
} }
} }
if *w.PersistentKeepaliveInterval < 0 {
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
*w.PersistentKeepaliveInterval)
}
// Validate interface // Validate interface
if !regexpInterfaceName.MatchString(w.Interface) { if !regexpInterfaceName.MatchString(w.Interface) {
return fmt.Errorf("%w: '%s' does not match regex '%s'", return fmt.Errorf("%w: '%s' does not match regex '%s'",
@@ -139,13 +146,14 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
func (w *Wireguard) copy() (copied Wireguard) { func (w *Wireguard) copy() (copied Wireguard) {
return Wireguard{ return Wireguard{
PrivateKey: gosettings.CopyPointer(w.PrivateKey), PrivateKey: gosettings.CopyPointer(w.PrivateKey),
PreSharedKey: gosettings.CopyPointer(w.PreSharedKey), PreSharedKey: gosettings.CopyPointer(w.PreSharedKey),
Addresses: gosettings.CopySlice(w.Addresses), Addresses: gosettings.CopySlice(w.Addresses),
AllowedIPs: gosettings.CopySlice(w.AllowedIPs), AllowedIPs: gosettings.CopySlice(w.AllowedIPs),
Interface: w.Interface, PersistentKeepaliveInterval: gosettings.CopyPointer(w.PersistentKeepaliveInterval),
MTU: w.MTU, Interface: w.Interface,
Implementation: w.Implementation, MTU: w.MTU,
Implementation: w.Implementation,
} }
} }
@@ -154,6 +162,8 @@ func (w *Wireguard) overrideWith(other Wireguard) {
w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey) w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses) w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses)
w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs) w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs)
w.PersistentKeepaliveInterval = gosettings.OverrideWithPointer(w.PersistentKeepaliveInterval,
other.PersistentKeepaliveInterval)
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface) w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU) w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation) w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
@@ -172,6 +182,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.PrefixFrom(netip.IPv6Unspecified(), 0),
} }
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs) w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0") w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
const defaultMTU = 1400 const defaultMTU = 1400
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU) w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
@@ -205,6 +216,10 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
allowedIPsNode.Appendf(allowedIP.String()) allowedIPsNode.Appendf(allowedIP.String())
} }
if *w.PersistentKeepaliveInterval > 0 {
node.Appendf("Persistent keepalive interval: %s", w.PersistentKeepaliveInterval)
}
interfaceNode := node.Appendf("Network interface: %s", w.Interface) interfaceNode := node.Appendf("Network interface: %s", w.Interface)
interfaceNode.Appendf("MTU: %d", w.MTU) interfaceNode.Appendf("MTU: %d", w.MTU)
@@ -241,6 +256,12 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
if err != nil { if err != nil {
return err // already wrapped return err // already wrapped
} }
w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL")
if err != nil {
return err
}
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU") mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
if err != nil { if err != nil {
return err return err

View File

@@ -40,5 +40,7 @@ func BuildWireguardSettings(connection models.Connection,
settings.AllowedIPs = append(settings.AllowedIPs, allowedIP) settings.AllowedIPs = append(settings.AllowedIPs, allowedIP)
} }
settings.PersistentKeepaliveInterval = *userSettings.PersistentKeepaliveInterval
return settings return settings
} }

View File

@@ -3,6 +3,7 @@ package utils
import ( import (
"net/netip" "net/netip"
"testing" "testing"
"time"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
@@ -10,7 +11,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func stringPtr(s string) *string { return &s } func ptrTo[T any](x T) *T { return &x }
func Test_BuildWireguardSettings(t *testing.T) { func Test_BuildWireguardSettings(t *testing.T) {
t.Parallel() t.Parallel()
@@ -28,8 +29,8 @@ func Test_BuildWireguardSettings(t *testing.T) {
PubKey: "public", PubKey: "public",
}, },
userSettings: settings.Wireguard{ userSettings: settings.Wireguard{
PrivateKey: stringPtr("private"), PrivateKey: ptrTo("private"),
PreSharedKey: stringPtr("pre-shared"), PreSharedKey: ptrTo("pre-shared"),
Addresses: []netip.Prefix{ Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32), netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32), netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
@@ -38,7 +39,8 @@ func Test_BuildWireguardSettings(t *testing.T) {
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32), netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
}, },
Interface: "wg1", PersistentKeepaliveInterval: ptrTo(time.Hour),
Interface: "wg1",
}, },
ipv6Supported: false, ipv6Supported: false,
settings: wireguard.Settings{ settings: wireguard.Settings{
@@ -53,8 +55,9 @@ func Test_BuildWireguardSettings(t *testing.T) {
AllowedIPs: []netip.Prefix{ AllowedIPs: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
}, },
RulePriority: 101, PersistentKeepaliveInterval: time.Hour,
IPv6: boolPtr(false), RulePriority: 101,
IPv6: boolPtr(false),
}, },
}, },
} }

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -43,6 +44,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
preSharedKey = &preSharedKeyValue preSharedKey = &preSharedKeyValue
} }
var persistentKeepaliveInterval *time.Duration
if settings.PersistentKeepaliveInterval > 0 {
persistentKeepaliveInterval = new(time.Duration)
*persistentKeepaliveInterval = settings.PersistentKeepaliveInterval
}
firewallMark := settings.FirewallMark firewallMark := settings.FirewallMark
config = wgtypes.Config{ config = wgtypes.Config{
@@ -63,7 +70,8 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
Mask: []byte(net.IPv6zero), Mask: []byte(net.IPv6zero),
}, },
}, },
ReplaceAllowedIPs: true, PersistentKeepaliveInterval: persistentKeepaliveInterval,
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{ Endpoint: &net.UDPAddr{
IP: settings.Endpoint.Addr().AsSlice(), IP: settings.Endpoint.Addr().AsSlice(),
Port: int(settings.Endpoint.Port()), Port: int(settings.Endpoint.Port()),

View File

@@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"regexp" "regexp"
"strings" "strings"
"time"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -30,6 +31,8 @@ type Settings struct {
// the Wireguard interface. // the Wireguard interface.
// Note IPv6 addresses are ignored if IPv6 is not supported. // Note IPv6 addresses are ignored if IPv6 is not supported.
AllowedIPs []netip.Prefix AllowedIPs []netip.Prefix
// PersistentKeepaliveInterval defines the keep alive interval, if not zero.
PersistentKeepaliveInterval time.Duration
// FirewallMark to be used in routing tables and IP rules. // FirewallMark to be used in routing tables and IP rules.
// It defaults to 51820 if left to 0. // It defaults to 51820 if left to 0.
FirewallMark int FirewallMark int
@@ -99,6 +102,7 @@ var (
ErrAllowedIPsMissing = errors.New("allowed IPs are missing") ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
ErrAllowedIPNotValid = errors.New("allowed IP is not valid") ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported") ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
ErrKeepaliveIsNegative = errors.New("keep alive interval is negative")
ErrFirewallMarkMissing = errors.New("firewall mark is missing") ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrMTUMissing = errors.New("MTU is missing") ErrMTUMissing = errors.New("MTU is missing")
ErrImplementationInvalid = errors.New("invalid implementation") ErrImplementationInvalid = errors.New("invalid implementation")
@@ -160,6 +164,11 @@ func (s *Settings) Check() (err error) {
} }
} }
if s.PersistentKeepaliveInterval < 0 {
return fmt.Errorf("%w: %s", ErrKeepaliveIsNegative,
s.PersistentKeepaliveInterval)
}
if s.FirewallMark == 0 { if s.FirewallMark == 0 {
return fmt.Errorf("%w", ErrFirewallMarkMissing) return fmt.Errorf("%w", ErrFirewallMarkMissing)
} }
@@ -286,5 +295,10 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
} }
} }
if s.PersistentKeepaliveInterval > 0 {
lines = append(lines, fieldPrefix+"Persistent keep alive interval: "+
s.PersistentKeepaliveInterval.String())
}
return lines return lines
} }