fix(pia): support port forwarding using Wireguard (#2420)

- Build API IP address using the first 2 bytes of the gateway IP and adding `128.1` to it
- API IP address is valid for both OpenVPN and Wireguard
- Fix #2320
This commit is contained in:
Quentin McGaw
2024-08-19 03:19:16 +02:00
parent c33158c13c
commit 7064a44403
2 changed files with 36 additions and 15 deletions

View File

@@ -191,11 +191,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err) return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
} }
if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 { if vpnServiceProvider == providers.Custom {
// Allow a single name to be specified for the custom provider in case switch len(settings.Names) {
// the user wants to use VPN server side port forwarding with PIA case 0:
// which requires a server name for TLS verification. case 1:
filterChoices.Names = settings.Names // Allow a single name to be specified for the custom provider in case
// the user wants to use VPN server side port forwarding with PIA
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
default:
return fmt.Errorf("%w: %d names specified instead of "+
"0 or 1 for the custom provider",
ErrNameNotValid, len(settings.Names))
}
} }
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names) err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
if err != nil { if err != nil {

View File

@@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context,
} }
serverName := objects.ServerName serverName := objects.ServerName
apiIP := buildAPIIPAddress(objects.Gateway)
logger := objects.Logger logger := objects.Logger
if !objects.CanPortForward { if !objects.CanPortForward {
@@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context,
if !dataFound || expired { if !dataFound || expired {
client := objects.Client client := objects.Client
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway, data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP,
p.portForwardPath, objects.Username, objects.Password) p.portForwardPath, objects.Username, objects.Password)
if err != nil { if err != nil {
return nil, fmt.Errorf("refreshing port forward data: %w", err) return nil, fmt.Errorf("refreshing port forward data: %w", err)
@@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context,
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration)) logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
// First time binding // First time binding
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil { if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil {
return nil, fmt.Errorf("binding port: %w", err) return nil, fmt.Errorf("binding port: %w", err)
} }
@@ -100,6 +100,8 @@ func (p *Provider) KeepPortForward(ctx context.Context,
panic("gateway is not set") panic("gateway is not set")
} }
apiIP := buildAPIIPAddress(objects.Gateway)
privateIPClient, err := newHTTPClient(objects.ServerName) privateIPClient, err := newHTTPClient(objects.ServerName)
if err != nil { if err != nil {
return fmt.Errorf("creating custom HTTP client: %w", err) return fmt.Errorf("creating custom HTTP client: %w", err)
@@ -127,7 +129,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
} }
return ctx.Err() return ctx.Err()
case <-keepAliveTimer.C: case <-keepAliveTimer.C:
err = bindPort(ctx, privateIPClient, objects.Gateway, data) err = bindPort(ctx, privateIPClient, apiIP, data)
if err != nil { if err != nil {
return fmt.Errorf("binding port: %w", err) return fmt.Errorf("binding port: %w", err)
} }
@@ -139,14 +141,25 @@ func (p *Provider) KeepPortForward(ctx context.Context,
} }
} }
func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) {
if gateway.Is6() {
panic("IPv6 gateway not supported")
}
gatewayBytes := gateway.As4()
gatewayBytes[2] = 128
gatewayBytes[3] = 1
return netip.AddrFrom4(gatewayBytes)
}
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client, func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) { apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, username, password) data.Token, err = fetchToken(ctx, client, username, password)
if err != nil { if err != nil {
return data, fmt.Errorf("fetching token: %w", err) return data, fmt.Errorf("fetching token: %w", err)
} }
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token) data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token)
if err != nil { if err != nil {
return data, fmt.Errorf("fetching port forwarding data: %w", err) return data, fmt.Errorf("fetching port forwarding data: %w", err)
} }
@@ -286,7 +299,7 @@ func fetchToken(ctx context.Context, client *http.Client,
return result.Token, nil return result.Token, nil
} }
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) ( func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) (
port uint16, signature string, expiration time.Time, err error) { port uint16, signature string, expiration time.Time, err error) {
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"} errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
@@ -294,7 +307,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway neti
queryParams.Add("token", token) queryParams.Add("token", token)
url := url.URL{ url := url.URL{
Scheme: "https", Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"), Host: net.JoinHostPort(apiIP.String(), "19999"),
Path: "/getSignature", Path: "/getSignature",
RawQuery: queryParams.Encode(), RawQuery: queryParams.Encode(),
} }
@@ -340,7 +353,7 @@ var (
ErrBadResponse = errors.New("bad response received") ErrBadResponse = errors.New("bad response received")
) )
func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) { func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
payload, err := packPayload(data.Port, data.Token, data.Expiration) payload, err := packPayload(data.Port, data.Token, data.Expiration)
if err != nil { if err != nil {
return fmt.Errorf("serializing payload: %w", err) return fmt.Errorf("serializing payload: %w", err)
@@ -351,7 +364,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
queryParams.Add("signature", data.Signature) queryParams.Add("signature", data.Signature)
bindPortURL := url.URL{ bindPortURL := url.URL{
Scheme: "https", Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"), Host: net.JoinHostPort(apiIPAddress.String(), "19999"),
Path: "/bindPort", Path: "/bindPort",
RawQuery: queryParams.Encode(), RawQuery: queryParams.Encode(),
} }