feat(wireguard): WIREGUARD_ALLOWED_IPS variable (#1291)
This commit is contained in:
@@ -96,6 +96,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
WIREGUARD_PRIVATE_KEY= \
|
WIREGUARD_PRIVATE_KEY= \
|
||||||
WIREGUARD_PRESHARED_KEY= \
|
WIREGUARD_PRESHARED_KEY= \
|
||||||
WIREGUARD_PUBLIC_KEY= \
|
WIREGUARD_PUBLIC_KEY= \
|
||||||
|
WIREGUARD_ALLOWED_IPS= \
|
||||||
WIREGUARD_ADDRESSES= \
|
WIREGUARD_ADDRESSES= \
|
||||||
WIREGUARD_MTU=1400 \
|
WIREGUARD_MTU=1400 \
|
||||||
WIREGUARD_IMPLEMENTATION=auto \
|
WIREGUARD_IMPLEMENTATION=auto \
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ var (
|
|||||||
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
|
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
|
||||||
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
|
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
|
||||||
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
|
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
|
||||||
|
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
|
||||||
|
ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set")
|
||||||
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
|
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
|
||||||
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
|
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
|
||||||
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
|
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ type Wireguard struct {
|
|||||||
PreSharedKey *string `json:"pre_shared_key"`
|
PreSharedKey *string `json:"pre_shared_key"`
|
||||||
// Addresses are the Wireguard interface addresses.
|
// Addresses are the Wireguard interface addresses.
|
||||||
Addresses []netip.Prefix `json:"addresses"`
|
Addresses []netip.Prefix `json:"addresses"`
|
||||||
|
// AllowedIPs are the Wireguard allowed IPs.
|
||||||
|
// If left unset, they default to "0.0.0.0/0"
|
||||||
|
// and, if IPv6 is supported, "::0".
|
||||||
|
AllowedIPs []netip.Prefix `json:"allowed_ips"`
|
||||||
// Interface is the name of the Wireguard interface
|
// Interface is the name of the Wireguard interface
|
||||||
// to create. It cannot be the empty string in the
|
// to create. It cannot be the empty string in the
|
||||||
// internal state.
|
// internal state.
|
||||||
@@ -89,13 +93,26 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
|
|||||||
}
|
}
|
||||||
for i, ipNet := range w.Addresses {
|
for i, ipNet := range w.Addresses {
|
||||||
if !ipNet.IsValid() {
|
if !ipNet.IsValid() {
|
||||||
return fmt.Errorf("%w: for address at index %d: %s",
|
return fmt.Errorf("%w: for address at index %d",
|
||||||
ErrWireguardInterfaceAddressNotSet, i, ipNet.String())
|
ErrWireguardInterfaceAddressNotSet, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ipv6Supported && ipNet.Addr().Is6() {
|
if !ipv6Supported && ipNet.Addr().Is6() {
|
||||||
return fmt.Errorf("%w: address %s",
|
return fmt.Errorf("%w: address %s",
|
||||||
ErrWireguardInterfaceAddressIPv6, ipNet)
|
ErrWireguardInterfaceAddressIPv6, ipNet.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate AllowedIPs
|
||||||
|
// WARNING: do not check for IPv6 networks in the allowed IPs,
|
||||||
|
// the wireguard code will take care to ignore it.
|
||||||
|
if len(w.AllowedIPs) == 0 {
|
||||||
|
return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet)
|
||||||
|
}
|
||||||
|
for i, allowedIP := range w.AllowedIPs {
|
||||||
|
if !allowedIP.IsValid() {
|
||||||
|
return fmt.Errorf("%w: for allowed ip %d of %d",
|
||||||
|
ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,6 +135,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
|
|||||||
PrivateKey: gosettings.CopyPointer(w.PrivateKey),
|
PrivateKey: gosettings.CopyPointer(w.PrivateKey),
|
||||||
PreSharedKey: gosettings.CopyPointer(w.PreSharedKey),
|
PreSharedKey: gosettings.CopyPointer(w.PreSharedKey),
|
||||||
Addresses: gosettings.CopySlice(w.Addresses),
|
Addresses: gosettings.CopySlice(w.Addresses),
|
||||||
|
AllowedIPs: gosettings.CopySlice(w.AllowedIPs),
|
||||||
Interface: w.Interface,
|
Interface: w.Interface,
|
||||||
MTU: w.MTU,
|
MTU: w.MTU,
|
||||||
Implementation: w.Implementation,
|
Implementation: w.Implementation,
|
||||||
@@ -128,6 +146,7 @@ func (w *Wireguard) mergeWith(other Wireguard) {
|
|||||||
w.PrivateKey = gosettings.MergeWithPointer(w.PrivateKey, other.PrivateKey)
|
w.PrivateKey = gosettings.MergeWithPointer(w.PrivateKey, other.PrivateKey)
|
||||||
w.PreSharedKey = gosettings.MergeWithPointer(w.PreSharedKey, other.PreSharedKey)
|
w.PreSharedKey = gosettings.MergeWithPointer(w.PreSharedKey, other.PreSharedKey)
|
||||||
w.Addresses = gosettings.MergeWithSlice(w.Addresses, other.Addresses)
|
w.Addresses = gosettings.MergeWithSlice(w.Addresses, other.Addresses)
|
||||||
|
w.AllowedIPs = gosettings.MergeWithSlice(w.AllowedIPs, other.AllowedIPs)
|
||||||
w.Interface = gosettings.MergeWithString(w.Interface, other.Interface)
|
w.Interface = gosettings.MergeWithString(w.Interface, other.Interface)
|
||||||
w.MTU = gosettings.MergeWithNumber(w.MTU, other.MTU)
|
w.MTU = gosettings.MergeWithNumber(w.MTU, other.MTU)
|
||||||
w.Implementation = gosettings.MergeWithString(w.Implementation, other.Implementation)
|
w.Implementation = gosettings.MergeWithString(w.Implementation, other.Implementation)
|
||||||
@@ -137,6 +156,7 @@ func (w *Wireguard) overrideWith(other Wireguard) {
|
|||||||
w.PrivateKey = gosettings.OverrideWithPointer(w.PrivateKey, other.PrivateKey)
|
w.PrivateKey = gosettings.OverrideWithPointer(w.PrivateKey, other.PrivateKey)
|
||||||
w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
|
w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
|
||||||
w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses)
|
w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses)
|
||||||
|
w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs)
|
||||||
w.Interface = gosettings.OverrideWithString(w.Interface, other.Interface)
|
w.Interface = gosettings.OverrideWithString(w.Interface, other.Interface)
|
||||||
w.MTU = gosettings.OverrideWithNumber(w.MTU, other.MTU)
|
w.MTU = gosettings.OverrideWithNumber(w.MTU, other.MTU)
|
||||||
w.Implementation = gosettings.OverrideWithString(w.Implementation, other.Implementation)
|
w.Implementation = gosettings.OverrideWithString(w.Implementation, other.Implementation)
|
||||||
@@ -150,6 +170,11 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
|
|||||||
defaultNordVPNPrefix := netip.PrefixFrom(defaultNordVPNAddress, defaultNordVPNAddress.BitLen())
|
defaultNordVPNPrefix := netip.PrefixFrom(defaultNordVPNAddress, defaultNordVPNAddress.BitLen())
|
||||||
w.Addresses = gosettings.DefaultSlice(w.Addresses, []netip.Prefix{defaultNordVPNPrefix})
|
w.Addresses = gosettings.DefaultSlice(w.Addresses, []netip.Prefix{defaultNordVPNPrefix})
|
||||||
}
|
}
|
||||||
|
defaultAllowedIPs := []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
}
|
||||||
|
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
|
||||||
w.Interface = gosettings.DefaultString(w.Interface, "wg0")
|
w.Interface = gosettings.DefaultString(w.Interface, "wg0")
|
||||||
const defaultMTU = 1400
|
const defaultMTU = 1400
|
||||||
w.MTU = gosettings.DefaultNumber(w.MTU, defaultMTU)
|
w.MTU = gosettings.DefaultNumber(w.MTU, defaultMTU)
|
||||||
@@ -178,6 +203,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
|
|||||||
addressesNode.Appendf(address.String())
|
addressesNode.Appendf(address.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allowedIPsNode := node.Appendf("Allowed IPs:")
|
||||||
|
for _, allowedIP := range w.AllowedIPs {
|
||||||
|
allowedIPsNode.Appendf(allowedIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
interfaceNode := node.Appendf("Network interface: %s", w.Interface)
|
||||||
interfaceNode.Appendf("MTU: %d", w.MTU)
|
interfaceNode.Appendf("MTU: %d", w.MTU)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return wireguard, err // already wrapped
|
return wireguard, err // already wrapped
|
||||||
}
|
}
|
||||||
|
wireguard.AllowedIPs, err = s.env.CSVNetipPrefixes("WIREGUARD_ALLOWED_IPS")
|
||||||
|
if err != nil {
|
||||||
|
return wireguard, err // already wrapped
|
||||||
|
}
|
||||||
mtuPtr, err := s.env.Uint16Ptr("WIREGUARD_MTU")
|
mtuPtr, err := s.env.Uint16Ptr("WIREGUARD_MTU")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wireguard, err
|
return wireguard, err
|
||||||
|
|||||||
@@ -32,5 +32,13 @@ func BuildWireguardSettings(connection models.Connection,
|
|||||||
settings.Addresses = append(settings.Addresses, addressCopy)
|
settings.Addresses = append(settings.Addresses, addressCopy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settings.AllowedIPs = make([]netip.Prefix, 0, len(userSettings.AllowedIPs))
|
||||||
|
for _, allowedIP := range userSettings.AllowedIPs {
|
||||||
|
if !ipv6Supported && allowedIP.Addr().Is6() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
settings.AllowedIPs = append(settings.AllowedIPs, allowedIP)
|
||||||
|
}
|
||||||
|
|
||||||
return settings
|
return settings
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,10 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
|||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
|
||||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
||||||
},
|
},
|
||||||
|
AllowedIPs: []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
|
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
|
||||||
|
},
|
||||||
Interface: "wg1",
|
Interface: "wg1",
|
||||||
},
|
},
|
||||||
ipv6Supported: false,
|
ipv6Supported: false,
|
||||||
@@ -46,6 +50,9 @@ func Test_BuildWireguardSettings(t *testing.T) {
|
|||||||
Addresses: []netip.Prefix{
|
Addresses: []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
|
||||||
},
|
},
|
||||||
|
AllowedIPs: []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
|
},
|
||||||
RulePriority: 101,
|
RulePriority: 101,
|
||||||
IPv6: boolPtr(false),
|
IPv6: boolPtr(false),
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ func Test_New(t *testing.T) {
|
|||||||
Addresses: []netip.Prefix{
|
Addresses: []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
|
||||||
},
|
},
|
||||||
|
AllowedIPs: []netip.Prefix{
|
||||||
|
allIPv4(),
|
||||||
|
},
|
||||||
FirewallMark: 100,
|
FirewallMark: 100,
|
||||||
MTU: device.DefaultMTU,
|
MTU: device.DefaultMTU,
|
||||||
IPv6: ptr(false),
|
IPv6: ptr(false),
|
||||||
|
|||||||
@@ -3,11 +3,30 @@ package wireguard
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO add IPv6 route if IPv6 is supported
|
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
|
||||||
|
firewallMark int) (err error) {
|
||||||
|
for _, dst := range destinations {
|
||||||
|
err = w.addRoute(link, dst, firewallMark)
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
|
||||||
|
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
|
||||||
|
"Ignoring and continuing execution; "+
|
||||||
|
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
|
||||||
|
"Full error string: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("adding route for destination %s: %w", dst, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
||||||
firewallMark int) (err error) {
|
firewallMark int) (err error) {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
@@ -103,7 +102,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
return w.netlink.LinkSetDown(link)
|
return w.netlink.LinkSetDown(link)
|
||||||
})
|
})
|
||||||
|
|
||||||
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
|
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
||||||
return
|
return
|
||||||
@@ -111,11 +110,13 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
|
|
||||||
if *w.settings.IPv6 {
|
if *w.settings.IPv6 {
|
||||||
// requires net.ipv6.conf.all.disable_ipv6=0
|
// requires net.ipv6.conf.all.disable_ipv6=0
|
||||||
err = w.setupIPv6(link, &closers)
|
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
|
||||||
|
w.settings.FirewallMark, unix.AF_INET6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("setting up IPv6: %w", err)
|
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleCleanup, err := w.addRule(w.settings.RulePriority,
|
ruleCleanup, err := w.addRule(w.settings.RulePriority,
|
||||||
@@ -132,31 +133,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
waitError <- waitAndCleanup()
|
waitError <- waitAndCleanup()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wireguard) setupIPv6(link netlink.Link, closers *closers) (err error) {
|
|
||||||
// requires net.ipv6.conf.all.disable_ipv6=0
|
|
||||||
err = w.addRoute(link, allIPv6(), w.settings.FirewallMark)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "permission denied") {
|
|
||||||
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
|
|
||||||
"Ignoring and continuing execution; "+
|
|
||||||
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
|
|
||||||
"Full error string: %s", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("%w: %s", ErrRouteAdd, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleCleanup6, ruleErr := w.addRule(
|
|
||||||
w.settings.RulePriority, w.settings.FirewallMark,
|
|
||||||
unix.AF_INET6)
|
|
||||||
if ruleErr != nil {
|
|
||||||
return fmt.Errorf("adding IPv6 rule: %w", ruleErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type waitAndCleanupFunc func() error
|
type waitAndCleanupFunc func() error
|
||||||
|
|
||||||
func setupKernelSpace(ctx context.Context,
|
func setupKernelSpace(ctx context.Context,
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ type Settings struct {
|
|||||||
// Addresses assigned to the client.
|
// Addresses assigned to the client.
|
||||||
// Note IPv6 addresses are ignored if IPv6 is not supported.
|
// Note IPv6 addresses are ignored if IPv6 is not supported.
|
||||||
Addresses []netip.Prefix
|
Addresses []netip.Prefix
|
||||||
|
// AllowedIPs is the IP networks to be routed through
|
||||||
|
// the Wireguard interface.
|
||||||
|
// Note IPv6 addresses are ignored if IPv6 is not supported.
|
||||||
|
AllowedIPs []netip.Prefix
|
||||||
// FirewallMark to be used in routing tables and IP rules.
|
// FirewallMark to be used in routing tables and IP rules.
|
||||||
// It defaults to 51820 if left to 0.
|
// It defaults to 51820 if left to 0.
|
||||||
FirewallMark int
|
FirewallMark int
|
||||||
@@ -68,6 +72,13 @@ func (s *Settings) SetDefaults() {
|
|||||||
s.IPv6 = &ipv6
|
s.IPv6 = &ipv6
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(s.AllowedIPs) == 0 {
|
||||||
|
s.AllowedIPs = append(s.AllowedIPs, allIPv4())
|
||||||
|
if *s.IPv6 {
|
||||||
|
s.AllowedIPs = append(s.AllowedIPs, allIPv6())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if s.Implementation == "" {
|
if s.Implementation == "" {
|
||||||
const defaultImplementation = "auto"
|
const defaultImplementation = "auto"
|
||||||
s.Implementation = defaultImplementation
|
s.Implementation = defaultImplementation
|
||||||
@@ -85,6 +96,9 @@ var (
|
|||||||
ErrEndpointPortMissing = errors.New("endpoint port is missing")
|
ErrEndpointPortMissing = errors.New("endpoint port is missing")
|
||||||
ErrAddressMissing = errors.New("interface address is missing")
|
ErrAddressMissing = errors.New("interface address is missing")
|
||||||
ErrAddressNotValid = errors.New("interface address is not valid")
|
ErrAddressNotValid = errors.New("interface address is not valid")
|
||||||
|
ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
|
||||||
|
ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
|
||||||
|
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
|
||||||
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
||||||
ErrMTUMissing = errors.New("MTU is missing")
|
ErrMTUMissing = errors.New("MTU is missing")
|
||||||
ErrImplementationInvalid = errors.New("invalid implementation")
|
ErrImplementationInvalid = errors.New("invalid implementation")
|
||||||
@@ -132,6 +146,20 @@ func (s *Settings) Check() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(s.AllowedIPs) == 0 {
|
||||||
|
return fmt.Errorf("%w", ErrAllowedIPsMissing)
|
||||||
|
}
|
||||||
|
for i, allowedIP := range s.AllowedIPs {
|
||||||
|
switch {
|
||||||
|
case !allowedIP.IsValid():
|
||||||
|
return fmt.Errorf("%w: for allowed IP %d of %d",
|
||||||
|
ErrAllowedIPNotValid, i+1, len(s.AllowedIPs))
|
||||||
|
case allowedIP.Addr().Is6() && !*s.IPv6:
|
||||||
|
return fmt.Errorf("%w: for allowed IP %s",
|
||||||
|
ErrAllowedIPv6NotSupported, allowedIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if s.FirewallMark == 0 {
|
if s.FirewallMark == 0 {
|
||||||
return fmt.Errorf("%w", ErrFirewallMarkMissing)
|
return fmt.Errorf("%w", ErrFirewallMarkMissing)
|
||||||
}
|
}
|
||||||
@@ -247,5 +275,16 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(s.AllowedIPs) > 0 {
|
||||||
|
lines = append(lines, fieldPrefix+"Allowed IPs:")
|
||||||
|
for i, allowedIP := range s.AllowedIPs {
|
||||||
|
prefix := fieldPrefix
|
||||||
|
if i == len(s.AllowedIPs)-1 {
|
||||||
|
prefix = lastFieldPrefix
|
||||||
|
}
|
||||||
|
lines = append(lines, indent+prefix+allowedIP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,6 +21,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
|
|||||||
expected: Settings{
|
expected: Settings{
|
||||||
InterfaceName: "wg0",
|
InterfaceName: "wg0",
|
||||||
FirewallMark: 51820,
|
FirewallMark: 51820,
|
||||||
|
AllowedIPs: []netip.Prefix{allIPv4()},
|
||||||
MTU: device.DefaultMTU,
|
MTU: device.DefaultMTU,
|
||||||
IPv6: ptr(false),
|
IPv6: ptr(false),
|
||||||
Implementation: "auto",
|
Implementation: "auto",
|
||||||
@@ -36,6 +35,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
|
|||||||
InterfaceName: "wg0",
|
InterfaceName: "wg0",
|
||||||
FirewallMark: 51820,
|
FirewallMark: 51820,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
AllowedIPs: []netip.Prefix{allIPv4()},
|
||||||
MTU: device.DefaultMTU,
|
MTU: device.DefaultMTU,
|
||||||
IPv6: ptr(false),
|
IPv6: ptr(false),
|
||||||
Implementation: "auto",
|
Implementation: "auto",
|
||||||
@@ -46,6 +46,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
|
|||||||
InterfaceName: "wg1",
|
InterfaceName: "wg1",
|
||||||
FirewallMark: 999,
|
FirewallMark: 999,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999),
|
||||||
|
AllowedIPs: []netip.Prefix{allIPv4()},
|
||||||
MTU: device.DefaultMTU,
|
MTU: device.DefaultMTU,
|
||||||
IPv6: ptr(true),
|
IPv6: ptr(true),
|
||||||
Implementation: "userspace",
|
Implementation: "userspace",
|
||||||
@@ -54,6 +55,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
|
|||||||
InterfaceName: "wg1",
|
InterfaceName: "wg1",
|
||||||
FirewallMark: 999,
|
FirewallMark: 999,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999),
|
||||||
|
AllowedIPs: []netip.Prefix{allIPv4()},
|
||||||
MTU: device.DefaultMTU,
|
MTU: device.DefaultMTU,
|
||||||
IPv6: ptr(true),
|
IPv6: ptr(true),
|
||||||
Implementation: "userspace",
|
Implementation: "userspace",
|
||||||
@@ -83,36 +85,42 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
settings Settings
|
settings Settings
|
||||||
err error
|
errWrapped error
|
||||||
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"empty settings": {
|
"empty settings": {
|
||||||
err: errors.New("invalid interface name: "),
|
errWrapped: ErrInterfaceNameInvalid,
|
||||||
|
errMessage: "invalid interface name: ",
|
||||||
},
|
},
|
||||||
"bad interface name": {
|
"bad interface name": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
InterfaceName: "$H1T",
|
InterfaceName: "$H1T",
|
||||||
},
|
},
|
||||||
err: errors.New("invalid interface name: $H1T"),
|
errWrapped: ErrInterfaceNameInvalid,
|
||||||
|
errMessage: "invalid interface name: $H1T",
|
||||||
},
|
},
|
||||||
"empty private key": {
|
"empty private key": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
InterfaceName: "wg0",
|
InterfaceName: "wg0",
|
||||||
},
|
},
|
||||||
err: ErrPrivateKeyMissing,
|
errWrapped: ErrPrivateKeyMissing,
|
||||||
|
errMessage: "private key is missing",
|
||||||
},
|
},
|
||||||
"bad private key": {
|
"bad private key": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
InterfaceName: "wg0",
|
InterfaceName: "wg0",
|
||||||
PrivateKey: "bad key",
|
PrivateKey: "bad key",
|
||||||
},
|
},
|
||||||
err: ErrPrivateKeyInvalid,
|
errWrapped: ErrPrivateKeyInvalid,
|
||||||
|
errMessage: "cannot parse private key",
|
||||||
},
|
},
|
||||||
"empty public key": {
|
"empty public key": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
InterfaceName: "wg0",
|
InterfaceName: "wg0",
|
||||||
PrivateKey: validKey1,
|
PrivateKey: validKey1,
|
||||||
},
|
},
|
||||||
err: ErrPublicKeyMissing,
|
errWrapped: ErrPublicKeyMissing,
|
||||||
|
errMessage: "public key is missing",
|
||||||
},
|
},
|
||||||
"bad public key": {
|
"bad public key": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -120,7 +128,8 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PrivateKey: validKey1,
|
PrivateKey: validKey1,
|
||||||
PublicKey: "bad key",
|
PublicKey: "bad key",
|
||||||
},
|
},
|
||||||
err: errors.New("cannot parse public key: bad key"),
|
errWrapped: ErrPublicKeyInvalid,
|
||||||
|
errMessage: "cannot parse public key: bad key",
|
||||||
},
|
},
|
||||||
"bad preshared key": {
|
"bad preshared key": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -129,7 +138,8 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PublicKey: validKey2,
|
PublicKey: validKey2,
|
||||||
PreSharedKey: "bad key",
|
PreSharedKey: "bad key",
|
||||||
},
|
},
|
||||||
err: errors.New("cannot parse pre-shared key"),
|
errWrapped: ErrPreSharedKeyInvalid,
|
||||||
|
errMessage: "cannot parse pre-shared key",
|
||||||
},
|
},
|
||||||
"invalid endpoint address": {
|
"invalid endpoint address": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -137,7 +147,8 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PrivateKey: validKey1,
|
PrivateKey: validKey1,
|
||||||
PublicKey: validKey2,
|
PublicKey: validKey2,
|
||||||
},
|
},
|
||||||
err: ErrEndpointAddrMissing,
|
errWrapped: ErrEndpointAddrMissing,
|
||||||
|
errMessage: "endpoint address is missing",
|
||||||
},
|
},
|
||||||
"zero endpoint port": {
|
"zero endpoint port": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -146,7 +157,8 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PublicKey: validKey2,
|
PublicKey: validKey2,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 0),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 0),
|
||||||
},
|
},
|
||||||
err: ErrEndpointPortMissing,
|
errWrapped: ErrEndpointPortMissing,
|
||||||
|
errMessage: "endpoint port is missing",
|
||||||
},
|
},
|
||||||
"no address": {
|
"no address": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -155,7 +167,8 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PublicKey: validKey2,
|
PublicKey: validKey2,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
},
|
},
|
||||||
err: ErrAddressMissing,
|
errWrapped: ErrAddressMissing,
|
||||||
|
errMessage: "interface address is missing",
|
||||||
},
|
},
|
||||||
"invalid address": {
|
"invalid address": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -165,7 +178,53 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
Addresses: []netip.Prefix{{}},
|
Addresses: []netip.Prefix{{}},
|
||||||
},
|
},
|
||||||
err: errors.New("interface address is not valid: for address 1 of 1"),
|
errWrapped: ErrAddressNotValid,
|
||||||
|
errMessage: "interface address is not valid: for address 1 of 1",
|
||||||
|
},
|
||||||
|
|
||||||
|
"no allowed IP": {
|
||||||
|
settings: Settings{
|
||||||
|
InterfaceName: "wg0",
|
||||||
|
PrivateKey: validKey1,
|
||||||
|
PublicKey: validKey2,
|
||||||
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
Addresses: []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errWrapped: ErrAllowedIPsMissing,
|
||||||
|
errMessage: "allowed IPs are missing",
|
||||||
|
},
|
||||||
|
"invalid allowed IP": {
|
||||||
|
settings: Settings{
|
||||||
|
InterfaceName: "wg0",
|
||||||
|
PrivateKey: validKey1,
|
||||||
|
PublicKey: validKey2,
|
||||||
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
Addresses: []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24),
|
||||||
|
},
|
||||||
|
AllowedIPs: []netip.Prefix{{}},
|
||||||
|
},
|
||||||
|
errWrapped: ErrAllowedIPNotValid,
|
||||||
|
errMessage: "allowed IP is not valid: for allowed IP 1 of 1",
|
||||||
|
},
|
||||||
|
"ipv6 allowed IP": {
|
||||||
|
settings: Settings{
|
||||||
|
InterfaceName: "wg0",
|
||||||
|
PrivateKey: validKey1,
|
||||||
|
PublicKey: validKey2,
|
||||||
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
Addresses: []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24),
|
||||||
|
},
|
||||||
|
AllowedIPs: []netip.Prefix{
|
||||||
|
allIPv6(),
|
||||||
|
},
|
||||||
|
IPv6: ptrTo(false),
|
||||||
|
},
|
||||||
|
errWrapped: ErrAllowedIPv6NotSupported,
|
||||||
|
errMessage: "allowed IPv6 address not supported: for allowed IP ::/0",
|
||||||
},
|
},
|
||||||
"zero firewall mark": {
|
"zero firewall mark": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -173,11 +232,13 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PrivateKey: validKey1,
|
PrivateKey: validKey1,
|
||||||
PublicKey: validKey2,
|
PublicKey: validKey2,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
AllowedIPs: []netip.Prefix{allIPv4()},
|
||||||
Addresses: []netip.Prefix{
|
Addresses: []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
err: ErrFirewallMarkMissing,
|
errWrapped: ErrFirewallMarkMissing,
|
||||||
|
errMessage: "firewall mark is missing",
|
||||||
},
|
},
|
||||||
"missing_MTU": {
|
"missing_MTU": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -185,12 +246,14 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PrivateKey: validKey1,
|
PrivateKey: validKey1,
|
||||||
PublicKey: validKey2,
|
PublicKey: validKey2,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
AllowedIPs: []netip.Prefix{allIPv4()},
|
||||||
Addresses: []netip.Prefix{
|
Addresses: []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
||||||
},
|
},
|
||||||
FirewallMark: 999,
|
FirewallMark: 999,
|
||||||
},
|
},
|
||||||
err: ErrMTUMissing,
|
errWrapped: ErrMTUMissing,
|
||||||
|
errMessage: "MTU is missing",
|
||||||
},
|
},
|
||||||
"invalid implementation": {
|
"invalid implementation": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -198,6 +261,7 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PrivateKey: validKey1,
|
PrivateKey: validKey1,
|
||||||
PublicKey: validKey2,
|
PublicKey: validKey2,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
AllowedIPs: []netip.Prefix{allIPv4()},
|
||||||
Addresses: []netip.Prefix{
|
Addresses: []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
||||||
},
|
},
|
||||||
@@ -205,7 +269,8 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
MTU: 1420,
|
MTU: 1420,
|
||||||
Implementation: "x",
|
Implementation: "x",
|
||||||
},
|
},
|
||||||
err: errors.New("invalid implementation: x"),
|
errWrapped: ErrImplementationInvalid,
|
||||||
|
errMessage: "invalid implementation: x",
|
||||||
},
|
},
|
||||||
"all valid": {
|
"all valid": {
|
||||||
settings: Settings{
|
settings: Settings{
|
||||||
@@ -213,11 +278,15 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
PrivateKey: validKey1,
|
PrivateKey: validKey1,
|
||||||
PublicKey: validKey2,
|
PublicKey: validKey2,
|
||||||
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
|
||||||
|
AllowedIPs: []netip.Prefix{
|
||||||
|
allIPv6(),
|
||||||
|
},
|
||||||
Addresses: []netip.Prefix{
|
Addresses: []netip.Prefix{
|
||||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
||||||
},
|
},
|
||||||
FirewallMark: 999,
|
FirewallMark: 999,
|
||||||
MTU: 1420,
|
MTU: 1420,
|
||||||
|
IPv6: ptrTo(true),
|
||||||
Implementation: "userspace",
|
Implementation: "userspace",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -230,11 +299,9 @@ func Test_Settings_Check(t *testing.T) {
|
|||||||
|
|
||||||
err := testCase.settings.Check()
|
err := testCase.settings.Check()
|
||||||
|
|
||||||
if testCase.err != nil {
|
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||||
require.Error(t, err)
|
if testCase.errWrapped != nil {
|
||||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user