Fix: PIA port forwarding (#427)

- Update PIA token URL
- Change base64 decoding to standard decoding
- Add unit tests
- Remove environment variable `GODEBUG=x509ignoreCN=0`
- Fixes #423 
- Fixes #292 
- Closes #264 
- Closes #293
This commit is contained in:
Quentin McGaw
2021-04-17 16:21:17 -04:00
committed by GitHub
parent 3795e92a82
commit 6208081788
3 changed files with 175 additions and 36 deletions

View File

@@ -147,5 +147,4 @@ RUN apk add -q --progress --no-cache --update openvpn ca-certificates iptables i
deluser unbound && \ deluser unbound && \
mkdir /gluetun mkdir /gluetun
# TODO remove once SAN is added to PIA servers certificates, see https://github.com/pia-foss/manual-connections/issues/10 # TODO remove once SAN is added to PIA servers certificates, see https://github.com/pia-foss/manual-connections/issues/10
ENV GODEBUG=x509ignoreCN=0
COPY --from=build /tmp/gobuild/entrypoint /entrypoint COPY --from=build /tmp/gobuild/entrypoint /entrypoint

View File

@@ -221,7 +221,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
return return
} }
client, err := newPIAHTTPClient(commonName) privateIPClient, err := newPIAHTTPClient(commonName)
if err != nil { if err != nil {
pfLogger.Error("aborting because: %s", err) pfLogger.Error("aborting because: %s", err)
return return
@@ -246,7 +246,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
if !dataFound || expired { if !dataFound || expired {
tryUntilSuccessful(ctx, pfLogger, func() error { tryUntilSuccessful(ctx, pfLogger, func() error {
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile) data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
return err return err
}) })
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -258,7 +258,10 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
// First time binding // First time binding
tryUntilSuccessful(ctx, pfLogger, func() error { tryUntilSuccessful(ctx, pfLogger, func() error {
return bindPIAPort(ctx, client, gateway, data) if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
return fmt.Errorf("cannot bind port: %w", err)
}
return nil
}) })
if ctx.Err() != nil { if ctx.Err() != nil {
return return
@@ -294,15 +297,15 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
} }
return return
case <-keepAliveTimer.C: case <-keepAliveTimer.C:
if err := bindPIAPort(ctx, client, gateway, data); err != nil { if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
pfLogger.Error(err) pfLogger.Error("cannot bind port: " + err.Error())
} }
keepAliveTimer.Reset(keepAlivePeriod) keepAliveTimer.Reset(keepAlivePeriod)
case <-expiryTimer.C: case <-expiryTimer.C:
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123)) pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
oldPort := data.Port oldPort := data.Port
for { for {
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile) data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
if err != nil { if err != nil {
pfLogger.Error(err) pfLogger.Error(err)
continue continue
@@ -322,8 +325,8 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil { if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
pfLogger.Error(err) pfLogger.Error(err)
} }
if err := bindPIAPort(ctx, client, gateway, data); err != nil { if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil {
pfLogger.Error(err) pfLogger.Error("cannot bind port: " + err.Error())
} }
if !keepAliveTimer.Stop() { if !keepAliveTimer.Stop() {
<-keepAliveTimer.C <-keepAliveTimer.C
@@ -357,22 +360,14 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err) return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err)
} }
// certificate.DNSNames = []string{serverName, "10.0.0.1"}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificate)
TLSClientConfig := &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
ServerName: serverName,
}
//nolint:gomnd //nolint:gomnd
transport := http.Transport{ transport := &http.Transport{
TLSClientConfig: TLSClientConfig, // Settings taken from http.DefaultTransport
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext, }).DialContext,
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
MaxIdleConns: 100, MaxIdleConns: 100,
@@ -380,18 +375,28 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) {
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificate)
transport.TLSClientConfig = &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
ServerName: serverName,
}
const httpTimeout = 30 * time.Second const httpTimeout = 30 * time.Second
client = &http.Client{Transport: &transport, Timeout: httpTimeout} return &http.Client{
return client, nil Transport: transport,
Timeout: httpTimeout,
}, nil
} }
func refreshPIAPortForwardData(ctx context.Context, client *http.Client, func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) { gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
data.Token, err = fetchPIAToken(ctx, openFile, client) data.Token, err = fetchPIAToken(ctx, openFile, client)
if err != nil { if err != nil {
return data, fmt.Errorf("cannot obtain token: %w", err) return data, fmt.Errorf("cannot obtain token: %w", err)
} }
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, client, gateway, data.Token) data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, privateIPClient, gateway, data.Token)
if err != nil { if err != nil {
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err) return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
} }
@@ -448,13 +453,15 @@ func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData)
} }
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) { func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
b, err := base64.RawStdEncoding.DecodeString(payload) b, err := base64.StdEncoding.DecodeString(payload)
if err != nil { if err != nil {
return 0, "", expiration, fmt.Errorf("cannot decode payload: %w", err) return 0, "", expiration,
fmt.Errorf("cannot decode payload: payload is %q: %w", payload, err)
} }
var payloadData piaPayload var payloadData piaPayload
if err := json.Unmarshal(b, &payloadData); err != nil { if err := json.Unmarshal(b, &payloadData); err != nil {
return 0, "", expiration, fmt.Errorf("cannot parse payload data: %w", err) return 0, "", expiration,
fmt.Errorf("cannot parse payload data: data is %q: %w", string(b), err)
} }
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
} }
@@ -469,7 +476,7 @@ func packPIAPayload(port uint16, token string, expiration time.Time) (payload st
if err != nil { if err != nil {
return "", fmt.Errorf("cannot serialize payload data: %w", err) return "", fmt.Errorf("cannot serialize payload data: %w", err)
} }
payload = base64.RawStdEncoding.EncodeToString(b) payload = base64.StdEncoding.EncodeToString(b)
return payload, nil return payload, nil
} }
@@ -482,16 +489,18 @@ func fetchPIAToken(ctx context.Context, openFile os.OpenFileFunc,
url := url.URL{ url := url.URL{
Scheme: "https", Scheme: "https",
User: url.UserPassword(username, password), User: url.UserPassword(username, password),
Host: "10.0.0.1", Host: "privateinternetaccess.com",
Path: "/authv3/generateToken", Path: "/gtoken/generateToken",
} }
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil { if err != nil {
return "", err return "", replaceInErr(err, map[string]string{
username: "<username>", password: "<password>"})
} }
response, err := client.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
return "", err return "", replaceInErr(err, map[string]string{
username: "<username>", password: "<password>"})
} }
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
@@ -547,10 +556,12 @@ func fetchPIAPortForwardData(ctx context.Context, client *http.Client, gateway n
} }
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil { if err != nil {
err = replaceInErr(err, map[string]string{token: "<token>"})
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err) return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
} }
response, err := client.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
err = replaceInErr(err, map[string]string{token: "<token>"})
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err) return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
} }
defer response.Body.Close() defer response.Body.Close()
@@ -590,11 +601,17 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil { if err != nil {
return fmt.Errorf("cannot bind port: %w", err) return replaceInErr(err, map[string]string{
payload: "<payload>",
data.Signature: "<signature>",
})
} }
response, err := client.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
return fmt.Errorf("cannot bind port: %w", err) return replaceInErr(err, map[string]string{
payload: "<payload>",
data.Signature: "<signature>",
})
} }
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
@@ -607,7 +624,7 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
Message string `json:"message"` Message string `json:"message"`
} }
if err := decoder.Decode(&responseData); err != nil { if err := decoder.Decode(&responseData); err != nil {
return fmt.Errorf("cannot bind port: %w", err) return err
} else if responseData.Status != "OK" { } else if responseData.Status != "OK" {
return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message) return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message)
} }
@@ -627,3 +644,12 @@ func writePortForwardedToFile(openFile os.OpenFileFunc,
} }
return file.Close() return file.Close()
} }
// replaceInErr is used to remove sensitive information from logs.
func replaceInErr(err error, substitutions map[string]string) error {
s := err.Error()
for old, new := range substitutions {
s = strings.ReplaceAll(s, old, new)
}
return errors.New(s)
}

