feat(protonvpn): port forwarding support with NAT-PMP (#1543)
Co-authored-by: Nicholas Xavier <nicho@nicho.dev>
This commit is contained in:
@@ -40,7 +40,10 @@ func (p PortForwarding) validate(vpnProvider string) (err error) {
|
|||||||
if *p.Provider != "" {
|
if *p.Provider != "" {
|
||||||
providerSelected = *p.Provider
|
providerSelected = *p.Provider
|
||||||
}
|
}
|
||||||
validProviders := []string{providers.PrivateInternetAccess}
|
validProviders := []string{
|
||||||
|
providers.PrivateInternetAccess,
|
||||||
|
providers.Protonvpn,
|
||||||
|
}
|
||||||
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
|
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err)
|
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type WireguardSelection struct {
|
|||||||
// It is only used with VPN providers generating Wireguard
|
// It is only used with VPN providers generating Wireguard
|
||||||
// configurations specific to each server and user.
|
// configurations specific to each server and user.
|
||||||
// To indicate it should not be used, it should be set
|
// To indicate it should not be used, it should be set
|
||||||
// to netaddr.IPv4Unspecified(). It can never be the zero value
|
// to netip.IPv4Unspecified(). It can never be the zero value
|
||||||
// in the internal state.
|
// in the internal state.
|
||||||
EndpointIP netip.Addr `json:"endpoint_ip"`
|
EndpointIP netip.Addr `json:"endpoint_ip"`
|
||||||
// EndpointPort is a the server port to use for the VPN server.
|
// EndpointPort is a the server port to use for the VPN server.
|
||||||
|
|||||||
94
internal/natpmp/checks.go
Normal file
94
internal/natpmp/checks.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrRequestSizeTooSmall = errors.New("message size is too small")
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkRequest(request []byte) (err error) {
|
||||||
|
const minMessageSize = 2 // version number + operation code
|
||||||
|
if len(request) < minMessageSize {
|
||||||
|
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
|
||||||
|
ErrRequestSizeTooSmall, minMessageSize, len(request))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrResponseSizeTooSmall = errors.New("response size is too small")
|
||||||
|
ErrResponseSizeUnexpected = errors.New("response size is unexpected")
|
||||||
|
ErrProtocolVersionUnknown = errors.New("protocol version is unknown")
|
||||||
|
ErrOperationCodeUnexpected = errors.New("operation code is unexpected")
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkResponse(response []byte, expectedOperationCode byte,
|
||||||
|
expectedResponseSize uint) (err error) {
|
||||||
|
const minResponseSize = 4
|
||||||
|
if len(response) < minResponseSize {
|
||||||
|
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
|
||||||
|
ErrResponseSizeTooSmall, minResponseSize, len(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response) != int(expectedResponseSize) {
|
||||||
|
return fmt.Errorf("%w: expected %d bytes and got %d byte(s)",
|
||||||
|
ErrResponseSizeUnexpected, expectedResponseSize, len(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
protocolVersion := response[0]
|
||||||
|
if protocolVersion != 0 {
|
||||||
|
return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
operationCode := response[1]
|
||||||
|
if operationCode != expectedOperationCode {
|
||||||
|
return fmt.Errorf("%w: expected 0x%x and got 0x%x",
|
||||||
|
ErrOperationCodeUnexpected, expectedOperationCode, operationCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCode := binary.BigEndian.Uint16(response[2:4])
|
||||||
|
err = checkResultCode(resultCode)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("result code: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrVersionNotSupported = errors.New("version is not supported")
|
||||||
|
ErrNotAuthorized = errors.New("not authorized")
|
||||||
|
ErrNetworkFailure = errors.New("network failure")
|
||||||
|
ErrOutOfResources = errors.New("out of resources")
|
||||||
|
ErrOperationCodeNotSupported = errors.New("operation code is not supported")
|
||||||
|
ErrResultCodeUnknown = errors.New("result code is unknown")
|
||||||
|
)
|
||||||
|
|
||||||
|
// checkResultCode checks the result code and returns an error
|
||||||
|
// if the result code is not a success (0).
|
||||||
|
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5
|
||||||
|
//
|
||||||
|
//nolint:gomnd
|
||||||
|
func checkResultCode(resultCode uint16) (err error) {
|
||||||
|
switch resultCode {
|
||||||
|
case 0:
|
||||||
|
return nil
|
||||||
|
case 1:
|
||||||
|
return fmt.Errorf("%w", ErrVersionNotSupported)
|
||||||
|
case 2:
|
||||||
|
return fmt.Errorf("%w", ErrNotAuthorized)
|
||||||
|
case 3:
|
||||||
|
return fmt.Errorf("%w", ErrNetworkFailure)
|
||||||
|
case 4:
|
||||||
|
return fmt.Errorf("%w", ErrOutOfResources)
|
||||||
|
case 5:
|
||||||
|
return fmt.Errorf("%w", ErrOperationCodeNotSupported)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
161
internal/natpmp/checks_test.go
Normal file
161
internal/natpmp/checks_test.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_checkRequest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
request []byte
|
||||||
|
err error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"too_short": {
|
||||||
|
request: []byte{1},
|
||||||
|
err: ErrRequestSizeTooSmall,
|
||||||
|
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)",
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
request: []byte{0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := checkRequest(testCase.request)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, testCase.err)
|
||||||
|
if testCase.err != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_checkResponse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
response []byte
|
||||||
|
expectedOperationCode byte
|
||||||
|
expectedResponseSize uint
|
||||||
|
err error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"too_short": {
|
||||||
|
response: []byte{1},
|
||||||
|
err: ErrResponseSizeTooSmall,
|
||||||
|
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)",
|
||||||
|
},
|
||||||
|
"size_mismatch": {
|
||||||
|
response: []byte{0, 0, 0, 0},
|
||||||
|
expectedResponseSize: 5,
|
||||||
|
err: ErrResponseSizeUnexpected,
|
||||||
|
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)",
|
||||||
|
},
|
||||||
|
"protocol_unknown": {
|
||||||
|
response: []byte{1, 0, 0, 0},
|
||||||
|
expectedResponseSize: 4,
|
||||||
|
err: ErrProtocolVersionUnknown,
|
||||||
|
errMessage: "protocol version is unknown: 1",
|
||||||
|
},
|
||||||
|
"operation_code_unexpected": {
|
||||||
|
response: []byte{0, 2, 0, 0},
|
||||||
|
expectedOperationCode: 1,
|
||||||
|
expectedResponseSize: 4,
|
||||||
|
err: ErrOperationCodeUnexpected,
|
||||||
|
errMessage: "operation code is unexpected: expected 0x1 and got 0x2",
|
||||||
|
},
|
||||||
|
"result_code_failure": {
|
||||||
|
response: []byte{0, 1, 0, 1},
|
||||||
|
expectedOperationCode: 1,
|
||||||
|
expectedResponseSize: 4,
|
||||||
|
err: ErrVersionNotSupported,
|
||||||
|
errMessage: "result code: version is not supported",
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
response: []byte{0, 1, 0, 0},
|
||||||
|
expectedOperationCode: 1,
|
||||||
|
expectedResponseSize: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := checkResponse(testCase.response,
|
||||||
|
testCase.expectedOperationCode,
|
||||||
|
testCase.expectedResponseSize)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, testCase.err)
|
||||||
|
if testCase.err != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_checkResultCode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
resultCode uint16
|
||||||
|
err error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"success": {},
|
||||||
|
"version_unsupported": {
|
||||||
|
resultCode: 1,
|
||||||
|
err: ErrVersionNotSupported,
|
||||||
|
errMessage: "version is not supported",
|
||||||
|
},
|
||||||
|
"not_authorized": {
|
||||||
|
resultCode: 2,
|
||||||
|
err: ErrNotAuthorized,
|
||||||
|
errMessage: "not authorized",
|
||||||
|
},
|
||||||
|
"network_failure": {
|
||||||
|
resultCode: 3,
|
||||||
|
err: ErrNetworkFailure,
|
||||||
|
errMessage: "network failure",
|
||||||
|
},
|
||||||
|
"out_of_resources": {
|
||||||
|
resultCode: 4,
|
||||||
|
err: ErrOutOfResources,
|
||||||
|
errMessage: "out of resources",
|
||||||
|
},
|
||||||
|
"unsupported_operation_code": {
|
||||||
|
resultCode: 5,
|
||||||
|
err: ErrOperationCodeNotSupported,
|
||||||
|
errMessage: "operation code is not supported",
|
||||||
|
},
|
||||||
|
"unknown": {
|
||||||
|
resultCode: 6,
|
||||||
|
err: ErrResultCodeUnknown,
|
||||||
|
errMessage: "result code is unknown: 6",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := checkResultCode(testCase.resultCode)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, testCase.err)
|
||||||
|
if testCase.err != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
28
internal/natpmp/externaladdress.go
Normal file
28
internal/natpmp/externaladdress.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExternalAddress fetches the duration since the start of epoch and the external
|
||||||
|
// IPv4 address of the gateway.
|
||||||
|
// See https://www.ietf.org/rfc/rfc6886.html#section-3.2
|
||||||
|
func (c *Client) ExternalAddress(ctx context.Context, gateway netip.Addr) (
|
||||||
|
durationSinceStartOfEpoch time.Duration,
|
||||||
|
externalIPv4Address netip.Addr, err error) {
|
||||||
|
request := []byte{0, 0} // version 0, operationCode 0
|
||||||
|
const responseSize = 12
|
||||||
|
response, err := c.rpc(ctx, gateway, request, responseSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, externalIPv4Address, fmt.Errorf("executing remote procedure call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondsSinceStartOfEpoch := binary.BigEndian.Uint32(response[4:8])
|
||||||
|
durationSinceStartOfEpoch = time.Duration(secondsSinceStartOfEpoch) * time.Second
|
||||||
|
externalIPv4Address = netip.AddrFrom4([4]byte{response[8], response[9], response[10], response[11]})
|
||||||
|
return durationSinceStartOfEpoch, externalIPv4Address, nil
|
||||||
|
}
|
||||||
71
internal/natpmp/externaladdress_test.go
Normal file
71
internal/natpmp/externaladdress_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Client_ExternalAddress(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
canceledCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
ctx context.Context
|
||||||
|
gateway netip.Addr
|
||||||
|
initialRetry time.Duration
|
||||||
|
exchanges []udpExchange
|
||||||
|
durationSinceStartOfEpoch time.Duration
|
||||||
|
externalIPv4Address netip.Addr
|
||||||
|
err error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"failure": {
|
||||||
|
ctx: canceledCtx,
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
initialRetry: time.Millisecond,
|
||||||
|
err: context.Canceled,
|
||||||
|
errMessage: "executing remote procedure call: reading from udp connection: context canceled",
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
initialRetry: time.Millisecond,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0, 0},
|
||||||
|
response: []byte{0x0, 0x80, 0x0, 0x0, 0x0, 0x13, 0xf2, 0x4f, 0x49, 0x8c, 0x36, 0x9a},
|
||||||
|
}},
|
||||||
|
durationSinceStartOfEpoch: time.Duration(0x13f24f) * time.Second,
|
||||||
|
externalIPv4Address: netip.AddrFrom4([4]byte{0x49, 0x8c, 0x36, 0x9a}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
remoteAddress := launchUDPServer(t, testCase.exchanges)
|
||||||
|
|
||||||
|
client := Client{
|
||||||
|
serverPort: uint16(remoteAddress.Port),
|
||||||
|
initialRetry: testCase.initialRetry,
|
||||||
|
maxRetries: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
durationSinceStartOfEpoch, externalIPv4Address, err :=
|
||||||
|
client.ExternalAddress(testCase.ctx, testCase.gateway)
|
||||||
|
assert.ErrorIs(t, err, testCase.err)
|
||||||
|
if testCase.err != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
assert.Equal(t, testCase.durationSinceStartOfEpoch, durationSinceStartOfEpoch)
|
||||||
|
assert.Equal(t, testCase.externalIPv4Address, externalIPv4Address)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
99
internal/natpmp/helpers_test.go
Normal file
99
internal/natpmp/helpers_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpExchange struct {
|
||||||
|
request []byte
|
||||||
|
response []byte
|
||||||
|
close bool // to trigger a client error
|
||||||
|
}
|
||||||
|
|
||||||
|
// launchUDPServer launches an UDP server which will expect
|
||||||
|
// the requests precised in each of the given exchanges,
|
||||||
|
// and respond the given corresponding response.
|
||||||
|
// The server shuts down gracefully at the end of the test.
|
||||||
|
// The remote address (127.0.0.1:port) is returned, where
|
||||||
|
// port is dynamically assigned by the OS so calling tests
|
||||||
|
// can run in parallel.
|
||||||
|
func launchUDPServer(t *testing.T, exchanges []udpExchange) (
|
||||||
|
remoteAddress *net.UDPAddr) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP("udp", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
listeningAddress, ok := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
require.True(t, ok, "listening address is not UDP")
|
||||||
|
remoteAddress = &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: listeningAddress.Port,
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := conn.Close()
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
<-done
|
||||||
|
})
|
||||||
|
|
||||||
|
var maxBufferSize int
|
||||||
|
for _, exchange := range exchanges {
|
||||||
|
if len(exchange.request) > maxBufferSize {
|
||||||
|
maxBufferSize = len(exchange.request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := make([]byte, maxBufferSize)
|
||||||
|
|
||||||
|
ready := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
close(ready)
|
||||||
|
for _, exchange := range exchanges {
|
||||||
|
n, clientAddress, err := conn.ReadFromUDP(buffer)
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
t.Error("at least one exchange is missing")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, len(exchange.request), n,
|
||||||
|
"request message size is unexpected")
|
||||||
|
if n > 0 {
|
||||||
|
assert.Equal(t, exchange.request, buffer[:n],
|
||||||
|
"request message is unexpected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if exchange.close {
|
||||||
|
err = conn.Close()
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
// connection might be already closed by client production code
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.WriteToUDP(exchange.response, clientAddress)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := conn.Close()
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
// The connection closing can be raced by the test
|
||||||
|
// cleanup function defined above.
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
<-ready
|
||||||
|
|
||||||
|
return remoteAddress
|
||||||
|
}
|
||||||
26
internal/natpmp/natpmp.go
Normal file
26
internal/natpmp/natpmp.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client is a NAT-PMP protocol client.
|
||||||
|
type Client struct {
|
||||||
|
serverPort uint16
|
||||||
|
initialRetry time.Duration
|
||||||
|
maxRetries uint
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new NAT-PMP client.
|
||||||
|
func New() (client *Client) {
|
||||||
|
const natpmpPort = 5351
|
||||||
|
|
||||||
|
// Parameters described in https://www.ietf.org/rfc/rfc6886.html#section-3.1
|
||||||
|
const initialRetry = 250 * time.Millisecond
|
||||||
|
const maxTries = 9 // 64 seconds
|
||||||
|
return &Client{
|
||||||
|
serverPort: natpmpPort,
|
||||||
|
initialRetry: initialRetry,
|
||||||
|
maxRetries: maxTries,
|
||||||
|
}
|
||||||
|
}
|
||||||
20
internal/natpmp/natpmp_test.go
Normal file
20
internal/natpmp/natpmp_test.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_New(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
expectedClient := &Client{
|
||||||
|
serverPort: 5351,
|
||||||
|
initialRetry: 250 * time.Millisecond,
|
||||||
|
maxRetries: 9,
|
||||||
|
}
|
||||||
|
client := New()
|
||||||
|
assert.Equal(t, expectedClient, client)
|
||||||
|
}
|
||||||
60
internal/natpmp/portmapping.go
Normal file
60
internal/natpmp/portmapping.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNetworkProtocolUnknown = errors.New("network protocol is unknown")
|
||||||
|
ErrLifetimeTooLong = errors.New("lifetime is too long")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Add or delete a port mapping. To delete a mapping, set both the
|
||||||
|
// requestedExternalPort and lifetime to 0.
|
||||||
|
// See https://www.ietf.org/rfc/rfc6886.html#section-3.3
|
||||||
|
func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr,
|
||||||
|
protocol string, internalPort, requestedExternalPort uint16,
|
||||||
|
lifetime time.Duration) (durationSinceStartOfEpoch time.Duration,
|
||||||
|
assignedInternalPort, assignedExternalPort uint16, assignedLifetime time.Duration,
|
||||||
|
err error) {
|
||||||
|
lifetimeSecondsFloat := lifetime.Seconds()
|
||||||
|
const maxLifetimeSeconds = uint64(^uint32(0))
|
||||||
|
if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds {
|
||||||
|
return 0, 0, 0, 0, fmt.Errorf("%w: %d seconds must at most %d seconds",
|
||||||
|
ErrLifetimeTooLong, uint64(lifetimeSecondsFloat), maxLifetimeSeconds)
|
||||||
|
}
|
||||||
|
const messageSize = 12
|
||||||
|
message := make([]byte, messageSize)
|
||||||
|
message[0] = 0 // Version 0
|
||||||
|
switch protocol {
|
||||||
|
case "udp":
|
||||||
|
message[1] = 1 // operationCode 1
|
||||||
|
case "tcp":
|
||||||
|
message[1] = 2 // operationCode 2
|
||||||
|
default:
|
||||||
|
return 0, 0, 0, 0, fmt.Errorf("%w: %s", ErrNetworkProtocolUnknown, protocol)
|
||||||
|
}
|
||||||
|
// [2:3] are reserved.
|
||||||
|
binary.BigEndian.PutUint16(message[4:6], internalPort)
|
||||||
|
binary.BigEndian.PutUint16(message[6:8], requestedExternalPort)
|
||||||
|
binary.BigEndian.PutUint32(message[8:12], uint32(lifetimeSecondsFloat))
|
||||||
|
|
||||||
|
const responseSize = 16
|
||||||
|
response, err := c.rpc(ctx, gateway, message, responseSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, 0, 0, fmt.Errorf("executing remote procedure call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondsSinceStartOfEpoch := binary.BigEndian.Uint32(response[4:8])
|
||||||
|
durationSinceStartOfEpoch = time.Duration(secondsSinceStartOfEpoch) * time.Second
|
||||||
|
assignedInternalPort = binary.BigEndian.Uint16(response[8:10])
|
||||||
|
assignedExternalPort = binary.BigEndian.Uint16(response[10:12])
|
||||||
|
lifetimeInSeconds := binary.BigEndian.Uint32(response[12:16])
|
||||||
|
assignedLifetime = time.Duration(lifetimeInSeconds) * time.Second
|
||||||
|
return durationSinceStartOfEpoch, assignedInternalPort, assignedExternalPort, assignedLifetime, nil
|
||||||
|
}
|
||||||
149
internal/natpmp/portmapping_test.go
Normal file
149
internal/natpmp/portmapping_test.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Client_AddPortMapping(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
ctx context.Context
|
||||||
|
gateway netip.Addr
|
||||||
|
protocol string
|
||||||
|
internalPort uint16
|
||||||
|
requestedExternalPort uint16
|
||||||
|
lifetime time.Duration
|
||||||
|
initialRetry time.Duration
|
||||||
|
exchanges []udpExchange
|
||||||
|
durationSinceStartOfEpoch time.Duration
|
||||||
|
assignedInternalPort uint16
|
||||||
|
assignedExternalPort uint16
|
||||||
|
assignedLifetime time.Duration
|
||||||
|
err error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"lifetime_too_long": {
|
||||||
|
lifetime: time.Duration(uint64(^uint32(0))+1) * time.Second,
|
||||||
|
err: ErrLifetimeTooLong,
|
||||||
|
errMessage: "lifetime is too long: 4294967296 seconds must at most 4294967295 seconds",
|
||||||
|
},
|
||||||
|
"protocol_unknown": {
|
||||||
|
lifetime: time.Second,
|
||||||
|
protocol: "xyz",
|
||||||
|
err: ErrNetworkProtocolUnknown,
|
||||||
|
errMessage: "network protocol is unknown: xyz",
|
||||||
|
},
|
||||||
|
"rpc_error": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
protocol: "udp",
|
||||||
|
internalPort: 123,
|
||||||
|
requestedExternalPort: 456,
|
||||||
|
lifetime: 1200 * time.Second,
|
||||||
|
initialRetry: time.Millisecond,
|
||||||
|
exchanges: []udpExchange{{close: true}},
|
||||||
|
err: ErrConnectionTimeout,
|
||||||
|
errMessage: "executing remote procedure call: connection timeout: after 1ms",
|
||||||
|
},
|
||||||
|
"add_udp": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
protocol: "udp",
|
||||||
|
internalPort: 123,
|
||||||
|
requestedExternalPort: 456,
|
||||||
|
lifetime: 1200 * time.Second,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x1, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
response: []byte{0x0, 0x81, 0x0, 0x0, 0x0, 0x13, 0xfe, 0xff, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
}},
|
||||||
|
durationSinceStartOfEpoch: 0x13feff * time.Second,
|
||||||
|
assignedInternalPort: 0x7b,
|
||||||
|
assignedExternalPort: 0x1c8,
|
||||||
|
assignedLifetime: 0x4b0 * time.Second,
|
||||||
|
},
|
||||||
|
"add_tcp": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
protocol: "tcp",
|
||||||
|
internalPort: 123,
|
||||||
|
requestedExternalPort: 456,
|
||||||
|
lifetime: 1200 * time.Second,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
response: []byte{0x0, 0x82, 0x0, 0x0, 0x0, 0x14, 0x3, 0x21, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
}},
|
||||||
|
durationSinceStartOfEpoch: 0x140321 * time.Second,
|
||||||
|
assignedInternalPort: 0x7b,
|
||||||
|
assignedExternalPort: 0x1c8,
|
||||||
|
assignedLifetime: 0x4b0 * time.Second,
|
||||||
|
},
|
||||||
|
"remove_udp": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
protocol: "udp",
|
||||||
|
internalPort: 123,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x1, 0x0, 0x0, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
response: []byte{0x0, 0x81, 0x0, 0x0, 0x0, 0x14, 0x3, 0xd5, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
}},
|
||||||
|
durationSinceStartOfEpoch: 0x1403d5 * time.Second,
|
||||||
|
assignedInternalPort: 0x7b,
|
||||||
|
},
|
||||||
|
"remove_tcp": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
protocol: "tcp",
|
||||||
|
internalPort: 123,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
response: []byte{0x0, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
}},
|
||||||
|
durationSinceStartOfEpoch: 0x140496 * time.Second,
|
||||||
|
assignedInternalPort: 0x7b,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
remoteAddress := launchUDPServer(t, testCase.exchanges)
|
||||||
|
|
||||||
|
client := Client{
|
||||||
|
serverPort: uint16(remoteAddress.Port),
|
||||||
|
initialRetry: testCase.initialRetry,
|
||||||
|
maxRetries: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
durationSinceStartOfEpoch, assignedInternalPort,
|
||||||
|
assignedExternalPort, assignedLifetime, err :=
|
||||||
|
client.AddPortMapping(testCase.ctx, testCase.gateway,
|
||||||
|
testCase.protocol, testCase.internalPort,
|
||||||
|
testCase.requestedExternalPort, testCase.lifetime)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.durationSinceStartOfEpoch, durationSinceStartOfEpoch)
|
||||||
|
assert.Equal(t, testCase.assignedInternalPort, assignedInternalPort)
|
||||||
|
assert.Equal(t, testCase.assignedExternalPort, assignedExternalPort)
|
||||||
|
assert.Equal(t, testCase.assignedLifetime, assignedLifetime)
|
||||||
|
if testCase.errMessage != "" {
|
||||||
|
if testCase.err != nil {
|
||||||
|
assert.ErrorIs(t, err, testCase.err)
|
||||||
|
}
|
||||||
|
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
123
internal/natpmp/rpc.go
Normal file
123
internal/natpmp/rpc.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrGatewayIPUnspecified = errors.New("gateway IP is unspecified")
|
||||||
|
ErrConnectionTimeout = errors.New("connection timeout")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
|
||||||
|
request []byte, responseSize uint) (
|
||||||
|
response []byte, err error) {
|
||||||
|
if gateway.IsUnspecified() || !gateway.IsValid() {
|
||||||
|
return nil, fmt.Errorf("%w", ErrGatewayIPUnspecified)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = checkRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("checking request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayAddress := &net.UDPAddr{
|
||||||
|
IP: gateway.AsSlice(),
|
||||||
|
Port: int(c.serverPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
connection, err := net.DialUDP("udp", nil, gatewayAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing udp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
endGoroutineDone := make(chan struct{})
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
<-endGoroutineDone
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer close(endGoroutineDone)
|
||||||
|
// Context is canceled either by the parent context or
|
||||||
|
// when this function returns.
|
||||||
|
<-ctx.Done()
|
||||||
|
closeErr := connection.Close()
|
||||||
|
if closeErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
err = fmt.Errorf("closing connection: %w", closeErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%w; closing connection: %w", err, closeErr)
|
||||||
|
}()
|
||||||
|
|
||||||
|
const maxResponseSize = 16
|
||||||
|
response = make([]byte, maxResponseSize)
|
||||||
|
|
||||||
|
// Retry duration doubles on every network error
|
||||||
|
// Note it does not double if the source IP mismatches the gateway IP.
|
||||||
|
retryDuration := c.initialRetry
|
||||||
|
|
||||||
|
var totalRetryDuration time.Duration
|
||||||
|
|
||||||
|
var retryCount uint
|
||||||
|
for retryCount = 0; retryCount < c.maxRetries; retryCount++ {
|
||||||
|
deadline := time.Now().Add(retryDuration)
|
||||||
|
err = connection.SetDeadline(deadline)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setting connection deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = connection.Write(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("writing to connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bytesRead, receivedRemoteAddress, err := connection.ReadFromUDP(response)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, fmt.Errorf("reading from udp connection: %w", ctx.Err())
|
||||||
|
}
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||||
|
totalRetryDuration += retryDuration
|
||||||
|
retryDuration *= 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("reading from udp connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !receivedRemoteAddress.IP.Equal(gatewayAddress.IP) {
|
||||||
|
// Upon receiving a response packet, the client MUST check the source IP
|
||||||
|
// address, and silently discard the packet if the address is not the
|
||||||
|
// address of the gateway to which the request was sent.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
response = response[:bytesRead]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if retryCount == c.maxRetries {
|
||||||
|
return nil, fmt.Errorf("%w: after %s",
|
||||||
|
ErrConnectionTimeout, totalRetryDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Opcodes between 0 and 127 are client requests. Opcodes from 128 to
|
||||||
|
// 255 are corresponding server responses.
|
||||||
|
const operationCodeMask = 128
|
||||||
|
expectedOperationCode := request[1] | operationCodeMask
|
||||||
|
err = checkResponse(response, expectedOperationCode, responseSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("checking response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
166
internal/natpmp/rpc_test.go
Normal file
166
internal/natpmp/rpc_test.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package natpmp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Client_rpc(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
ctx context.Context
|
||||||
|
gateway netip.Addr
|
||||||
|
request []byte
|
||||||
|
responseSize uint
|
||||||
|
initialRetry time.Duration
|
||||||
|
exchanges []udpExchange
|
||||||
|
expectedResponse []byte
|
||||||
|
err error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"gateway_ip_unspecified": {
|
||||||
|
gateway: netip.IPv6Unspecified(),
|
||||||
|
request: []byte{0, 0},
|
||||||
|
err: ErrGatewayIPUnspecified,
|
||||||
|
errMessage: "gateway IP is unspecified",
|
||||||
|
},
|
||||||
|
"request_too_small": {
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0},
|
||||||
|
initialRetry: time.Second,
|
||||||
|
err: ErrRequestSizeTooSmall,
|
||||||
|
errMessage: `checking request: message size is too small: ` +
|
||||||
|
`need at least 2 bytes and got 1 byte\(s\)`,
|
||||||
|
},
|
||||||
|
"write_error": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0, 0},
|
||||||
|
errMessage: `writing to connection: write udp ` +
|
||||||
|
`127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: ` +
|
||||||
|
`i/o timeout`,
|
||||||
|
},
|
||||||
|
"call_error": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0, 1},
|
||||||
|
initialRetry: time.Millisecond,
|
||||||
|
exchanges: []udpExchange{
|
||||||
|
{request: []byte{0, 1}, close: true},
|
||||||
|
},
|
||||||
|
err: ErrConnectionTimeout,
|
||||||
|
errMessage: "connection timeout: after 1ms",
|
||||||
|
},
|
||||||
|
"response_too_small": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0, 0},
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0, 0},
|
||||||
|
response: []byte{1},
|
||||||
|
}},
|
||||||
|
err: ErrResponseSizeTooSmall,
|
||||||
|
errMessage: `checking response: response size is too small: ` +
|
||||||
|
`need at least 4 bytes and got 1 byte\(s\)`,
|
||||||
|
},
|
||||||
|
"unexpected_response_size": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
responseSize: 5,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
response: []byte{0, 1, 2, 3}, // size 4
|
||||||
|
}},
|
||||||
|
err: ErrResponseSizeUnexpected,
|
||||||
|
errMessage: `checking response: response size is unexpected: ` +
|
||||||
|
`expected 5 bytes and got 4 byte\(s\)`,
|
||||||
|
},
|
||||||
|
"unknown_protocol_version": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
responseSize: 16,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
response: []byte{0x1, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
}},
|
||||||
|
err: ErrProtocolVersionUnknown,
|
||||||
|
errMessage: "checking response: protocol version is unknown: 1",
|
||||||
|
},
|
||||||
|
"unexpected_operation_code": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
responseSize: 16,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
response: []byte{0x0, 0x88, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
}},
|
||||||
|
err: ErrOperationCodeUnexpected,
|
||||||
|
errMessage: "checking response: operation code is unexpected: expected 0x82 and got 0x88",
|
||||||
|
},
|
||||||
|
"failure_result_code": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
responseSize: 16,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
response: []byte{0x0, 0x82, 0x0, 0x11, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
}},
|
||||||
|
err: ErrResultCodeUnknown,
|
||||||
|
errMessage: "checking response: result code: result code is unknown: 17",
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
ctx: context.Background(),
|
||||||
|
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
responseSize: 16,
|
||||||
|
initialRetry: time.Second,
|
||||||
|
exchanges: []udpExchange{{
|
||||||
|
request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0},
|
||||||
|
response: []byte{0x0, 0x82, 0x0, 0x0, 0x0, 0x0, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
}},
|
||||||
|
expectedResponse: []byte{0x0, 0x82, 0x0, 0x0, 0x0, 0x0, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
remoteAddress := launchUDPServer(t, testCase.exchanges)
|
||||||
|
|
||||||
|
client := Client{
|
||||||
|
serverPort: uint16(remoteAddress.Port),
|
||||||
|
initialRetry: testCase.initialRetry,
|
||||||
|
maxRetries: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := client.rpc(testCase.ctx, testCase.gateway,
|
||||||
|
testCase.request, testCase.responseSize)
|
||||||
|
|
||||||
|
if testCase.errMessage != "" {
|
||||||
|
if testCase.err != nil {
|
||||||
|
assert.ErrorIs(t, err, testCase.err)
|
||||||
|
}
|
||||||
|
assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error())
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, testCase.expectedResponse, response)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -34,8 +34,8 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
portCh <- port
|
portCh <- port
|
||||||
|
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
err = startData.PortForwarder.KeepPortForward(ctx,
|
err = startData.PortForwarder.KeepPortForward(ctx, port,
|
||||||
startData.Gateway, startData.ServerName)
|
startData.Gateway, startData.ServerName, l.logger)
|
||||||
errorCh <- err
|
errorCh <- err
|
||||||
}(pfCtx, startData)
|
}(pfCtx, startData)
|
||||||
|
|
||||||
|
|||||||
@@ -91,8 +91,8 @@ var (
|
|||||||
ErrPortForwardedExpired = errors.New("port forwarded data expired")
|
ErrPortForwardedExpired = errors.New("port forwarded data expired")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *Provider) KeepPortForward(ctx context.Context,
|
func (p *Provider) KeepPortForward(ctx context.Context, _ uint16,
|
||||||
gateway netip.Addr, serverName string) (err error) {
|
gateway netip.Addr, serverName string, _ utils.Logger) (err error) {
|
||||||
privateIPClient, err := newHTTPClient(serverName)
|
privateIPClient, err := newHTTPClient(serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating custom HTTP client: %w", err)
|
return fmt.Errorf("creating custom HTTP client: %w", err)
|
||||||
|
|||||||
103
internal/provider/protonvpn/portforward.go
Normal file
103
internal/provider/protonvpn/portforward.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package protonvpn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/natpmp"
|
||||||
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrGatewayIPNotValid = errors.New("gateway IP address is not valid")
|
||||||
|
)
|
||||||
|
|
||||||
|
// PortForward obtains a VPN server side port forwarded from ProtonVPN gateway.
|
||||||
|
func (p *Provider) PortForward(ctx context.Context, _ *http.Client,
|
||||||
|
logger utils.Logger, gateway netip.Addr, _ string) (
|
||||||
|
port uint16, err error) {
|
||||||
|
if !gateway.IsValid() {
|
||||||
|
return 0, fmt.Errorf("%w", ErrGatewayIPNotValid)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := natpmp.New()
|
||||||
|
_, externalIPv4Address, err := client.ExternalAddress(ctx,
|
||||||
|
gateway)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("getting external IPv4 address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("gateway external IPv4 address is " + externalIPv4Address.String())
|
||||||
|
networkProtocols := []string{"udp", "tcp"}
|
||||||
|
const internalPort, externalPort = 0, 0
|
||||||
|
const lifetime = 60 * time.Second
|
||||||
|
for _, networkProtocol := range networkProtocols {
|
||||||
|
_, assignedInternalPort, assignedExternalPort, assignedLiftetime, err :=
|
||||||
|
client.AddPortMapping(ctx, gateway, networkProtocol,
|
||||||
|
internalPort, externalPort, lifetime)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("adding port mapping: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if assignedLiftetime != lifetime {
|
||||||
|
logger.Warn(fmt.Sprintf("assigned lifetime %s differs"+
|
||||||
|
" from requested lifetime %s",
|
||||||
|
assignedLiftetime, lifetime))
|
||||||
|
}
|
||||||
|
|
||||||
|
if assignedInternalPort != assignedExternalPort {
|
||||||
|
logger.Warn(fmt.Sprintf("internal port assigned %d differs"+
|
||||||
|
" from external port assigned %d",
|
||||||
|
assignedInternalPort, assignedExternalPort))
|
||||||
|
}
|
||||||
|
|
||||||
|
port = assignedExternalPort
|
||||||
|
}
|
||||||
|
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) KeepPortForward(ctx context.Context, port uint16,
|
||||||
|
gateway netip.Addr, _ string, logger utils.Logger) (err error) {
|
||||||
|
client := natpmp.New()
|
||||||
|
const refreshTimeout = 45 * time.Second
|
||||||
|
timer := time.NewTimer(refreshTimeout)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
|
||||||
|
networkProtocols := []string{"udp", "tcp"}
|
||||||
|
const internalPort = 0
|
||||||
|
const lifetime = 60 * time.Second
|
||||||
|
|
||||||
|
for _, networkProtocol := range networkProtocols {
|
||||||
|
_, assignedInternalPort, assignedExternalPort, assignedLiftetime, err :=
|
||||||
|
client.AddPortMapping(ctx, gateway, networkProtocol,
|
||||||
|
internalPort, port, lifetime)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("adding port mapping: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if assignedLiftetime != lifetime {
|
||||||
|
logger.Warn(fmt.Sprintf("assigned lifetime %s differs"+
|
||||||
|
" from requested lifetime %s",
|
||||||
|
assignedLiftetime, lifetime))
|
||||||
|
}
|
||||||
|
|
||||||
|
if assignedInternalPort != assignedExternalPort {
|
||||||
|
logger.Warn(fmt.Sprintf("internal port assigned %d differs"+
|
||||||
|
" from external port assigned %d",
|
||||||
|
assignedInternalPort, assignedExternalPort))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
timer.Reset(refreshTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,23 +7,20 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||||
"github.com/qdm12/gluetun/internal/provider/common"
|
"github.com/qdm12/gluetun/internal/provider/common"
|
||||||
"github.com/qdm12/gluetun/internal/provider/protonvpn/updater"
|
"github.com/qdm12/gluetun/internal/provider/protonvpn/updater"
|
||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Provider struct {
|
type Provider struct {
|
||||||
storage common.Storage
|
storage common.Storage
|
||||||
randSource rand.Source
|
randSource rand.Source
|
||||||
utils.NoPortForwarder
|
|
||||||
common.Fetcher
|
common.Fetcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(storage common.Storage, randSource rand.Source,
|
func New(storage common.Storage, randSource rand.Source,
|
||||||
client *http.Client, updaterWarner common.Warner) *Provider {
|
client *http.Client, updaterWarner common.Warner) *Provider {
|
||||||
return &Provider{
|
return &Provider{
|
||||||
storage: storage,
|
storage: storage,
|
||||||
randSource: randSource,
|
randSource: randSource,
|
||||||
NoPortForwarder: utils.NewNoPortForwarding(providers.Protonvpn),
|
Fetcher: updater.New(client, updaterWarner),
|
||||||
Fetcher: updater.New(client, updaterWarner),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,6 @@ type PortForwarder interface {
|
|||||||
PortForward(ctx context.Context, client *http.Client,
|
PortForward(ctx context.Context, client *http.Client,
|
||||||
logger utils.Logger, gateway netip.Addr, serverName string) (
|
logger utils.Logger, gateway netip.Addr, serverName string) (
|
||||||
port uint16, err error)
|
port uint16, err error)
|
||||||
KeepPortForward(ctx context.Context, gateway netip.Addr,
|
KeepPortForward(ctx context.Context, port uint16, gateway netip.Addr,
|
||||||
serverName string) (err error)
|
serverName string, _ utils.Logger) (err error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ type NoPortForwarder interface {
|
|||||||
PortForward(ctx context.Context, client *http.Client,
|
PortForward(ctx context.Context, client *http.Client,
|
||||||
logger Logger, gateway netip.Addr, serverName string) (
|
logger Logger, gateway netip.Addr, serverName string) (
|
||||||
port uint16, err error)
|
port uint16, err error)
|
||||||
KeepPortForward(ctx context.Context, gateway netip.Addr,
|
KeepPortForward(ctx context.Context, port uint16, gateway netip.Addr,
|
||||||
serverName string) (err error)
|
serverName string, logger Logger) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type NoPortForwarding struct {
|
type NoPortForwarding struct {
|
||||||
@@ -33,6 +33,7 @@ func (n *NoPortForwarding) PortForward(context.Context, *http.Client,
|
|||||||
return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
|
return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NoPortForwarding) KeepPortForward(context.Context, netip.Addr, string) (err error) {
|
func (n *NoPortForwarding) KeepPortForward(context.Context, uint16, netip.Addr,
|
||||||
|
string, Logger) (err error) {
|
||||||
return fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
|
return fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user