diff --git a/Dockerfile b/Dockerfile index f85b5c10..2d888495 100644 --- a/Dockerfile +++ b/Dockerfile @@ -147,5 +147,4 @@ RUN apk add -q --progress --no-cache --update openvpn ca-certificates iptables i deluser unbound && \ mkdir /gluetun # TODO remove once SAN is added to PIA servers certificates, see https://github.com/pia-foss/manual-connections/issues/10 -ENV GODEBUG=x509ignoreCN=0 COPY --from=build /tmp/gobuild/entrypoint /entrypoint diff --git a/internal/provider/piav4.go b/internal/provider/piav4.go index 43ad3231..9591628c 100644 --- a/internal/provider/piav4.go +++ b/internal/provider/piav4.go @@ -221,7 +221,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, return } - client, err := newPIAHTTPClient(commonName) + privateIPClient, err := newPIAHTTPClient(commonName) if err != nil { pfLogger.Error("aborting because: %s", err) return @@ -246,7 +246,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, if !dataFound || expired { tryUntilSuccessful(ctx, pfLogger, func() error { - data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile) + data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile) return err }) if ctx.Err() != nil { @@ -258,7 +258,10 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, // First time binding tryUntilSuccessful(ctx, pfLogger, func() error { - return bindPIAPort(ctx, client, gateway, data) + if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil { + return fmt.Errorf("cannot bind port: %w", err) + } + return nil }) if ctx.Err() != nil { return @@ -294,15 +297,15 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, } return case <-keepAliveTimer.C: - if err := bindPIAPort(ctx, client, gateway, data); err != nil { - pfLogger.Error(err) + if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil { + pfLogger.Error("cannot bind port: " + err.Error()) } keepAliveTimer.Reset(keepAlivePeriod) case <-expiryTimer.C: pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123)) oldPort := data.Port for { - data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile) + data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile) if err != nil { pfLogger.Error(err) continue @@ -322,8 +325,8 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil { pfLogger.Error(err) } - if err := bindPIAPort(ctx, client, gateway, data); err != nil { - pfLogger.Error(err) + if err := bindPIAPort(ctx, privateIPClient, gateway, data); err != nil { + pfLogger.Error("cannot bind port: " + err.Error()) } if !keepAliveTimer.Stop() { <-keepAliveTimer.C @@ -357,22 +360,14 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) { if err != nil { return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err) } - // certificate.DNSNames = []string{serverName, "10.0.0.1"} - rootCAs := x509.NewCertPool() - rootCAs.AddCert(certificate) - TLSClientConfig := &tls.Config{ - RootCAs: rootCAs, - MinVersion: tls.VersionTLS12, - ServerName: serverName, - } + //nolint:gomnd - transport := http.Transport{ - TLSClientConfig: TLSClientConfig, - Proxy: http.ProxyFromEnvironment, + transport := &http.Transport{ + // Settings taken from http.DefaultTransport + Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - DualStack: true, }).DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, @@ -380,18 +375,28 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } + rootCAs := x509.NewCertPool() + rootCAs.AddCert(certificate) + transport.TLSClientConfig = &tls.Config{ + RootCAs: rootCAs, + MinVersion: tls.VersionTLS12, + ServerName: serverName, + } + const httpTimeout = 30 * time.Second - client = &http.Client{Transport: &transport, Timeout: httpTimeout} - return client, nil + return &http.Client{ + Transport: transport, + Timeout: httpTimeout, + }, nil } -func refreshPIAPortForwardData(ctx context.Context, client *http.Client, +func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client, gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) { data.Token, err = fetchPIAToken(ctx, openFile, client) if err != nil { return data, fmt.Errorf("cannot obtain token: %w", err) } - data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, client, gateway, data.Token) + data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, privateIPClient, gateway, data.Token) if err != nil { return data, fmt.Errorf("cannot obtain port forwarding data: %w", err) } @@ -448,13 +453,15 @@ func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) } func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) { - b, err := base64.RawStdEncoding.DecodeString(payload) + b, err := base64.StdEncoding.DecodeString(payload) if err != nil { - return 0, "", expiration, fmt.Errorf("cannot decode payload: %w", err) + return 0, "", expiration, + fmt.Errorf("cannot decode payload: payload is %q: %w", payload, err) } var payloadData piaPayload if err := json.Unmarshal(b, &payloadData); err != nil { - return 0, "", expiration, fmt.Errorf("cannot parse payload data: %w", err) + return 0, "", expiration, + fmt.Errorf("cannot parse payload data: data is %q: %w", string(b), err) } return payloadData.Port, payloadData.Token, payloadData.Expiration, nil } @@ -469,7 +476,7 @@ func packPIAPayload(port uint16, token string, expiration time.Time) (payload st if err != nil { return "", fmt.Errorf("cannot serialize payload data: %w", err) } - payload = base64.RawStdEncoding.EncodeToString(b) + payload = base64.StdEncoding.EncodeToString(b) return payload, nil } @@ -482,16 +489,18 @@ func fetchPIAToken(ctx context.Context, openFile os.OpenFileFunc, url := url.URL{ Scheme: "https", User: url.UserPassword(username, password), - Host: "10.0.0.1", - Path: "/authv3/generateToken", + Host: "privateinternetaccess.com", + Path: "/gtoken/generateToken", } request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) if err != nil { - return "", err + return "", replaceInErr(err, map[string]string{ + username: "", password: ""}) } response, err := client.Do(request) if err != nil { - return "", err + return "", replaceInErr(err, map[string]string{ + username: "", password: ""}) } defer response.Body.Close() if response.StatusCode != http.StatusOK { @@ -547,10 +556,12 @@ func fetchPIAPortForwardData(ctx context.Context, client *http.Client, gateway n } request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) if err != nil { + err = replaceInErr(err, map[string]string{token: ""}) return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err) } response, err := client.Do(request) if err != nil { + err = replaceInErr(err, map[string]string{token: ""}) return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err) } defer response.Body.Close() @@ -590,11 +601,17 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) if err != nil { - return fmt.Errorf("cannot bind port: %w", err) + return replaceInErr(err, map[string]string{ + payload: "", + data.Signature: "", + }) } response, err := client.Do(request) if err != nil { - return fmt.Errorf("cannot bind port: %w", err) + return replaceInErr(err, map[string]string{ + payload: "", + data.Signature: "", + }) } defer response.Body.Close() if response.StatusCode != http.StatusOK { @@ -607,7 +624,7 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data Message string `json:"message"` } if err := decoder.Decode(&responseData); err != nil { - return fmt.Errorf("cannot bind port: %w", err) + return err } else if responseData.Status != "OK" { return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message) } @@ -627,3 +644,12 @@ func writePortForwardedToFile(openFile os.OpenFileFunc, } return file.Close() } + +// replaceInErr is used to remove sensitive information from logs. +func replaceInErr(err error, substitutions map[string]string) error { + s := err.Error() + for old, new := range substitutions { + s = strings.ReplaceAll(s, old, new) + } + return errors.New(s) +} diff --git a/internal/provider/piav4_test.go b/internal/provider/piav4_test.go new file mode 100644 index 00000000..b03284b3 --- /dev/null +++ b/internal/provider/piav4_test.go @@ -0,0 +1,114 @@ +package provider + +import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "net/http" + "testing" + "time" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_newPIAHTTPClient(t *testing.T) { + t.Parallel() + + const serverName = "testserver" + + certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong) + require.NoError(t, err) + certificate, err := x509.ParseCertificate(certificateBytes) + require.NoError(t, err) + rootCAs := x509.NewCertPool() + rootCAs.AddCert(certificate) + expectedRootCAsSubjects := rootCAs.Subjects() + + expectedPIATransportTLSConfig := &tls.Config{ + // Can't directly compare RootCAs because of private fields + RootCAs: nil, + MinVersion: tls.VersionTLS12, + ServerName: serverName, + } + + piaClient, err := newPIAHTTPClient(serverName) + + require.NoError(t, err) + + // Verify pia transport TLS config is set + piaTransport := piaClient.Transport.(*http.Transport) + rootCAsSubjects := piaTransport.TLSClientConfig.RootCAs.Subjects() + assert.Equal(t, expectedRootCAsSubjects, rootCAsSubjects) + piaTransport.TLSClientConfig.RootCAs = nil + assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig) +} + +func Test_unpackPIAPayload(t *testing.T) { + t.Parallel() + + const exampleToken = "token" + const examplePort = 2000 + exampleExpiration := time.Unix(1000, 0).UTC() + + testCases := map[string]struct { + payload string + port uint16 + token string + expiration time.Time + err error + }{ + "valid payload": { + payload: makePIAPayload(t, exampleToken, examplePort, exampleExpiration), + port: examplePort, + token: exampleToken, + expiration: exampleExpiration, + err: nil, + }, + "invalid base64 payload": { + payload: "invalid", + err: errors.New(`cannot decode payload: payload is "invalid": illegal base64 data at input byte 4`), + }, + "invalid json payload": { + payload: base64.StdEncoding.EncodeToString([]byte{1}), + err: errors.New(`cannot parse payload data: data is "\x01": invalid character '\x01' looking for beginning of value`), //nolint:lll + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + port, token, expiration, err := unpackPIAPayload(testCase.payload) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + } else { + require.NoError(t, err) + } + + assert.Equal(t, testCase.port, port) + assert.Equal(t, testCase.token, token) + assert.Equal(t, testCase.expiration, expiration) + }) + } +} + +func makePIAPayload(t *testing.T, token string, port uint16, expiration time.Time) (payload string) { + t.Helper() + + data := piaPayload{ + Token: token, + Port: port, + Expiration: expiration, + } + + b, err := json.Marshal(data) + require.NoError(t, err) + + return base64.StdEncoding.EncodeToString(b) +}