feat(wireguard): WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL option
This commit is contained in:
@@ -100,6 +100,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
WIREGUARD_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \
|
||||
WIREGUARD_PUBLIC_KEY= \
|
||||
WIREGUARD_ALLOWED_IPS= \
|
||||
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
||||
WIREGUARD_ADDRESSES= \
|
||||
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||
WIREGUARD_MTU=1400 \
|
||||
|
||||
@@ -50,5 +50,6 @@ var (
|
||||
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
|
||||
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
|
||||
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
|
||||
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
|
||||
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
@@ -35,6 +36,7 @@ type Wireguard struct {
|
||||
// to create. It cannot be the empty string in the
|
||||
// internal state.
|
||||
Interface string `json:"interface"`
|
||||
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
|
||||
// Maximum Transmission Unit (MTU) of the Wireguard interface.
|
||||
// It cannot be zero in the internal state, and defaults to
|
||||
// 1400. Note it is not the wireguard-go MTU default of 1420
|
||||
@@ -123,6 +125,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
|
||||
}
|
||||
}
|
||||
|
||||
if *w.PersistentKeepaliveInterval < 0 {
|
||||
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
|
||||
*w.PersistentKeepaliveInterval)
|
||||
}
|
||||
|
||||
// Validate interface
|
||||
if !regexpInterfaceName.MatchString(w.Interface) {
|
||||
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
||||
@@ -143,6 +150,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
|
||||
PreSharedKey: gosettings.CopyPointer(w.PreSharedKey),
|
||||
Addresses: gosettings.CopySlice(w.Addresses),
|
||||
AllowedIPs: gosettings.CopySlice(w.AllowedIPs),
|
||||
PersistentKeepaliveInterval: gosettings.CopyPointer(w.PersistentKeepaliveInterval),
|
||||
Interface: w.Interface,
|
||||
MTU: w.MTU,
|
||||
Implementation: w.Implementation,
|
||||
@@ -154,6 +162,8 @@ func (w *Wireguard) overrideWith(other Wireguard) {
|
||||
w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
|
||||
w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses)
|
||||
w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs)
|
||||
w.PersistentKeepaliveInterval = gosettings.OverrideWithPointer(w.PersistentKeepaliveInterval,
|
||||
other.PersistentKeepaliveInterval)
|
||||
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
|
||||
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
|
||||
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
|
||||
@@ -172,6 +182,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
||||
netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||
}
|
||||
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
||||
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
|
||||
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
||||
const defaultMTU = 1400
|
||||
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
|
||||
@@ -205,6 +216,10 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
||||
allowedIPsNode.Appendf(allowedIP.String())
|
||||
}
|
||||
|
||||
if *w.PersistentKeepaliveInterval > 0 {
|
||||
node.Appendf("Persistent keepalive interval: %s", w.PersistentKeepaliveInterval)
|
||||
}
|
||||
|
||||
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
||||
interfaceNode.Appendf("MTU: %d", w.MTU)
|
||||
|
||||
@@ -241,6 +256,12 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
||||
if err != nil {
|
||||
return err // already wrapped
|
||||
}
|
||||
|
||||
w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -40,5 +40,7 @@ func BuildWireguardSettings(connection models.Connection,
|
||||
settings.AllowedIPs = append(settings.AllowedIPs, allowedIP)
|
||||
}
|
||||
|
||||
settings.PersistentKeepaliveInterval = *userSettings.PersistentKeepaliveInterval
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package utils
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
@@ -10,7 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func stringPtr(s string) *string { return &s }
|
||||
func ptrTo[T any](x T) *T { return &x }
|
||||
|
||||
func Test_BuildWireguardSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -28,8 +29,8 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
||||
PubKey: "public",
|
||||
},
|
||||
userSettings: settings.Wireguard{
|
||||
PrivateKey: stringPtr("private"),
|
||||
PreSharedKey: stringPtr("pre-shared"),
|
||||
PrivateKey: ptrTo("private"),
|
||||
PreSharedKey: ptrTo("pre-shared"),
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
||||
@@ -38,6 +39,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
||||
},
|
||||
PersistentKeepaliveInterval: ptrTo(time.Hour),
|
||||
Interface: "wg1",
|
||||
},
|
||||
ipv6Supported: false,
|
||||
@@ -53,6 +55,7 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
||||
AllowedIPs: []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||
},
|
||||
PersistentKeepaliveInterval: time.Hour,
|
||||
RulePriority: 101,
|
||||
IPv6: boolPtr(false),
|
||||
},
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -43,6 +44,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
|
||||
preSharedKey = &preSharedKeyValue
|
||||
}
|
||||
|
||||
var persistentKeepaliveInterval *time.Duration
|
||||
if settings.PersistentKeepaliveInterval > 0 {
|
||||
persistentKeepaliveInterval = new(time.Duration)
|
||||
*persistentKeepaliveInterval = settings.PersistentKeepaliveInterval
|
||||
}
|
||||
|
||||
firewallMark := settings.FirewallMark
|
||||
|
||||
config = wgtypes.Config{
|
||||
@@ -63,6 +70,7 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
|
||||
Mask: []byte(net.IPv6zero),
|
||||
},
|
||||
},
|
||||
PersistentKeepaliveInterval: persistentKeepaliveInterval,
|
||||
ReplaceAllowedIPs: true,
|
||||
Endpoint: &net.UDPAddr{
|
||||
IP: settings.Endpoint.Addr().AsSlice(),
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -30,6 +31,8 @@ type Settings struct {
|
||||
// the Wireguard interface.
|
||||
// Note IPv6 addresses are ignored if IPv6 is not supported.
|
||||
AllowedIPs []netip.Prefix
|
||||
// PersistentKeepaliveInterval defines the keep alive interval, if not zero.
|
||||
PersistentKeepaliveInterval time.Duration
|
||||
// FirewallMark to be used in routing tables and IP rules.
|
||||
// It defaults to 51820 if left to 0.
|
||||
FirewallMark int
|
||||
@@ -99,6 +102,7 @@ var (
|
||||
ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
|
||||
ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
|
||||
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
|
||||
ErrKeepaliveIsNegative = errors.New("keep alive interval is negative")
|
||||
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
||||
ErrMTUMissing = errors.New("MTU is missing")
|
||||
ErrImplementationInvalid = errors.New("invalid implementation")
|
||||
@@ -160,6 +164,11 @@ func (s *Settings) Check() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if s.PersistentKeepaliveInterval < 0 {
|
||||
return fmt.Errorf("%w: %s", ErrKeepaliveIsNegative,
|
||||
s.PersistentKeepaliveInterval)
|
||||
}
|
||||
|
||||
if s.FirewallMark == 0 {
|
||||
return fmt.Errorf("%w", ErrFirewallMarkMissing)
|
||||
}
|
||||
@@ -286,5 +295,10 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
|
||||
}
|
||||
}
|
||||
|
||||
if s.PersistentKeepaliveInterval > 0 {
|
||||
lines = append(lines, fieldPrefix+"Persistent keep alive interval: "+
|
||||
s.PersistentKeepaliveInterval.String())
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user