diff --git a/internal/configuration/settings/portforward.go b/internal/configuration/settings/portforward.go index cc4c4fd6..a6b60c9c 100644 --- a/internal/configuration/settings/portforward.go +++ b/internal/configuration/settings/portforward.go @@ -40,7 +40,10 @@ func (p PortForwarding) validate(vpnProvider string) (err error) { if *p.Provider != "" { providerSelected = *p.Provider } - validProviders := []string{providers.PrivateInternetAccess} + validProviders := []string{ + providers.PrivateInternetAccess, + providers.Protonvpn, + } if err = validate.IsOneOf(providerSelected, validProviders...); err != nil { return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err) } diff --git a/internal/configuration/settings/wireguardselection.go b/internal/configuration/settings/wireguardselection.go index 20e9ff04..bf061f99 100644 --- a/internal/configuration/settings/wireguardselection.go +++ b/internal/configuration/settings/wireguardselection.go @@ -16,7 +16,7 @@ type WireguardSelection struct { // It is only used with VPN providers generating Wireguard // configurations specific to each server and user. // To indicate it should not be used, it should be set - // to netaddr.IPv4Unspecified(). It can never be the zero value + // to netip.IPv4Unspecified(). It can never be the zero value // in the internal state. EndpointIP netip.Addr `json:"endpoint_ip"` // EndpointPort is a the server port to use for the VPN server. diff --git a/internal/natpmp/checks.go b/internal/natpmp/checks.go new file mode 100644 index 00000000..4ac4b3f2 --- /dev/null +++ b/internal/natpmp/checks.go @@ -0,0 +1,94 @@ +package natpmp + +import ( + "encoding/binary" + "errors" + "fmt" +) + +var ( + ErrRequestSizeTooSmall = errors.New("message size is too small") +) + +func checkRequest(request []byte) (err error) { + const minMessageSize = 2 // version number + operation code + if len(request) < minMessageSize { + return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)", + ErrRequestSizeTooSmall, minMessageSize, len(request)) + } + + return nil +} + +var ( + ErrResponseSizeTooSmall = errors.New("response size is too small") + ErrResponseSizeUnexpected = errors.New("response size is unexpected") + ErrProtocolVersionUnknown = errors.New("protocol version is unknown") + ErrOperationCodeUnexpected = errors.New("operation code is unexpected") +) + +func checkResponse(response []byte, expectedOperationCode byte, + expectedResponseSize uint) (err error) { + const minResponseSize = 4 + if len(response) < minResponseSize { + return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)", + ErrResponseSizeTooSmall, minResponseSize, len(response)) + } + + if len(response) != int(expectedResponseSize) { + return fmt.Errorf("%w: expected %d bytes and got %d byte(s)", + ErrResponseSizeUnexpected, expectedResponseSize, len(response)) + } + + protocolVersion := response[0] + if protocolVersion != 0 { + return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion) + } + + operationCode := response[1] + if operationCode != expectedOperationCode { + return fmt.Errorf("%w: expected 0x%x and got 0x%x", + ErrOperationCodeUnexpected, expectedOperationCode, operationCode) + } + + resultCode := binary.BigEndian.Uint16(response[2:4]) + err = checkResultCode(resultCode) + if err != nil { + return fmt.Errorf("result code: %w", err) + } + + return nil +} + +var ( + ErrVersionNotSupported = errors.New("version is not supported") + ErrNotAuthorized = errors.New("not authorized") + ErrNetworkFailure = errors.New("network failure") + ErrOutOfResources = errors.New("out of resources") + ErrOperationCodeNotSupported = errors.New("operation code is not supported") + ErrResultCodeUnknown = errors.New("result code is unknown") +) + +// checkResultCode checks the result code and returns an error +// if the result code is not a success (0). +// See https://www.ietf.org/rfc/rfc6886.html#section-3.5 +// +//nolint:gomnd +func checkResultCode(resultCode uint16) (err error) { + switch resultCode { + case 0: + return nil + case 1: + return fmt.Errorf("%w", ErrVersionNotSupported) + case 2: + return fmt.Errorf("%w", ErrNotAuthorized) + case 3: + return fmt.Errorf("%w", ErrNetworkFailure) + case 4: + return fmt.Errorf("%w", ErrOutOfResources) + case 5: + return fmt.Errorf("%w", ErrOperationCodeNotSupported) + default: + return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode) + } +} diff --git a/internal/natpmp/checks_test.go b/internal/natpmp/checks_test.go new file mode 100644 index 00000000..096b82e3 --- /dev/null +++ b/internal/natpmp/checks_test.go @@ -0,0 +1,161 @@ +package natpmp + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_checkRequest(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + request []byte + err error + errMessage string + }{ + "too_short": { + request: []byte{1}, + err: ErrRequestSizeTooSmall, + errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)", + }, + "success": { + request: []byte{0, 0}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := checkRequest(testCase.request) + + assert.ErrorIs(t, err, testCase.err) + if testCase.err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_checkResponse(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + response []byte + expectedOperationCode byte + expectedResponseSize uint + err error + errMessage string + }{ + "too_short": { + response: []byte{1}, + err: ErrResponseSizeTooSmall, + errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)", + }, + "size_mismatch": { + response: []byte{0, 0, 0, 0}, + expectedResponseSize: 5, + err: ErrResponseSizeUnexpected, + errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)", + }, + "protocol_unknown": { + response: []byte{1, 0, 0, 0}, + expectedResponseSize: 4, + err: ErrProtocolVersionUnknown, + errMessage: "protocol version is unknown: 1", + }, + "operation_code_unexpected": { + response: []byte{0, 2, 0, 0}, + expectedOperationCode: 1, + expectedResponseSize: 4, + err: ErrOperationCodeUnexpected, + errMessage: "operation code is unexpected: expected 0x1 and got 0x2", + }, + "result_code_failure": { + response: []byte{0, 1, 0, 1}, + expectedOperationCode: 1, + expectedResponseSize: 4, + err: ErrVersionNotSupported, + errMessage: "result code: version is not supported", + }, + "success": { + response: []byte{0, 1, 0, 0}, + expectedOperationCode: 1, + expectedResponseSize: 4, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := checkResponse(testCase.response, + testCase.expectedOperationCode, + testCase.expectedResponseSize) + + assert.ErrorIs(t, err, testCase.err) + if testCase.err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_checkResultCode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + resultCode uint16 + err error + errMessage string + }{ + "success": {}, + "version_unsupported": { + resultCode: 1, + err: ErrVersionNotSupported, + errMessage: "version is not supported", + }, + "not_authorized": { + resultCode: 2, + err: ErrNotAuthorized, + errMessage: "not authorized", + }, + "network_failure": { + resultCode: 3, + err: ErrNetworkFailure, + errMessage: "network failure", + }, + "out_of_resources": { + resultCode: 4, + err: ErrOutOfResources, + errMessage: "out of resources", + }, + "unsupported_operation_code": { + resultCode: 5, + err: ErrOperationCodeNotSupported, + errMessage: "operation code is not supported", + }, + "unknown": { + resultCode: 6, + err: ErrResultCodeUnknown, + errMessage: "result code is unknown: 6", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := checkResultCode(testCase.resultCode) + + assert.ErrorIs(t, err, testCase.err) + if testCase.err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/internal/natpmp/externaladdress.go b/internal/natpmp/externaladdress.go new file mode 100644 index 00000000..971a59e1 --- /dev/null +++ b/internal/natpmp/externaladdress.go @@ -0,0 +1,28 @@ +package natpmp + +import ( + "context" + "encoding/binary" + "fmt" + "net/netip" + "time" +) + +// ExternalAddress fetches the duration since the start of epoch and the external +// IPv4 address of the gateway. +// See https://www.ietf.org/rfc/rfc6886.html#section-3.2 +func (c *Client) ExternalAddress(ctx context.Context, gateway netip.Addr) ( + durationSinceStartOfEpoch time.Duration, + externalIPv4Address netip.Addr, err error) { + request := []byte{0, 0} // version 0, operationCode 0 + const responseSize = 12 + response, err := c.rpc(ctx, gateway, request, responseSize) + if err != nil { + return 0, externalIPv4Address, fmt.Errorf("executing remote procedure call: %w", err) + } + + secondsSinceStartOfEpoch := binary.BigEndian.Uint32(response[4:8]) + durationSinceStartOfEpoch = time.Duration(secondsSinceStartOfEpoch) * time.Second + externalIPv4Address = netip.AddrFrom4([4]byte{response[8], response[9], response[10], response[11]}) + return durationSinceStartOfEpoch, externalIPv4Address, nil +} diff --git a/internal/natpmp/externaladdress_test.go b/internal/natpmp/externaladdress_test.go new file mode 100644 index 00000000..cdde3e62 --- /dev/null +++ b/internal/natpmp/externaladdress_test.go @@ -0,0 +1,71 @@ +package natpmp + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_Client_ExternalAddress(t *testing.T) { + t.Parallel() + + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + testCases := map[string]struct { + ctx context.Context + gateway netip.Addr + initialRetry time.Duration + exchanges []udpExchange + durationSinceStartOfEpoch time.Duration + externalIPv4Address netip.Addr + err error + errMessage string + }{ + "failure": { + ctx: canceledCtx, + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + initialRetry: time.Millisecond, + err: context.Canceled, + errMessage: "executing remote procedure call: reading from udp connection: context canceled", + }, + "success": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + initialRetry: time.Millisecond, + exchanges: []udpExchange{{ + request: []byte{0, 0}, + response: []byte{0x0, 0x80, 0x0, 0x0, 0x0, 0x13, 0xf2, 0x4f, 0x49, 0x8c, 0x36, 0x9a}, + }}, + durationSinceStartOfEpoch: time.Duration(0x13f24f) * time.Second, + externalIPv4Address: netip.AddrFrom4([4]byte{0x49, 0x8c, 0x36, 0x9a}), + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + remoteAddress := launchUDPServer(t, testCase.exchanges) + + client := Client{ + serverPort: uint16(remoteAddress.Port), + initialRetry: testCase.initialRetry, + maxRetries: 1, + } + + durationSinceStartOfEpoch, externalIPv4Address, err := + client.ExternalAddress(testCase.ctx, testCase.gateway) + assert.ErrorIs(t, err, testCase.err) + if testCase.err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.durationSinceStartOfEpoch, durationSinceStartOfEpoch) + assert.Equal(t, testCase.externalIPv4Address, externalIPv4Address) + }) + } +} diff --git a/internal/natpmp/helpers_test.go b/internal/natpmp/helpers_test.go new file mode 100644 index 00000000..0aff9f30 --- /dev/null +++ b/internal/natpmp/helpers_test.go @@ -0,0 +1,99 @@ +package natpmp + +import ( + "errors" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type udpExchange struct { + request []byte + response []byte + close bool // to trigger a client error +} + +// launchUDPServer launches an UDP server which will expect +// the requests precised in each of the given exchanges, +// and respond the given corresponding response. +// The server shuts down gracefully at the end of the test. +// The remote address (127.0.0.1:port) is returned, where +// port is dynamically assigned by the OS so calling tests +// can run in parallel. +func launchUDPServer(t *testing.T, exchanges []udpExchange) ( + remoteAddress *net.UDPAddr) { + t.Helper() + + conn, err := net.ListenUDP("udp", nil) + require.NoError(t, err) + + listeningAddress, ok := conn.LocalAddr().(*net.UDPAddr) + require.True(t, ok, "listening address is not UDP") + remoteAddress = &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: listeningAddress.Port, + } + + done := make(chan struct{}) + t.Cleanup(func() { + err := conn.Close() + if !errors.Is(err, net.ErrClosed) { + assert.NoError(t, err) + } + <-done + }) + + var maxBufferSize int + for _, exchange := range exchanges { + if len(exchange.request) > maxBufferSize { + maxBufferSize = len(exchange.request) + } + } + + buffer := make([]byte, maxBufferSize) + + ready := make(chan struct{}) + go func() { + defer close(done) + close(ready) + for _, exchange := range exchanges { + n, clientAddress, err := conn.ReadFromUDP(buffer) + if errors.Is(err, net.ErrClosed) { + t.Error("at least one exchange is missing") + return + } + require.NoError(t, err) + + assert.Equal(t, len(exchange.request), n, + "request message size is unexpected") + if n > 0 { + assert.Equal(t, exchange.request, buffer[:n], + "request message is unexpected") + } + + if exchange.close { + err = conn.Close() + if !errors.Is(err, net.ErrClosed) { + // connection might be already closed by client production code + assert.NoError(t, err) + } + return + } + + _, err = conn.WriteToUDP(exchange.response, clientAddress) + require.NoError(t, err) + } + + err := conn.Close() + if !errors.Is(err, net.ErrClosed) { + // The connection closing can be raced by the test + // cleanup function defined above. + assert.NoError(t, err) + } + }() + <-ready + + return remoteAddress +} diff --git a/internal/natpmp/natpmp.go b/internal/natpmp/natpmp.go new file mode 100644 index 00000000..6de1c54a --- /dev/null +++ b/internal/natpmp/natpmp.go @@ -0,0 +1,26 @@ +package natpmp + +import ( + "time" +) + +// Client is a NAT-PMP protocol client. +type Client struct { + serverPort uint16 + initialRetry time.Duration + maxRetries uint +} + +// New creates a new NAT-PMP client. +func New() (client *Client) { + const natpmpPort = 5351 + + // Parameters described in https://www.ietf.org/rfc/rfc6886.html#section-3.1 + const initialRetry = 250 * time.Millisecond + const maxTries = 9 // 64 seconds + return &Client{ + serverPort: natpmpPort, + initialRetry: initialRetry, + maxRetries: maxTries, + } +} diff --git a/internal/natpmp/natpmp_test.go b/internal/natpmp/natpmp_test.go new file mode 100644 index 00000000..3ce24d86 --- /dev/null +++ b/internal/natpmp/natpmp_test.go @@ -0,0 +1,20 @@ +package natpmp + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_New(t *testing.T) { + t.Parallel() + + expectedClient := &Client{ + serverPort: 5351, + initialRetry: 250 * time.Millisecond, + maxRetries: 9, + } + client := New() + assert.Equal(t, expectedClient, client) +} diff --git a/internal/natpmp/portmapping.go b/internal/natpmp/portmapping.go new file mode 100644 index 00000000..098adf3b --- /dev/null +++ b/internal/natpmp/portmapping.go @@ -0,0 +1,60 @@ +package natpmp + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "net/netip" + "time" +) + +var ( + ErrNetworkProtocolUnknown = errors.New("network protocol is unknown") + ErrLifetimeTooLong = errors.New("lifetime is too long") +) + +// Add or delete a port mapping. To delete a mapping, set both the +// requestedExternalPort and lifetime to 0. +// See https://www.ietf.org/rfc/rfc6886.html#section-3.3 +func (c *Client) AddPortMapping(ctx context.Context, gateway netip.Addr, + protocol string, internalPort, requestedExternalPort uint16, + lifetime time.Duration) (durationSinceStartOfEpoch time.Duration, + assignedInternalPort, assignedExternalPort uint16, assignedLifetime time.Duration, + err error) { + lifetimeSecondsFloat := lifetime.Seconds() + const maxLifetimeSeconds = uint64(^uint32(0)) + if uint64(lifetimeSecondsFloat) > maxLifetimeSeconds { + return 0, 0, 0, 0, fmt.Errorf("%w: %d seconds must at most %d seconds", + ErrLifetimeTooLong, uint64(lifetimeSecondsFloat), maxLifetimeSeconds) + } + const messageSize = 12 + message := make([]byte, messageSize) + message[0] = 0 // Version 0 + switch protocol { + case "udp": + message[1] = 1 // operationCode 1 + case "tcp": + message[1] = 2 // operationCode 2 + default: + return 0, 0, 0, 0, fmt.Errorf("%w: %s", ErrNetworkProtocolUnknown, protocol) + } + // [2:3] are reserved. + binary.BigEndian.PutUint16(message[4:6], internalPort) + binary.BigEndian.PutUint16(message[6:8], requestedExternalPort) + binary.BigEndian.PutUint32(message[8:12], uint32(lifetimeSecondsFloat)) + + const responseSize = 16 + response, err := c.rpc(ctx, gateway, message, responseSize) + if err != nil { + return 0, 0, 0, 0, fmt.Errorf("executing remote procedure call: %w", err) + } + + secondsSinceStartOfEpoch := binary.BigEndian.Uint32(response[4:8]) + durationSinceStartOfEpoch = time.Duration(secondsSinceStartOfEpoch) * time.Second + assignedInternalPort = binary.BigEndian.Uint16(response[8:10]) + assignedExternalPort = binary.BigEndian.Uint16(response[10:12]) + lifetimeInSeconds := binary.BigEndian.Uint32(response[12:16]) + assignedLifetime = time.Duration(lifetimeInSeconds) * time.Second + return durationSinceStartOfEpoch, assignedInternalPort, assignedExternalPort, assignedLifetime, nil +} diff --git a/internal/natpmp/portmapping_test.go b/internal/natpmp/portmapping_test.go new file mode 100644 index 00000000..99721fb2 --- /dev/null +++ b/internal/natpmp/portmapping_test.go @@ -0,0 +1,149 @@ +package natpmp + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_Client_AddPortMapping(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + ctx context.Context + gateway netip.Addr + protocol string + internalPort uint16 + requestedExternalPort uint16 + lifetime time.Duration + initialRetry time.Duration + exchanges []udpExchange + durationSinceStartOfEpoch time.Duration + assignedInternalPort uint16 + assignedExternalPort uint16 + assignedLifetime time.Duration + err error + errMessage string + }{ + "lifetime_too_long": { + lifetime: time.Duration(uint64(^uint32(0))+1) * time.Second, + err: ErrLifetimeTooLong, + errMessage: "lifetime is too long: 4294967296 seconds must at most 4294967295 seconds", + }, + "protocol_unknown": { + lifetime: time.Second, + protocol: "xyz", + err: ErrNetworkProtocolUnknown, + errMessage: "network protocol is unknown: xyz", + }, + "rpc_error": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + protocol: "udp", + internalPort: 123, + requestedExternalPort: 456, + lifetime: 1200 * time.Second, + initialRetry: time.Millisecond, + exchanges: []udpExchange{{close: true}}, + err: ErrConnectionTimeout, + errMessage: "executing remote procedure call: connection timeout: after 1ms", + }, + "add_udp": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + protocol: "udp", + internalPort: 123, + requestedExternalPort: 456, + lifetime: 1200 * time.Second, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x1, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + response: []byte{0x0, 0x81, 0x0, 0x0, 0x0, 0x13, 0xfe, 0xff, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + }}, + durationSinceStartOfEpoch: 0x13feff * time.Second, + assignedInternalPort: 0x7b, + assignedExternalPort: 0x1c8, + assignedLifetime: 0x4b0 * time.Second, + }, + "add_tcp": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + protocol: "tcp", + internalPort: 123, + requestedExternalPort: 456, + lifetime: 1200 * time.Second, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + response: []byte{0x0, 0x82, 0x0, 0x0, 0x0, 0x14, 0x3, 0x21, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + }}, + durationSinceStartOfEpoch: 0x140321 * time.Second, + assignedInternalPort: 0x7b, + assignedExternalPort: 0x1c8, + assignedLifetime: 0x4b0 * time.Second, + }, + "remove_udp": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + protocol: "udp", + internalPort: 123, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x1, 0x0, 0x0, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + response: []byte{0x0, 0x81, 0x0, 0x0, 0x0, 0x14, 0x3, 0xd5, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + }}, + durationSinceStartOfEpoch: 0x1403d5 * time.Second, + assignedInternalPort: 0x7b, + }, + "remove_tcp": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + protocol: "tcp", + internalPort: 123, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + response: []byte{0x0, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + }}, + durationSinceStartOfEpoch: 0x140496 * time.Second, + assignedInternalPort: 0x7b, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + remoteAddress := launchUDPServer(t, testCase.exchanges) + + client := Client{ + serverPort: uint16(remoteAddress.Port), + initialRetry: testCase.initialRetry, + maxRetries: 1, + } + + durationSinceStartOfEpoch, assignedInternalPort, + assignedExternalPort, assignedLifetime, err := + client.AddPortMapping(testCase.ctx, testCase.gateway, + testCase.protocol, testCase.internalPort, + testCase.requestedExternalPort, testCase.lifetime) + + assert.Equal(t, testCase.durationSinceStartOfEpoch, durationSinceStartOfEpoch) + assert.Equal(t, testCase.assignedInternalPort, assignedInternalPort) + assert.Equal(t, testCase.assignedExternalPort, assignedExternalPort) + assert.Equal(t, testCase.assignedLifetime, assignedLifetime) + if testCase.errMessage != "" { + if testCase.err != nil { + assert.ErrorIs(t, err, testCase.err) + } + assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/natpmp/rpc.go b/internal/natpmp/rpc.go new file mode 100644 index 00000000..f1b0be31 --- /dev/null +++ b/internal/natpmp/rpc.go @@ -0,0 +1,123 @@ +package natpmp + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "time" +) + +var ( + ErrGatewayIPUnspecified = errors.New("gateway IP is unspecified") + ErrConnectionTimeout = errors.New("connection timeout") +) + +func (c *Client) rpc(ctx context.Context, gateway netip.Addr, + request []byte, responseSize uint) ( + response []byte, err error) { + if gateway.IsUnspecified() || !gateway.IsValid() { + return nil, fmt.Errorf("%w", ErrGatewayIPUnspecified) + } + + err = checkRequest(request) + if err != nil { + return nil, fmt.Errorf("checking request: %w", err) + } + + gatewayAddress := &net.UDPAddr{ + IP: gateway.AsSlice(), + Port: int(c.serverPort), + } + + connection, err := net.DialUDP("udp", nil, gatewayAddress) + if err != nil { + return nil, fmt.Errorf("dialing udp: %w", err) + } + + ctx, cancel := context.WithCancel(ctx) + endGoroutineDone := make(chan struct{}) + defer func() { + cancel() + <-endGoroutineDone + }() + go func() { + defer close(endGoroutineDone) + // Context is canceled either by the parent context or + // when this function returns. + <-ctx.Done() + closeErr := connection.Close() + if closeErr == nil { + return + } + if err == nil { + err = fmt.Errorf("closing connection: %w", closeErr) + return + } + err = fmt.Errorf("%w; closing connection: %w", err, closeErr) + }() + + const maxResponseSize = 16 + response = make([]byte, maxResponseSize) + + // Retry duration doubles on every network error + // Note it does not double if the source IP mismatches the gateway IP. + retryDuration := c.initialRetry + + var totalRetryDuration time.Duration + + var retryCount uint + for retryCount = 0; retryCount < c.maxRetries; retryCount++ { + deadline := time.Now().Add(retryDuration) + err = connection.SetDeadline(deadline) + if err != nil { + return nil, fmt.Errorf("setting connection deadline: %w", err) + } + + _, err = connection.Write(request) + if err != nil { + return nil, fmt.Errorf("writing to connection: %w", err) + } + + bytesRead, receivedRemoteAddress, err := connection.ReadFromUDP(response) + if err != nil { + if ctx.Err() != nil { + return nil, fmt.Errorf("reading from udp connection: %w", ctx.Err()) + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + totalRetryDuration += retryDuration + retryDuration *= 2 + continue + } + return nil, fmt.Errorf("reading from udp connection: %w", err) + } + + if !receivedRemoteAddress.IP.Equal(gatewayAddress.IP) { + // Upon receiving a response packet, the client MUST check the source IP + // address, and silently discard the packet if the address is not the + // address of the gateway to which the request was sent. + continue + } + + response = response[:bytesRead] + break + } + + if retryCount == c.maxRetries { + return nil, fmt.Errorf("%w: after %s", + ErrConnectionTimeout, totalRetryDuration) + } + + // Opcodes between 0 and 127 are client requests. Opcodes from 128 to + // 255 are corresponding server responses. + const operationCodeMask = 128 + expectedOperationCode := request[1] | operationCodeMask + err = checkResponse(response, expectedOperationCode, responseSize) + if err != nil { + return nil, fmt.Errorf("checking response: %w", err) + } + + return response, nil +} diff --git a/internal/natpmp/rpc_test.go b/internal/natpmp/rpc_test.go new file mode 100644 index 00000000..5f56e8f7 --- /dev/null +++ b/internal/natpmp/rpc_test.go @@ -0,0 +1,166 @@ +package natpmp + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_Client_rpc(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + ctx context.Context + gateway netip.Addr + request []byte + responseSize uint + initialRetry time.Duration + exchanges []udpExchange + expectedResponse []byte + err error + errMessage string + }{ + "gateway_ip_unspecified": { + gateway: netip.IPv6Unspecified(), + request: []byte{0, 0}, + err: ErrGatewayIPUnspecified, + errMessage: "gateway IP is unspecified", + }, + "request_too_small": { + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0}, + initialRetry: time.Second, + err: ErrRequestSizeTooSmall, + errMessage: `checking request: message size is too small: ` + + `need at least 2 bytes and got 1 byte\(s\)`, + }, + "write_error": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0, 0}, + errMessage: `writing to connection: write udp ` + + `127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: ` + + `i/o timeout`, + }, + "call_error": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0, 1}, + initialRetry: time.Millisecond, + exchanges: []udpExchange{ + {request: []byte{0, 1}, close: true}, + }, + err: ErrConnectionTimeout, + errMessage: "connection timeout: after 1ms", + }, + "response_too_small": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0, 0}, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0, 0}, + response: []byte{1}, + }}, + err: ErrResponseSizeTooSmall, + errMessage: `checking response: response size is too small: ` + + `need at least 4 bytes and got 1 byte\(s\)`, + }, + "unexpected_response_size": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + responseSize: 5, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + response: []byte{0, 1, 2, 3}, // size 4 + }}, + err: ErrResponseSizeUnexpected, + errMessage: `checking response: response size is unexpected: ` + + `expected 5 bytes and got 4 byte\(s\)`, + }, + "unknown_protocol_version": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + responseSize: 16, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + response: []byte{0x1, 0x82, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + }}, + err: ErrProtocolVersionUnknown, + errMessage: "checking response: protocol version is unknown: 1", + }, + "unexpected_operation_code": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + responseSize: 16, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + response: []byte{0x0, 0x88, 0x0, 0x0, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + }}, + err: ErrOperationCodeUnexpected, + errMessage: "checking response: operation code is unexpected: expected 0x82 and got 0x88", + }, + "failure_result_code": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + responseSize: 16, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + response: []byte{0x0, 0x82, 0x0, 0x11, 0x0, 0x14, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + }}, + err: ErrResultCodeUnknown, + errMessage: "checking response: result code: result code is unknown: 17", + }, + "success": { + ctx: context.Background(), + gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + responseSize: 16, + initialRetry: time.Second, + exchanges: []udpExchange{{ + request: []byte{0x0, 0x2, 0x0, 0x0, 0x0, 0x7b, 0x1, 0xc8, 0x0, 0x0, 0x4, 0xb0}, + response: []byte{0x0, 0x82, 0x0, 0x0, 0x0, 0x0, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + }}, + expectedResponse: []byte{0x0, 0x82, 0x0, 0x0, 0x0, 0x0, 0x4, 0x96, 0x0, 0x7b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + remoteAddress := launchUDPServer(t, testCase.exchanges) + + client := Client{ + serverPort: uint16(remoteAddress.Port), + initialRetry: testCase.initialRetry, + maxRetries: 1, + } + + response, err := client.rpc(testCase.ctx, testCase.gateway, + testCase.request, testCase.responseSize) + + if testCase.errMessage != "" { + if testCase.err != nil { + assert.ErrorIs(t, err, testCase.err) + } + assert.Regexp(t, "^"+testCase.errMessage+"$", err.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, testCase.expectedResponse, response) + }) + } +} diff --git a/internal/portforward/run.go b/internal/portforward/run.go index 1ff9ef15..bf0ab740 100644 --- a/internal/portforward/run.go +++ b/internal/portforward/run.go @@ -34,8 +34,8 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { portCh <- port // Infinite loop - err = startData.PortForwarder.KeepPortForward(ctx, - startData.Gateway, startData.ServerName) + err = startData.PortForwarder.KeepPortForward(ctx, port, + startData.Gateway, startData.ServerName, l.logger) errorCh <- err }(pfCtx, startData) diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index a6cf844a..93f49b3b 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -91,8 +91,8 @@ var ( ErrPortForwardedExpired = errors.New("port forwarded data expired") ) -func (p *Provider) KeepPortForward(ctx context.Context, - gateway netip.Addr, serverName string) (err error) { +func (p *Provider) KeepPortForward(ctx context.Context, _ uint16, + gateway netip.Addr, serverName string, _ utils.Logger) (err error) { privateIPClient, err := newHTTPClient(serverName) if err != nil { return fmt.Errorf("creating custom HTTP client: %w", err) diff --git a/internal/provider/protonvpn/portforward.go b/internal/provider/protonvpn/portforward.go new file mode 100644 index 00000000..7eab60cc --- /dev/null +++ b/internal/provider/protonvpn/portforward.go @@ -0,0 +1,103 @@ +package protonvpn + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/netip" + "time" + + "github.com/qdm12/gluetun/internal/natpmp" + "github.com/qdm12/gluetun/internal/provider/utils" +) + +var ( + ErrGatewayIPNotValid = errors.New("gateway IP address is not valid") +) + +// PortForward obtains a VPN server side port forwarded from ProtonVPN gateway. +func (p *Provider) PortForward(ctx context.Context, _ *http.Client, + logger utils.Logger, gateway netip.Addr, _ string) ( + port uint16, err error) { + if !gateway.IsValid() { + return 0, fmt.Errorf("%w", ErrGatewayIPNotValid) + } + + client := natpmp.New() + _, externalIPv4Address, err := client.ExternalAddress(ctx, + gateway) + if err != nil { + return 0, fmt.Errorf("getting external IPv4 address: %w", err) + } + + logger.Info("gateway external IPv4 address is " + externalIPv4Address.String()) + networkProtocols := []string{"udp", "tcp"} + const internalPort, externalPort = 0, 0 + const lifetime = 60 * time.Second + for _, networkProtocol := range networkProtocols { + _, assignedInternalPort, assignedExternalPort, assignedLiftetime, err := + client.AddPortMapping(ctx, gateway, networkProtocol, + internalPort, externalPort, lifetime) + if err != nil { + return 0, fmt.Errorf("adding port mapping: %w", err) + } + + if assignedLiftetime != lifetime { + logger.Warn(fmt.Sprintf("assigned lifetime %s differs"+ + " from requested lifetime %s", + assignedLiftetime, lifetime)) + } + + if assignedInternalPort != assignedExternalPort { + logger.Warn(fmt.Sprintf("internal port assigned %d differs"+ + " from external port assigned %d", + assignedInternalPort, assignedExternalPort)) + } + + port = assignedExternalPort + } + + return port, nil +} + +func (p *Provider) KeepPortForward(ctx context.Context, port uint16, + gateway netip.Addr, _ string, logger utils.Logger) (err error) { + client := natpmp.New() + const refreshTimeout = 45 * time.Second + timer := time.NewTimer(refreshTimeout) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + } + + networkProtocols := []string{"udp", "tcp"} + const internalPort = 0 + const lifetime = 60 * time.Second + + for _, networkProtocol := range networkProtocols { + _, assignedInternalPort, assignedExternalPort, assignedLiftetime, err := + client.AddPortMapping(ctx, gateway, networkProtocol, + internalPort, port, lifetime) + if err != nil { + return fmt.Errorf("adding port mapping: %w", err) + } + + if assignedLiftetime != lifetime { + logger.Warn(fmt.Sprintf("assigned lifetime %s differs"+ + " from requested lifetime %s", + assignedLiftetime, lifetime)) + } + + if assignedInternalPort != assignedExternalPort { + logger.Warn(fmt.Sprintf("internal port assigned %d differs"+ + " from external port assigned %d", + assignedInternalPort, assignedExternalPort)) + } + } + + timer.Reset(refreshTimeout) + } +} diff --git a/internal/provider/protonvpn/provider.go b/internal/provider/protonvpn/provider.go index f5b9e706..b0e5d711 100644 --- a/internal/provider/protonvpn/provider.go +++ b/internal/provider/protonvpn/provider.go @@ -7,23 +7,20 @@ import ( "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/provider/common" "github.com/qdm12/gluetun/internal/provider/protonvpn/updater" - "github.com/qdm12/gluetun/internal/provider/utils" ) type Provider struct { storage common.Storage randSource rand.Source - utils.NoPortForwarder common.Fetcher } func New(storage common.Storage, randSource rand.Source, client *http.Client, updaterWarner common.Warner) *Provider { return &Provider{ - storage: storage, - randSource: randSource, - NoPortForwarder: utils.NewNoPortForwarding(providers.Protonvpn), - Fetcher: updater.New(client, updaterWarner), + storage: storage, + randSource: randSource, + Fetcher: updater.New(client, updaterWarner), } } diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 3fe70cce..94584087 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -24,6 +24,6 @@ type PortForwarder interface { PortForward(ctx context.Context, client *http.Client, logger utils.Logger, gateway netip.Addr, serverName string) ( port uint16, err error) - KeepPortForward(ctx context.Context, gateway netip.Addr, - serverName string) (err error) + KeepPortForward(ctx context.Context, port uint16, gateway netip.Addr, + serverName string, _ utils.Logger) (err error) } diff --git a/internal/provider/utils/noportforward.go b/internal/provider/utils/noportforward.go index d70bc21a..2df92e5f 100644 --- a/internal/provider/utils/noportforward.go +++ b/internal/provider/utils/noportforward.go @@ -12,8 +12,8 @@ type NoPortForwarder interface { PortForward(ctx context.Context, client *http.Client, logger Logger, gateway netip.Addr, serverName string) ( port uint16, err error) - KeepPortForward(ctx context.Context, gateway netip.Addr, - serverName string) (err error) + KeepPortForward(ctx context.Context, port uint16, gateway netip.Addr, + serverName string, logger Logger) (err error) } type NoPortForwarding struct { @@ -33,6 +33,7 @@ func (n *NoPortForwarding) PortForward(context.Context, *http.Client, return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) } -func (n *NoPortForwarding) KeepPortForward(context.Context, netip.Addr, string) (err error) { +func (n *NoPortForwarding) KeepPortForward(context.Context, uint16, netip.Addr, + string, Logger) (err error) { return fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) }