feat(wireguard): WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL option

This commit is contained in:
Quentin McGaw
2024-04-25 10:42:09 +00:00
parent 7b4befce61
commit c87c0e12fe
7 changed files with 65 additions and 15 deletions

View File

@@ -100,6 +100,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \
WIREGUARD_PUBLIC_KEY= \
WIREGUARD_ALLOWED_IPS= \
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
WIREGUARD_ADDRESSES= \
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
WIREGUARD_MTU=1400 \

View File

@@ -50,5 +50,6 @@ var (
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
)

View File

@@ -5,6 +5,7 @@ import (
"net/netip"
"regexp"
"strings"
"time"
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers"
@@ -35,6 +36,7 @@ type Wireguard struct {
// to create. It cannot be the empty string in the
// internal state.
Interface string `json:"interface"`
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
// Maximum Transmission Unit (MTU) of the Wireguard interface.
// It cannot be zero in the internal state, and defaults to
// 1400. Note it is not the wireguard-go MTU default of 1420
@@ -123,6 +125,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
}
}
if *w.PersistentKeepaliveInterval < 0 {
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
*w.PersistentKeepaliveInterval)
}
// Validate interface
if !regexpInterfaceName.MatchString(w.Interface) {
return fmt.Errorf("%w: '%s' does not match regex '%s'",
@@ -143,6 +150,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
PreSharedKey: gosettings.CopyPointer(w.PreSharedKey),
Addresses: gosettings.CopySlice(w.Addresses),
AllowedIPs: gosettings.CopySlice(w.AllowedIPs),
PersistentKeepaliveInterval: gosettings.CopyPointer(w.PersistentKeepaliveInterval),
Interface: w.Interface,
MTU: w.MTU,
Implementation: w.Implementation,
@@ -154,6 +162,8 @@ func (w *Wireguard) overrideWith(other Wireguard) {
w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses)
w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs)
w.PersistentKeepaliveInterval = gosettings.OverrideWithPointer(w.PersistentKeepaliveInterval,
other.PersistentKeepaliveInterval)
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
@@ -172,6 +182,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
netip.PrefixFrom(netip.IPv6Unspecified(), 0),
}
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
const defaultMTU = 1400
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
@@ -205,6 +216,10 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
allowedIPsNode.Appendf(allowedIP.String())
}
if *w.PersistentKeepaliveInterval > 0 {
node.Appendf("Persistent keepalive interval: %s", w.PersistentKeepaliveInterval)
}
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
interfaceNode.Appendf("MTU: %d", w.MTU)
@@ -241,6 +256,12 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
if err != nil {
return err // already wrapped
}
w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL")
if err != nil {
return err
}
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
if err != nil {
return err

View File

@@ -40,5 +40,7 @@ func BuildWireguardSettings(connection models.Connection,
settings.AllowedIPs = append(settings.AllowedIPs, allowedIP)
}
settings.PersistentKeepaliveInterval = *userSettings.PersistentKeepaliveInterval
return settings
}

View File

@@ -3,6 +3,7 @@ package utils
import (
"net/netip"
"testing"
"time"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
@@ -10,7 +11,7 @@ import (
"github.com/stretchr/testify/assert"
)
func stringPtr(s string) *string { return &s }
func ptrTo[T any](x T) *T { return &x }
func Test_BuildWireguardSettings(t *testing.T) {
t.Parallel()
@@ -28,8 +29,8 @@ func Test_BuildWireguardSettings(t *testing.T) {
PubKey: "public",
},
userSettings: settings.Wireguard{
PrivateKey: stringPtr("private"),
PreSharedKey: stringPtr("pre-shared"),
PrivateKey: ptrTo("private"),
PreSharedKey: ptrTo("pre-shared"),
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
@@ -38,6 +39,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
},
PersistentKeepaliveInterval: ptrTo(time.Hour),
Interface: "wg1",
},
ipv6Supported: false,
@@ -53,6 +55,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
AllowedIPs: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
},
PersistentKeepaliveInterval: time.Hour,
RulePriority: 101,
IPv6: boolPtr(false),
},

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -43,6 +44,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
preSharedKey = &preSharedKeyValue
}
var persistentKeepaliveInterval *time.Duration
if settings.PersistentKeepaliveInterval > 0 {
persistentKeepaliveInterval = new(time.Duration)
*persistentKeepaliveInterval = settings.PersistentKeepaliveInterval
}
firewallMark := settings.FirewallMark
config = wgtypes.Config{
@@ -63,6 +70,7 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
Mask: []byte(net.IPv6zero),
},
},
PersistentKeepaliveInterval: persistentKeepaliveInterval,
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{
IP: settings.Endpoint.Addr().AsSlice(),

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"regexp"
"strings"
"time"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -30,6 +31,8 @@ type Settings struct {
// the Wireguard interface.
// Note IPv6 addresses are ignored if IPv6 is not supported.
AllowedIPs []netip.Prefix
// PersistentKeepaliveInterval defines the keep alive interval, if not zero.
PersistentKeepaliveInterval time.Duration
// FirewallMark to be used in routing tables and IP rules.
// It defaults to 51820 if left to 0.
FirewallMark int
@@ -99,6 +102,7 @@ var (
ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
ErrKeepaliveIsNegative = errors.New("keep alive interval is negative")
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrMTUMissing = errors.New("MTU is missing")
ErrImplementationInvalid = errors.New("invalid implementation")
@@ -160,6 +164,11 @@ func (s *Settings) Check() (err error) {
}
}
if s.PersistentKeepaliveInterval < 0 {
return fmt.Errorf("%w: %s", ErrKeepaliveIsNegative,
s.PersistentKeepaliveInterval)
}
if s.FirewallMark == 0 {
return fmt.Errorf("%w", ErrFirewallMarkMissing)
}
@@ -286,5 +295,10 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
}
}
if s.PersistentKeepaliveInterval > 0 {
lines = append(lines, fieldPrefix+"Persistent keep alive interval: "+
s.PersistentKeepaliveInterval.String())
}
return lines
}