Fix: PIA port forwarding (#427)
- Update PIA token URL - Change base64 decoding to standard decoding - Add unit tests - Remove environment variable `GODEBUG=x509ignoreCN=0` - Fixes #423 - Fixes #292 - Closes #264 - Closes #293
This commit is contained in:
@@ -147,5 +147,4 @@ RUN apk add -q --progress --no-cache --update openvpn ca-certificates iptables i
|
|||||||
deluser unbound && \
|
deluser unbound && \
|
||||||
mkdir /gluetun
|
mkdir /gluetun
|
||||||
# TODO remove once SAN is added to PIA servers certificates, see https://github.com/pia-foss/manual-connections/issues/10
|
# TODO remove once SAN is added to PIA servers certificates, see https://github.com/pia-foss/manual-connections/issues/10
|
||||||
ENV GODEBUG=x509ignoreCN=0
|
|
||||||
COPY --from=build /tmp/gobuild/entrypoint /entrypoint
|
COPY --from=build /tmp/gobuild/entrypoint /entrypoint
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := newPIAHTTPClient(commonName)
|
privateIPClient, err := newPIAHTTPClient(commonName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pfLogger.Error("aborting because: %s", err)
|
pfLogger.Error("aborting because: %s", err)
|
||||||
return
|
return
|
||||||
@@ -246,7 +246,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
|||||||
|
|
||||||
if !dataFound || expired {
|
if !dataFound || expired {
|
||||||
tryUntilSuccessful(ctx, pfLogger, func() error {
|
tryUntilSuccessful(ctx, pfLogger, func() error {
|
||||||
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
|
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
@@ -258,7 +258,10 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
|||||||
|
|
||||||
// First time binding
|
// First time binding
|
||||||
tryUntilSuccessful(ctx, pfLogger, func() error {
|
tryUntilSuccessful(ctx, pfLogger, func() error {
|
||||||
return bindPIAPort(ctx, client, gateway, data)
|
if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||||
|
return fmt.Errorf("cannot bind port: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
@@ -294,15 +297,15 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-keepAliveTimer.C:
|
case <-keepAliveTimer.C:
|
||||||
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
|
if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error("cannot bind port: " + err.Error())
|
||||||
}
|
}
|
||||||
keepAliveTimer.Reset(keepAlivePeriod)
|
keepAliveTimer.Reset(keepAlivePeriod)
|
||||||
case <-expiryTimer.C:
|
case <-expiryTimer.C:
|
||||||
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
|
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
|
||||||
oldPort := data.Port
|
oldPort := data.Port
|
||||||
for {
|
for {
|
||||||
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
|
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error(err)
|
||||||
continue
|
continue
|
||||||
@@ -322,8 +325,8 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
|||||||
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error(err)
|
||||||
}
|
}
|
||||||
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
|
if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error("cannot bind port: " + err.Error())
|
||||||
}
|
}
|
||||||
if !keepAliveTimer.Stop() {
|
if !keepAliveTimer.Stop() {
|
||||||
<-keepAliveTimer.C
|
<-keepAliveTimer.C
|
||||||
@@ -357,22 +360,14 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err)
|
return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err)
|
||||||
}
|
}
|
||||||
// certificate.DNSNames = []string{serverName, "10.0.0.1"}
|
|
||||||
rootCAs := x509.NewCertPool()
|
|
||||||
rootCAs.AddCert(certificate)
|
|
||||||
TLSClientConfig := &tls.Config{
|
|
||||||
RootCAs: rootCAs,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
ServerName: serverName,
|
|
||||||
}
|
|
||||||
//nolint:gomnd
|
//nolint:gomnd
|
||||||
transport := http.Transport{
|
transport := &http.Transport{
|
||||||
TLSClientConfig: TLSClientConfig,
|
// Settings taken from http.DefaultTransport
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
DialContext: (&net.Dialer{
|
DialContext: (&net.Dialer{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
KeepAlive: 30 * time.Second,
|
KeepAlive: 30 * time.Second,
|
||||||
DualStack: true,
|
|
||||||
}).DialContext,
|
}).DialContext,
|
||||||
ForceAttemptHTTP2: true,
|
ForceAttemptHTTP2: true,
|
||||||
MaxIdleConns: 100,
|
MaxIdleConns: 100,
|
||||||
@@ -380,18 +375,28 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) {
|
|||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
|
rootCAs := x509.NewCertPool()
|
||||||
|
rootCAs.AddCert(certificate)
|
||||||
|
transport.TLSClientConfig = &tls.Config{
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: serverName,
|
||||||
|
}
|
||||||
|
|
||||||
const httpTimeout = 30 * time.Second
|
const httpTimeout = 30 * time.Second
|
||||||
client = &http.Client{Transport: &transport, Timeout: httpTimeout}
|
return &http.Client{
|
||||||
return client, nil
|
Transport: transport,
|
||||||
|
Timeout: httpTimeout,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
|
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||||
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||||
data.Token, err = fetchPIAToken(ctx, openFile, client)
|
data.Token, err = fetchPIAToken(ctx, openFile, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("cannot obtain token: %w", err)
|
return data, fmt.Errorf("cannot obtain token: %w", err)
|
||||||
}
|
}
|
||||||
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, client, gateway, data.Token)
|
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
|
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
|
||||||
}
|
}
|
||||||
@@ -448,13 +453,15 @@ func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
|
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
|
||||||
b, err := base64.RawStdEncoding.DecodeString(payload)
|
b, err := base64.StdEncoding.DecodeString(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", expiration, fmt.Errorf("cannot decode payload: %w", err)
|
return 0, "", expiration,
|
||||||
|
fmt.Errorf("cannot decode payload: payload is %q: %w", payload, err)
|
||||||
}
|
}
|
||||||
var payloadData piaPayload
|
var payloadData piaPayload
|
||||||
if err := json.Unmarshal(b, &payloadData); err != nil {
|
if err := json.Unmarshal(b, &payloadData); err != nil {
|
||||||
return 0, "", expiration, fmt.Errorf("cannot parse payload data: %w", err)
|
return 0, "", expiration,
|
||||||
|
fmt.Errorf("cannot parse payload data: data is %q: %w", string(b), err)
|
||||||
}
|
}
|
||||||
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
|
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
|
||||||
}
|
}
|
||||||
@@ -469,7 +476,7 @@ func packPIAPayload(port uint16, token string, expiration time.Time) (payload st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("cannot serialize payload data: %w", err)
|
return "", fmt.Errorf("cannot serialize payload data: %w", err)
|
||||||
}
|
}
|
||||||
payload = base64.RawStdEncoding.EncodeToString(b)
|
payload = base64.StdEncoding.EncodeToString(b)
|
||||||
return payload, nil
|
return payload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -482,16 +489,18 @@ func fetchPIAToken(ctx context.Context, openFile os.OpenFileFunc,
|
|||||||
url := url.URL{
|
url := url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
User: url.UserPassword(username, password),
|
User: url.UserPassword(username, password),
|
||||||
Host: "10.0.0.1",
|
Host: "privateinternetaccess.com",
|
||||||
Path: "/authv3/generateToken",
|
Path: "/gtoken/generateToken",
|
||||||
}
|
}
|
||||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", replaceInErr(err, map[string]string{
|
||||||
|
username: "<username>", password: "<password>"})
|
||||||
}
|
}
|
||||||
response, err := client.Do(request)
|
response, err := client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", replaceInErr(err, map[string]string{
|
||||||
|
username: "<username>", password: "<password>"})
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
@@ -547,10 +556,12 @@ func fetchPIAPortForwardData(ctx context.Context, client *http.Client, gateway n
|
|||||||
}
|
}
|
||||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = replaceInErr(err, map[string]string{token: "<token>"})
|
||||||
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
|
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
|
||||||
}
|
}
|
||||||
response, err := client.Do(request)
|
response, err := client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = replaceInErr(err, map[string]string{token: "<token>"})
|
||||||
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
|
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
@@ -590,11 +601,17 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
|
|||||||
|
|
||||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot bind port: %w", err)
|
return replaceInErr(err, map[string]string{
|
||||||
|
payload: "<payload>",
|
||||||
|
data.Signature: "<signature>",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
response, err := client.Do(request)
|
response, err := client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot bind port: %w", err)
|
return replaceInErr(err, map[string]string{
|
||||||
|
payload: "<payload>",
|
||||||
|
data.Signature: "<signature>",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
@@ -607,7 +624,7 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
if err := decoder.Decode(&responseData); err != nil {
|
if err := decoder.Decode(&responseData); err != nil {
|
||||||
return fmt.Errorf("cannot bind port: %w", err)
|
return err
|
||||||
} else if responseData.Status != "OK" {
|
} else if responseData.Status != "OK" {
|
||||||
return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message)
|
return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message)
|
||||||
}
|
}
|
||||||
@@ -627,3 +644,12 @@ func writePortForwardedToFile(openFile os.OpenFileFunc,
|
|||||||
}
|
}
|
||||||
return file.Close()
|
return file.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// replaceInErr is used to remove sensitive information from logs.
|
||||||
|
func replaceInErr(err error, substitutions map[string]string) error {
|
||||||
|
s := err.Error()
|
||||||
|
for old, new := range substitutions {
|
||||||
|
s = strings.ReplaceAll(s, old, new)
|
||||||
|
}
|
||||||
|
return errors.New(s)
|
||||||
|
}
|
||||||
|
|||||||
114
internal/provider/piav4_test.go
Normal file
114
internal/provider/piav4_test.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_newPIAHTTPClient(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const serverName = "testserver"
|
||||||
|
|
||||||
|
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong)
|
||||||
|
require.NoError(t, err)
|
||||||
|
certificate, err := x509.ParseCertificate(certificateBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
rootCAs := x509.NewCertPool()
|
||||||
|
rootCAs.AddCert(certificate)
|
||||||
|
expectedRootCAsSubjects := rootCAs.Subjects()
|
||||||
|
|
||||||
|
expectedPIATransportTLSConfig := &tls.Config{
|
||||||
|
// Can't directly compare RootCAs because of private fields
|
||||||
|
RootCAs: nil,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: serverName,
|
||||||
|
}
|
||||||
|
|
||||||
|
piaClient, err := newPIAHTTPClient(serverName)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify pia transport TLS config is set
|
||||||
|
piaTransport := piaClient.Transport.(*http.Transport)
|
||||||
|
rootCAsSubjects := piaTransport.TLSClientConfig.RootCAs.Subjects()
|
||||||
|
assert.Equal(t, expectedRootCAsSubjects, rootCAsSubjects)
|
||||||
|
piaTransport.TLSClientConfig.RootCAs = nil
|
||||||
|
assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_unpackPIAPayload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const exampleToken = "token"
|
||||||
|
const examplePort = 2000
|
||||||
|
exampleExpiration := time.Unix(1000, 0).UTC()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
payload string
|
||||||
|
port uint16
|
||||||
|
token string
|
||||||
|
expiration time.Time
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
"valid payload": {
|
||||||
|
payload: makePIAPayload(t, exampleToken, examplePort, exampleExpiration),
|
||||||
|
port: examplePort,
|
||||||
|
token: exampleToken,
|
||||||
|
expiration: exampleExpiration,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
"invalid base64 payload": {
|
||||||
|
payload: "invalid",
|
||||||
|
err: errors.New(`cannot decode payload: payload is "invalid": illegal base64 data at input byte 4`),
|
||||||
|
},
|
||||||
|
"invalid json payload": {
|
||||||
|
payload: base64.StdEncoding.EncodeToString([]byte{1}),
|
||||||
|
err: errors.New(`cannot parse payload data: data is "\x01": invalid character '\x01' looking for beginning of value`), //nolint:lll
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
port, token, expiration, err := unpackPIAPayload(testCase.payload)
|
||||||
|
|
||||||
|
if testCase.err != nil {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.port, port)
|
||||||
|
assert.Equal(t, testCase.token, token)
|
||||||
|
assert.Equal(t, testCase.expiration, expiration)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makePIAPayload(t *testing.T, token string, port uint16, expiration time.Time) (payload string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
data := piaPayload{
|
||||||
|
Token: token,
|
||||||
|
Port: port,
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := json.Marshal(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user