175 lines
4.7 KiB
Go
175 lines
4.7 KiB
Go
package natpmp
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
ErrGatewayIPUnspecified = errors.New("gateway IP is unspecified")
|
|
ErrConnectionTimeout = errors.New("connection timeout")
|
|
)
|
|
|
|
func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
|
|
request []byte, responseSize uint) (
|
|
response []byte, err error) {
|
|
if gateway.IsUnspecified() || !gateway.IsValid() {
|
|
return nil, fmt.Errorf("%w", ErrGatewayIPUnspecified)
|
|
}
|
|
|
|
err = checkRequest(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("checking request: %w", err)
|
|
}
|
|
|
|
gatewayAddress := &net.UDPAddr{
|
|
IP: gateway.AsSlice(),
|
|
Port: int(c.serverPort),
|
|
}
|
|
|
|
connection, err := net.DialUDP("udp", nil, gatewayAddress)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dialing udp: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
endGoroutineDone := make(chan struct{})
|
|
defer func() {
|
|
cancel()
|
|
<-endGoroutineDone
|
|
}()
|
|
go func() {
|
|
defer close(endGoroutineDone)
|
|
// Context is canceled either by the parent context or
|
|
// when this function returns.
|
|
<-ctx.Done()
|
|
closeErr := connection.Close()
|
|
if closeErr == nil {
|
|
return
|
|
}
|
|
if err == nil {
|
|
err = fmt.Errorf("closing connection: %w", closeErr)
|
|
return
|
|
}
|
|
err = fmt.Errorf("%w; closing connection: %w", err, closeErr)
|
|
}()
|
|
|
|
const maxResponseSize = 16
|
|
response = make([]byte, maxResponseSize)
|
|
|
|
// Connection duration doubles on every network error
|
|
// Note it does not double if the source IP mismatches the gateway IP.
|
|
connectionDuration := c.initialConnectionDuration
|
|
|
|
var retryCount uint
|
|
var failedAttempts []string
|
|
for retryCount = 0; retryCount < c.maxRetries; retryCount++ { //nolint:intrange
|
|
deadline := time.Now().Add(connectionDuration)
|
|
err = connection.SetDeadline(deadline)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("setting connection deadline: %w", err)
|
|
}
|
|
|
|
_, err = connection.Write(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("writing to connection: %w", err)
|
|
}
|
|
|
|
bytesRead, receivedRemoteAddress, err := connection.ReadFromUDP(response)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil, fmt.Errorf("reading from udp connection: %w", ctx.Err())
|
|
}
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
connectionDuration *= 2
|
|
failedAttempts = append(failedAttempts, netErr.Error())
|
|
continue
|
|
}
|
|
return nil, fmt.Errorf("reading from udp connection: %w", err)
|
|
}
|
|
|
|
if !receivedRemoteAddress.IP.Equal(gatewayAddress.IP) {
|
|
// Upon receiving a response packet, the client MUST check the source IP
|
|
// address, and silently discard the packet if the address is not the
|
|
// address of the gateway to which the request was sent.
|
|
failedAttempts = append(failedAttempts,
|
|
fmt.Sprintf("received response from %s instead of gateway IP %s",
|
|
receivedRemoteAddress.IP, gatewayAddress.IP))
|
|
continue
|
|
}
|
|
|
|
response = response[:bytesRead]
|
|
break
|
|
}
|
|
|
|
if retryCount == c.maxRetries {
|
|
return nil, fmt.Errorf("%w: failed attempts: %s",
|
|
ErrConnectionTimeout, dedupFailedAttempts(failedAttempts))
|
|
}
|
|
|
|
// Opcodes between 0 and 127 are client requests. Opcodes from 128 to
|
|
// 255 are corresponding server responses.
|
|
const operationCodeMask = 128
|
|
expectedOperationCode := request[1] | operationCodeMask
|
|
err = checkResponse(response, expectedOperationCode, responseSize)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("checking response: %w", err)
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func dedupFailedAttempts(failedAttempts []string) (errorMessage string) {
|
|
type data struct {
|
|
message string
|
|
indices []int
|
|
}
|
|
messageToData := make(map[string]data, len(failedAttempts))
|
|
for i, message := range failedAttempts {
|
|
metadata, ok := messageToData[message]
|
|
if !ok {
|
|
metadata.message = message
|
|
}
|
|
metadata.indices = append(metadata.indices, i)
|
|
sort.Slice(metadata.indices, func(i, j int) bool {
|
|
return metadata.indices[i] < metadata.indices[j]
|
|
})
|
|
messageToData[message] = metadata
|
|
}
|
|
|
|
// Sort by first index
|
|
dataSlice := make([]data, 0, len(messageToData))
|
|
for _, metadata := range messageToData {
|
|
dataSlice = append(dataSlice, metadata)
|
|
}
|
|
sort.Slice(dataSlice, func(i, j int) bool {
|
|
return dataSlice[i].indices[0] < dataSlice[j].indices[0]
|
|
})
|
|
|
|
dedupedFailedAttempts := make([]string, 0, len(dataSlice))
|
|
for _, data := range dataSlice {
|
|
newMessage := fmt.Sprintf("%s (%s)", data.message,
|
|
indicesToTryString(data.indices))
|
|
dedupedFailedAttempts = append(dedupedFailedAttempts, newMessage)
|
|
}
|
|
return strings.Join(dedupedFailedAttempts, "; ")
|
|
}
|
|
|
|
func indicesToTryString(indices []int) string {
|
|
if len(indices) == 1 {
|
|
return fmt.Sprintf("try %d", indices[0]+1)
|
|
}
|
|
tries := make([]string, len(indices))
|
|
for i, index := range indices {
|
|
tries[i] = fmt.Sprintf("%d", index+1)
|
|
}
|
|
return fmt.Sprintf("tries %s", strings.Join(tries, ", "))
|
|
}
|