feat(wireguard): WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL option
This commit is contained in:
@@ -100,6 +100,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
WIREGUARD_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \
|
WIREGUARD_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \
|
||||||
WIREGUARD_PUBLIC_KEY= \
|
WIREGUARD_PUBLIC_KEY= \
|
||||||
WIREGUARD_ALLOWED_IPS= \
|
WIREGUARD_ALLOWED_IPS= \
|
||||||
|
WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \
|
||||||
WIREGUARD_ADDRESSES= \
|
WIREGUARD_ADDRESSES= \
|
||||||
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \
|
||||||
WIREGUARD_MTU=1400 \
|
WIREGUARD_MTU=1400 \
|
||||||
|
|||||||
@@ -50,5 +50,6 @@ var (
|
|||||||
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
|
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
|
||||||
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
|
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
|
||||||
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
|
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
|
||||||
|
ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative")
|
||||||
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
|
ErrWireguardImplementationNotValid = errors.New("implementation is not valid")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
||||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||||
@@ -34,7 +35,8 @@ type Wireguard struct {
|
|||||||
// Interface is the name of the Wireguard interface
|
// Interface is the name of the Wireguard interface
|
||||||
// to create. It cannot be the empty string in the
|
// to create. It cannot be the empty string in the
|
||||||
// internal state.
|
// internal state.
|
||||||
Interface string `json:"interface"`
|
Interface string `json:"interface"`
|
||||||
|
PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"`
|
||||||
// Maximum Transmission Unit (MTU) of the Wireguard interface.
|
// Maximum Transmission Unit (MTU) of the Wireguard interface.
|
||||||
// It cannot be zero in the internal state, and defaults to
|
// It cannot be zero in the internal state, and defaults to
|
||||||
// 1400. Note it is not the wireguard-go MTU default of 1420
|
// 1400. Note it is not the wireguard-go MTU default of 1420
|
||||||
@@ -123,6 +125,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *w.PersistentKeepaliveInterval < 0 {
|
||||||
|
return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative,
|
||||||
|
*w.PersistentKeepaliveInterval)
|
||||||
|
}
|
||||||
|
|
||||||
// Validate interface
|
// Validate interface
|
||||||
if !regexpInterfaceName.MatchString(w.Interface) {
|
if !regexpInterfaceName.MatchString(w.Interface) {
|
||||||
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
return fmt.Errorf("%w: '%s' does not match regex '%s'",
|
||||||
@@ -139,13 +146,14 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
|
|||||||
|
|
||||||
func (w *Wireguard) copy() (copied Wireguard) {
|
func (w *Wireguard) copy() (copied Wireguard) {
|
||||||
return Wireguard{
|
return Wireguard{
|
||||||
PrivateKey: gosettings.CopyPointer(w.PrivateKey),
|
PrivateKey: gosettings.CopyPointer(w.PrivateKey),
|
||||||
PreSharedKey: gosettings.CopyPointer(w.PreSharedKey),
|
PreSharedKey: gosettings.CopyPointer(w.PreSharedKey),
|
||||||
Addresses: gosettings.CopySlice(w.Addresses),
|
Addresses: gosettings.CopySlice(w.Addresses),
|
||||||
AllowedIPs: gosettings.CopySlice(w.AllowedIPs),
|
AllowedIPs: gosettings.CopySlice(w.AllowedIPs),
|
||||||
Interface: w.Interface,
|
PersistentKeepaliveInterval: gosettings.CopyPointer(w.PersistentKeepaliveInterval),
|
||||||
MTU: w.MTU,
|
Interface: w.Interface,
|
||||||
Implementation: w.Implementation,
|
MTU: w.MTU,
|
||||||
|
Implementation: w.Implementation,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,6 +162,8 @@ func (w *Wireguard) overrideWith(other Wireguard) {
|
|||||||
w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
|
w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
|
||||||
w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses)
|
w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses)
|
||||||
w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs)
|
w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs)
|
||||||
|
w.PersistentKeepaliveInterval = gosettings.OverrideWithPointer(w.PersistentKeepaliveInterval,
|
||||||
|
other.PersistentKeepaliveInterval)
|
||||||
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
|
w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface)
|
||||||
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
|
w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU)
|
||||||
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
|
w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation)
|
||||||
@@ -172,6 +182,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
|||||||
netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
}
|
}
|
||||||
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
||||||
|
w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0)
|
||||||
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
w.Interface = gosettings.DefaultComparable(w.Interface, "wg0")
|
||||||
const defaultMTU = 1400
|
const defaultMTU = 1400
|
||||||
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
|
w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU)
|
||||||
@@ -205,6 +216,10 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
|||||||
allowedIPsNode.Appendf(allowedIP.String())
|
allowedIPsNode.Appendf(allowedIP.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *w.PersistentKeepaliveInterval > 0 {
|
||||||
|
node.Appendf("Persistent keepalive interval: %s", w.PersistentKeepaliveInterval)
|
||||||
|
}
|
||||||
|
|
||||||
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
||||||
interfaceNode.Appendf("MTU: %d", w.MTU)
|
interfaceNode.Appendf("MTU: %d", w.MTU)
|
||||||
|
|
||||||
@@ -241,6 +256,12 @@ func (w *Wireguard) read(r *reader.Reader) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err // already wrapped
|
return err // already wrapped
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
|
mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -40,5 +40,7 @@ func BuildWireguardSettings(connection models.Connection,
|
|||||||
settings.AllowedIPs = append(settings.AllowedIPs, allowedIP)
|
settings.AllowedIPs = append(settings.AllowedIPs, allowedIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settings.PersistentKeepaliveInterval = *userSettings.PersistentKeepaliveInterval
|
||||||
|
|
||||||
return settings
|
return settings
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
@@ -10,7 +11,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func stringPtr(s string) *string { return &s }
|
func ptrTo[T any](x T) *T { return &x }
|
||||||
|
|
||||||
func Test_BuildWireguardSettings(t *testing.T) {
|
func Test_BuildWireguardSettings(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
@@ -28,8 +29,8 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
|||||||
PubKey: "public",
|
PubKey: "public",
|
||||||
},
|
},
|
||||||
userSettings: settings.Wireguard{
|
userSettings: settings.Wireguard{
|
||||||
PrivateKey: stringPtr("private"),
|
PrivateKey: ptrTo("private"),
|
||||||
PreSharedKey: stringPtr("pre-shared"),
|
PreSharedKey: ptrTo("pre-shared"),
|
||||||
Addresses: []netip.Prefix{
|
Addresses: []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
||||||
@@ -38,7 +39,8 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
|||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
||||||
},
|
},
|
||||||
Interface: "wg1",
|
PersistentKeepaliveInterval: ptrTo(time.Hour),
|
||||||
|
Interface: "wg1",
|
||||||
},
|
},
|
||||||
ipv6Supported: false,
|
ipv6Supported: false,
|
||||||
settings: wireguard.Settings{
|
settings: wireguard.Settings{
|
||||||
@@ -53,8 +55,9 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
|||||||
AllowedIPs: []netip.Prefix{
|
AllowedIPs: []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
},
|
},
|
||||||
RulePriority: 101,
|
PersistentKeepaliveInterval: time.Hour,
|
||||||
IPv6: boolPtr(false),
|
RulePriority: 101,
|
||||||
|
IPv6: boolPtr(false),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -43,6 +44,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
|
|||||||
preSharedKey = &preSharedKeyValue
|
preSharedKey = &preSharedKeyValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var persistentKeepaliveInterval *time.Duration
|
||||||
|
if settings.PersistentKeepaliveInterval > 0 {
|
||||||
|
persistentKeepaliveInterval = new(time.Duration)
|
||||||
|
*persistentKeepaliveInterval = settings.PersistentKeepaliveInterval
|
||||||
|
}
|
||||||
|
|
||||||
firewallMark := settings.FirewallMark
|
firewallMark := settings.FirewallMark
|
||||||
|
|
||||||
config = wgtypes.Config{
|
config = wgtypes.Config{
|
||||||
@@ -63,7 +70,8 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
|
|||||||
Mask: []byte(net.IPv6zero),
|
Mask: []byte(net.IPv6zero),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ReplaceAllowedIPs: true,
|
PersistentKeepaliveInterval: persistentKeepaliveInterval,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
Endpoint: &net.UDPAddr{
|
Endpoint: &net.UDPAddr{
|
||||||
IP: settings.Endpoint.Addr().AsSlice(),
|
IP: settings.Endpoint.Addr().AsSlice(),
|
||||||
Port: int(settings.Endpoint.Port()),
|
Port: int(settings.Endpoint.Port()),
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -30,6 +31,8 @@ type Settings struct {
|
|||||||
// the Wireguard interface.
|
// the Wireguard interface.
|
||||||
// Note IPv6 addresses are ignored if IPv6 is not supported.
|
// Note IPv6 addresses are ignored if IPv6 is not supported.
|
||||||
AllowedIPs []netip.Prefix
|
AllowedIPs []netip.Prefix
|
||||||
|
// PersistentKeepaliveInterval defines the keep alive interval, if not zero.
|
||||||
|
PersistentKeepaliveInterval time.Duration
|
||||||
// FirewallMark to be used in routing tables and IP rules.
|
// FirewallMark to be used in routing tables and IP rules.
|
||||||
// It defaults to 51820 if left to 0.
|
// It defaults to 51820 if left to 0.
|
||||||
FirewallMark int
|
FirewallMark int
|
||||||
@@ -99,6 +102,7 @@ var (
|
|||||||
ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
|
ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
|
||||||
ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
|
ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
|
||||||
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
|
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
|
||||||
|
ErrKeepaliveIsNegative = errors.New("keep alive interval is negative")
|
||||||
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
||||||
ErrMTUMissing = errors.New("MTU is missing")
|
ErrMTUMissing = errors.New("MTU is missing")
|
||||||
ErrImplementationInvalid = errors.New("invalid implementation")
|
ErrImplementationInvalid = errors.New("invalid implementation")
|
||||||
@@ -160,6 +164,11 @@ func (s *Settings) Check() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.PersistentKeepaliveInterval < 0 {
|
||||||
|
return fmt.Errorf("%w: %s", ErrKeepaliveIsNegative,
|
||||||
|
s.PersistentKeepaliveInterval)
|
||||||
|
}
|
||||||
|
|
||||||
if s.FirewallMark == 0 {
|
if s.FirewallMark == 0 {
|
||||||
return fmt.Errorf("%w", ErrFirewallMarkMissing)
|
return fmt.Errorf("%w", ErrFirewallMarkMissing)
|
||||||
}
|
}
|
||||||
@@ -286,5 +295,10 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.PersistentKeepaliveInterval > 0 {
|
||||||
|
lines = append(lines, fieldPrefix+"Persistent keep alive interval: "+
|
||||||
|
s.PersistentKeepaliveInterval.String())
|
||||||
|
}
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user