diff --git a/internal/configuration/settings/serverselection.go b/internal/configuration/settings/serverselection.go index ed254f2f..37efc08f 100644 --- a/internal/configuration/settings/serverselection.go +++ b/internal/configuration/settings/serverselection.go @@ -191,11 +191,19 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter return fmt.Errorf("%w: %w", ErrHostnameNotValid, err) } - if vpnServiceProvider == providers.Custom && len(settings.Names) == 1 { - // Allow a single name to be specified for the custom provider in case - // the user wants to use VPN server side port forwarding with PIA - // which requires a server name for TLS verification. - filterChoices.Names = settings.Names + if vpnServiceProvider == providers.Custom { + switch len(settings.Names) { + case 0: + case 1: + // Allow a single name to be specified for the custom provider in case + // the user wants to use VPN server side port forwarding with PIA + // which requires a server name for TLS verification. + filterChoices.Names = settings.Names + default: + return fmt.Errorf("%w: %d names specified instead of "+ + "0 or 1 for the custom provider", + ErrNameNotValid, len(settings.Names)) + } } err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names) if err != nil { diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 14b03e58..6de34b59 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -39,7 +39,7 @@ func (p *Provider) PortForward(ctx context.Context, } serverName := objects.ServerName - + apiIP := buildAPIIPAddress(objects.Gateway) logger := objects.Logger if !objects.CanPortForward { @@ -70,7 +70,7 @@ func (p *Provider) PortForward(ctx context.Context, if !dataFound || expired { client := objects.Client - data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway, + data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, apiIP, p.portForwardPath, objects.Username, objects.Password) if err != nil { return nil, fmt.Errorf("refreshing port forward data: %w", err) @@ -80,7 +80,7 @@ func (p *Provider) PortForward(ctx context.Context, logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration)) // First time binding - if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil { + if err := bindPort(ctx, privateIPClient, apiIP, data); err != nil { return nil, fmt.Errorf("binding port: %w", err) } @@ -100,6 +100,8 @@ func (p *Provider) KeepPortForward(ctx context.Context, panic("gateway is not set") } + apiIP := buildAPIIPAddress(objects.Gateway) + privateIPClient, err := newHTTPClient(objects.ServerName) if err != nil { return fmt.Errorf("creating custom HTTP client: %w", err) @@ -127,7 +129,7 @@ func (p *Provider) KeepPortForward(ctx context.Context, } return ctx.Err() case <-keepAliveTimer.C: - err = bindPort(ctx, privateIPClient, objects.Gateway, data) + err = bindPort(ctx, privateIPClient, apiIP, data) if err != nil { return fmt.Errorf("binding port: %w", err) } @@ -139,14 +141,25 @@ func (p *Provider) KeepPortForward(ctx context.Context, } } +func buildAPIIPAddress(gateway netip.Addr) (api netip.Addr) { + if gateway.Is6() { + panic("IPv6 gateway not supported") + } + + gatewayBytes := gateway.As4() + gatewayBytes[2] = 128 + gatewayBytes[3] = 1 + return netip.AddrFrom4(gatewayBytes) +} + func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client, - gateway netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) { + apiIP netip.Addr, portForwardPath, username, password string) (data piaPortForwardData, err error) { data.Token, err = fetchToken(ctx, client, username, password) if err != nil { return data, fmt.Errorf("fetching token: %w", err) } - data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token) + data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, apiIP, data.Token) if err != nil { return data, fmt.Errorf("fetching port forwarding data: %w", err) } @@ -286,7 +299,7 @@ func fetchToken(ctx context.Context, client *http.Client, return result.Token, nil } -func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) ( +func fetchPortForwardData(ctx context.Context, client *http.Client, apiIP netip.Addr, token string) ( port uint16, signature string, expiration time.Time, err error) { errSubstitutions := map[string]string{url.QueryEscape(token): ""} @@ -294,7 +307,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway neti queryParams.Add("token", token) url := url.URL{ Scheme: "https", - Host: net.JoinHostPort(gateway.String(), "19999"), + Host: net.JoinHostPort(apiIP.String(), "19999"), Path: "/getSignature", RawQuery: queryParams.Encode(), } @@ -340,7 +353,7 @@ var ( ErrBadResponse = errors.New("bad response received") ) -func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) { +func bindPort(ctx context.Context, client *http.Client, apiIPAddress netip.Addr, data piaPortForwardData) (err error) { payload, err := packPayload(data.Port, data.Token, data.Expiration) if err != nil { return fmt.Errorf("serializing payload: %w", err) @@ -351,7 +364,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data queryParams.Add("signature", data.Signature) bindPortURL := url.URL{ Scheme: "https", - Host: net.JoinHostPort(gateway.String(), "19999"), + Host: net.JoinHostPort(apiIPAddress.String(), "19999"), Path: "/bindPort", RawQuery: queryParams.Encode(), }