chore(wireguard): use netip.AddrPort instead of *net.UDPAddr

This commit is contained in:
Quentin McGaw
2023-05-20 20:05:40 +00:00
parent 0a29337c3b
commit 86ec75722a
7 changed files with 42 additions and 99 deletions

View File

@@ -3,7 +3,6 @@ package wireguard
import (
"errors"
"fmt"
"net"
"net/netip"
"regexp"
"strings"
@@ -22,7 +21,7 @@ type Settings struct {
// Pre shared key in base 64 format
PreSharedKey string
// Wireguard server endpoint to connect to.
Endpoint *net.UDPAddr
Endpoint netip.AddrPort
// Addresses assigned to the client.
// Note IPv6 addresses are ignored if IPv6 is not supported.
Addresses []netip.Prefix
@@ -46,9 +45,9 @@ func (s *Settings) SetDefaults() {
s.InterfaceName = defaultInterfaceName
}
if s.Endpoint != nil && s.Endpoint.Port == 0 {
if s.Endpoint.IsValid() && s.Endpoint.Port() == 0 {
const defaultPort = 51820
s.Endpoint.Port = defaultPort
s.Endpoint = netip.AddrPortFrom(s.Endpoint.Addr(), defaultPort)
}
if s.FirewallMark == 0 {
@@ -74,8 +73,7 @@ var (
ErrPublicKeyMissing = errors.New("public key is missing")
ErrPublicKeyInvalid = errors.New("cannot parse public key")
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
ErrEndpointMissing = errors.New("endpoint is missing")
ErrEndpointIPMissing = errors.New("endpoint IP is missing")
ErrEndpointAddrMissing = errors.New("endpoint address is missing")
ErrEndpointPortMissing = errors.New("endpoint port is missing")
ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNotValid = errors.New("interface address is not valid")
@@ -109,11 +107,9 @@ func (s *Settings) Check() (err error) {
}
switch {
case s.Endpoint == nil:
return fmt.Errorf("%w", ErrEndpointMissing)
case len(s.Endpoint.IP) == 0:
return fmt.Errorf("%w", ErrEndpointIPMissing)
case s.Endpoint.Port == 0:
case !s.Endpoint.Addr().IsValid():
return fmt.Errorf("%w", ErrEndpointAddrMissing)
case s.Endpoint.Port() == 0:
return fmt.Errorf("%w", ErrEndpointPortMissing)
}
@@ -198,7 +194,7 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
lines = append(lines, fieldPrefix+"Pre shared key: "+isSet)
endpointStr := notSet
if s.Endpoint != nil {
if s.Endpoint.Addr().IsValid() {
endpointStr = s.Endpoint.String()
}
lines = append(lines, fieldPrefix+"Endpoint: "+endpointStr)