View File

@@ -0,0 +1,114 @@
package provider
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"net/http"
"testing"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_newPIAHTTPClient(t *testing.T) {
t.Parallel()
const serverName = "testserver"
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong)
require.NoError(t, err)
certificate, err := x509.ParseCertificate(certificateBytes)
require.NoError(t, err)
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificate)
expectedRootCAsSubjects := rootCAs.Subjects()
expectedPIATransportTLSConfig := &tls.Config{
// Can't directly compare RootCAs because of private fields
RootCAs: nil,
MinVersion: tls.VersionTLS12,
ServerName: serverName,
}
piaClient, err := newPIAHTTPClient(serverName)
require.NoError(t, err)
// Verify pia transport TLS config is set
piaTransport := piaClient.Transport.(*http.Transport)
rootCAsSubjects := piaTransport.TLSClientConfig.RootCAs.Subjects()
assert.Equal(t, expectedRootCAsSubjects, rootCAsSubjects)
piaTransport.TLSClientConfig.RootCAs = nil
assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig)
}
func Test_unpackPIAPayload(t *testing.T) {
t.Parallel()
const exampleToken = "token"
const examplePort = 2000
exampleExpiration := time.Unix(1000, 0).UTC()
testCases := map[string]struct {
payload string
port uint16
token string
expiration time.Time
err error
}{
"valid payload": {
payload: makePIAPayload(t, exampleToken, examplePort, exampleExpiration),
port: examplePort,
token: exampleToken,
expiration: exampleExpiration,
err: nil,
},
"invalid base64 payload": {
payload: "invalid",
err: errors.New(`cannot decode payload: payload is "invalid": illegal base64 data at input byte 4`),
},
"invalid json payload": {
payload: base64.StdEncoding.EncodeToString([]byte{1}),
err: errors.New(`cannot parse payload data: data is "\x01": invalid character '\x01' looking for beginning of value`), //nolint:lll
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
port, token, expiration, err := unpackPIAPayload(testCase.payload)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
require.NoError(t, err)
}
assert.Equal(t, testCase.port, port)
assert.Equal(t, testCase.token, token)
assert.Equal(t, testCase.expiration, expiration)
})
}
}
func makePIAPayload(t *testing.T, token string, port uint16, expiration time.Time) (payload string) {
t.Helper()
data := piaPayload{
Token: token,
Port: port,
Expiration: expiration,
}
b, err := json.Marshal(data)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(b)
}