fix(pia): support port forwarding using Wireguard (#2420)
- Build API IP address using the first 2 bytes of the gateway IP and adding `128.1` to it - API IP address is valid for both OpenVPN and Wireguard - Fix #2320
This commit is contained in:
@@ -191,11 +191,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
|
|||||||
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 {
|
if vpnServiceProvider == providers.Custom {
|
||||||
// Allow a single name to be specified for the custom provider in case
|
switch len(settings.Names) {
|
||||||
// the user wants to use VPN server side port forwarding with PIA
|
case 0:
|
||||||
// which requires a server name for TLS verification.
|
case 1:
|
||||||
filterChoices.Names = settings.Names
|
// Allow a single name to be specified for the custom provider in case
|
||||||
|
// the user wants to use VPN server side port forwarding with PIA
|
||||||
|
// which requires a server name for TLS verification.
|
||||||
|
filterChoices.Names = settings.Names
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: %d names specified instead of "+
|
||||||
|
"0 or 1 for the custom provider",
|
||||||
|
ErrNameNotValid, len(settings.Names))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
|
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
serverName := objects.ServerName
|
serverName := objects.ServerName
|
||||||
|
apiIP := buildAPIIPAddress(objects.Gateway)
|
||||||
logger := objects.Logger
|
logger := objects.Logger
|
||||||
|
|
||||||
if !objects.CanPortForward {
|
if !objects.CanPortForward {
|
||||||
@@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
|||||||
|
|
||||||
if !dataFound || expired {
|
if !dataFound || expired {
|
||||||
client := objects.Client
|
client := objects.Client
|
||||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
|
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP,
|
||||||
p.portForwardPath, objects.Username, objects.Password)
|
p.portForwardPath, objects.Username, objects.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("refreshing port forward data: %w", err)
|
return nil, fmt.Errorf("refreshing port forward data: %w", err)
|
||||||
@@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
|||||||
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
||||||
|
|
||||||
// First time binding
|
// First time binding
|
||||||
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
|
if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil {
|
||||||
return nil, fmt.Errorf("binding port: %w", err)
|
return nil, fmt.Errorf("binding port: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,6 +100,8 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
|||||||
panic("gateway is not set")
|
panic("gateway is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apiIP := buildAPIIPAddress(objects.Gateway)
|
||||||
|
|
||||||
privateIPClient, err := newHTTPClient(objects.ServerName)
|
privateIPClient, err := newHTTPClient(objects.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating custom HTTP client: %w", err)
|
return fmt.Errorf("creating custom HTTP client: %w", err)
|
||||||
@@ -127,7 +129,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
|||||||
}
|
}
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-keepAliveTimer.C:
|
case <-keepAliveTimer.C:
|
||||||
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
|
err = bindPort(ctx, privateIPClient, apiIP, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("binding port: %w", err)
|
return fmt.Errorf("binding port: %w", err)
|
||||||
}
|
}
|
||||||
@@ -139,14 +141,25 @@ func (p *Provider) KeepPortForward(ctx context.Context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) {
|
||||||
|
if gateway.Is6() {
|
||||||
|
panic("IPv6 gateway not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayBytes := gateway.As4()
|
||||||
|
gatewayBytes[2] = 128
|
||||||
|
gatewayBytes[3] = 1
|
||||||
|
return netip.AddrFrom4(gatewayBytes)
|
||||||
|
}
|
||||||
|
|
||||||
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||||
gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
|
apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) {
|
||||||
data.Token, err = fetchToken(ctx, client, username, password)
|
data.Token, err = fetchToken(ctx, client, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("fetching token: %w", err)
|
return data, fmt.Errorf("fetching token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("fetching port forwarding data: %w", err)
|
return data, fmt.Errorf("fetching port forwarding data: %w", err)
|
||||||
}
|
}
|
||||||
@@ -286,7 +299,7 @@ func fetchToken(ctx context.Context, client *http.Client,
|
|||||||
return result.Token, nil
|
return result.Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
|
func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) (
|
||||||
port uint16, signature string, expiration time.Time, err error) {
|
port uint16, signature string, expiration time.Time, err error) {
|
||||||
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
|
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
|
||||||
|
|
||||||
@@ -294,7 +307,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway neti
|
|||||||
queryParams.Add("token", token)
|
queryParams.Add("token", token)
|
||||||
url := url.URL{
|
url := url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
Host: net.JoinHostPort(apiIP.String(), "19999"),
|
||||||
Path: "/getSignature",
|
Path: "/getSignature",
|
||||||
RawQuery: queryParams.Encode(),
|
RawQuery: queryParams.Encode(),
|
||||||
}
|
}
|
||||||
@@ -340,7 +353,7 @@ var (
|
|||||||
ErrBadResponse = errors.New("bad response received")
|
ErrBadResponse = errors.New("bad response received")
|
||||||
)
|
)
|
||||||
|
|
||||||
func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
|
func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) {
|
||||||
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("serializing payload: %w", err)
|
return fmt.Errorf("serializing payload: %w", err)
|
||||||
@@ -351,7 +364,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
|
|||||||
queryParams.Add("signature", data.Signature)
|
queryParams.Add("signature", data.Signature)
|
||||||
bindPortURL := url.URL{
|
bindPortURL := url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
Host: net.JoinHostPort(apiIPAddress.String(), "19999"),
|
||||||
Path: "/bindPort",
|
Path: "/bindPort",
|
||||||
RawQuery: queryParams.Encode(),
|
RawQuery: queryParams.Encode(),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user