|
|
|
|
@@ -21,12 +21,8 @@ import (
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
|
|
|
|
|
ErrServerNameEmpty = errors.New("server name is empty")
|
|
|
|
|
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
|
|
|
|
|
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
|
|
|
|
|
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
|
|
|
|
|
ErrBindPort = errors.New("cannot bind port")
|
|
|
|
|
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
|
|
|
|
|
ErrServerNameEmpty = errors.New("server name is empty")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// PortForward obtains a VPN server side port forwarded from PIA.
|
|
|
|
|
@@ -53,12 +49,12 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
|
|
|
|
|
|
|
|
|
privateIPClient, err := newHTTPClient(serverName)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
|
|
|
|
|
return 0, fmt.Errorf("cannot create custom HTTP client: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
data, err := readPIAPortForwardData(p.portForwardPath)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
|
|
|
|
|
return 0, fmt.Errorf("cannot read saved port forwarded data: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dataFound := data.Port > 0
|
|
|
|
|
@@ -79,7 +75,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
|
|
|
|
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
|
|
|
|
p.portForwardPath, p.authFilePath)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err)
|
|
|
|
|
return 0, fmt.Errorf("cannot refresh port forward data: %w", err)
|
|
|
|
|
}
|
|
|
|
|
durationToExpiration = data.Expiration.Sub(p.timeNow())
|
|
|
|
|
}
|
|
|
|
|
@@ -87,7 +83,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
|
|
|
|
|
|
|
|
|
// First time binding
|
|
|
|
|
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
|
|
|
|
return 0, fmt.Errorf("%w: %s", ErrBindPort, err)
|
|
|
|
|
return 0, fmt.Errorf("cannot bind port: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return data.Port, nil
|
|
|
|
|
@@ -101,12 +97,12 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
|
|
|
|
port uint16, gateway net.IP, serverName string) (err error) {
|
|
|
|
|
privateIPClient, err := newHTTPClient(serverName)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
|
|
|
|
|
return fmt.Errorf("cannot create custom HTTP client: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
data, err := readPIAPortForwardData(p.portForwardPath)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
|
|
|
|
|
return fmt.Errorf("cannot read saved port forwarded data: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
|
|
|
|
@@ -128,7 +124,7 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
|
|
|
|
case <-keepAliveTimer.C:
|
|
|
|
|
err := bindPort(ctx, privateIPClient, gateway, data)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("%w: %s", ErrBindPort, err)
|
|
|
|
|
return fmt.Errorf("cannot bind port: %w", err)
|
|
|
|
|
}
|
|
|
|
|
keepAliveTimer.Reset(keepAlivePeriod)
|
|
|
|
|
case <-expiryTimer.C:
|
|
|
|
|
@@ -138,26 +134,20 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
ErrFetchToken = errors.New("cannot fetch token")
|
|
|
|
|
ErrFetchPortForwarding = errors.New("cannot fetch port forwarding data")
|
|
|
|
|
ErrPersistPortForwarding = errors.New("cannot persist port forwarding data")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
|
|
|
|
gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
|
|
|
|
|
data.Token, err = fetchToken(ctx, client, authFilePath)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return data, fmt.Errorf("%w: %s", ErrFetchToken, err)
|
|
|
|
|
return data, fmt.Errorf("cannot fetch token: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err)
|
|
|
|
|
return data, fmt.Errorf("cannot fetch port forwarding data: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := writePIAPortForwardData(portForwardPath, data); err != nil {
|
|
|
|
|
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err)
|
|
|
|
|
return data, fmt.Errorf("cannot persist port forwarding data: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return data, nil
|
|
|
|
|
@@ -242,15 +232,14 @@ func packPayload(port uint16, token string, expiration time.Time) (payload strin
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
errGetCredentials = errors.New("cannot get username and password")
|
|
|
|
|
errEmptyToken = errors.New("token received is empty")
|
|
|
|
|
errEmptyToken = errors.New("token received is empty")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func fetchToken(ctx context.Context, client *http.Client,
|
|
|
|
|
authFilePath string) (token string, err error) {
|
|
|
|
|
username, password, err := getOpenvpnCredentials(authFilePath)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("%w: %s", errGetCredentials, err)
|
|
|
|
|
return "", fmt.Errorf("cannot get username and password: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
errSubstitutions := map[string]string{
|
|
|
|
|
@@ -284,7 +273,7 @@ func fetchToken(ctx context.Context, client *http.Client,
|
|
|
|
|
Token string `json:"token"`
|
|
|
|
|
}
|
|
|
|
|
if err := decoder.Decode(&result); err != nil {
|
|
|
|
|
return "", fmt.Errorf("%w: %s", ErrUnmarshalResponse, err)
|
|
|
|
|
return "", fmt.Errorf("cannot unmarshal response: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if result.Token == "" {
|
|
|
|
|
@@ -294,7 +283,6 @@ func fetchToken(ctx context.Context, client *http.Client,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
errAuthFileRead = errors.New("cannot read OpenVPN authentication file")
|
|
|
|
|
errAuthFileMalformed = errors.New("authentication file is malformed")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -302,13 +290,13 @@ func getOpenvpnCredentials(authFilePath string) (
|
|
|
|
|
username, password string, err error) {
|
|
|
|
|
file, err := os.Open(authFilePath)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
|
|
|
|
|
return "", "", fmt.Errorf("cannot read OpenVPN authentication file: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
authData, err := io.ReadAll(file)
|
|
|
|
|
if err != nil {
|
|
|
|
|
_ = file.Close()
|
|
|
|
|
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
|
|
|
|
|
return "", "", fmt.Errorf("authentication file is malformed: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := file.Close(); err != nil {
|
|
|
|
|
@@ -325,11 +313,6 @@ func getOpenvpnCredentials(authFilePath string) (
|
|
|
|
|
return username, password, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
errGetSignaturePayload = errors.New("cannot obtain signature payload")
|
|
|
|
|
errUnpackPayload = errors.New("cannot unpack payload data")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (
|
|
|
|
|
port uint16, signature string, expiration time.Time, err error) {
|
|
|
|
|
errSubstitutions := map[string]string{token: "<token>"}
|
|
|
|
|
@@ -345,13 +328,13 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
|
|
|
|
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
|
|
|
|
if err != nil {
|
|
|
|
|
err = replaceInErr(err, errSubstitutions)
|
|
|
|
|
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err)
|
|
|
|
|
return 0, "", expiration, fmt.Errorf("cannot obtain signature payload: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
response, err := client.Do(request)
|
|
|
|
|
if err != nil {
|
|
|
|
|
err = replaceInErr(err, errSubstitutions)
|
|
|
|
|
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err)
|
|
|
|
|
return 0, "", expiration, fmt.Errorf("cannot obtain signature payload: %w", err)
|
|
|
|
|
}
|
|
|
|
|
defer response.Body.Close()
|
|
|
|
|
|
|
|
|
|
@@ -366,7 +349,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
|
|
|
|
|
Signature string `json:"signature"`
|
|
|
|
|
}
|
|
|
|
|
if err := decoder.Decode(&data); err != nil {
|
|
|
|
|
return 0, "", expiration, fmt.Errorf("%w: %s", ErrUnmarshalResponse, err)
|
|
|
|
|
return 0, "", expiration, fmt.Errorf("cannot unmarshal response: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if data.Status != "OK" {
|
|
|
|
|
@@ -375,21 +358,19 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
|
|
|
|
|
|
|
|
|
|
port, _, expiration, err = unpackPayload(data.Payload)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return 0, "", expiration, fmt.Errorf("%w: %s", errUnpackPayload, err)
|
|
|
|
|
return 0, "", expiration, fmt.Errorf("cannot unpack payload data: %w", err)
|
|
|
|
|
}
|
|
|
|
|
return port, data.Signature, expiration, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
ErrSerializePayload = errors.New("cannot serialize payload")
|
|
|
|
|
ErrUnmarshalResponse = errors.New("cannot unmarshal response")
|
|
|
|
|
ErrBadResponse = errors.New("bad response received")
|
|
|
|
|
ErrBadResponse = errors.New("bad response received")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
|
|
|
|
|
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("%w: %s", ErrSerializePayload, err)
|
|
|
|
|
return fmt.Errorf("cannot serialize payload: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
queryParams := make(url.Values)
|
|
|
|
|
@@ -428,7 +409,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
|
|
|
|
|
Message string `json:"message"`
|
|
|
|
|
}
|
|
|
|
|
if err := decoder.Decode(&responseData); err != nil {
|
|
|
|
|
return fmt.Errorf("%w: from %s: %s", ErrUnmarshalResponse, url.String(), err)
|
|
|
|
|
return fmt.Errorf("cannot unmarshal response: from %s: %w", url.String(), err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if responseData.Status != "OK" {
|
|
|
|
|
@@ -464,6 +445,7 @@ func makeNOKStatusError(response *http.Response, substitutions map[string]string
|
|
|
|
|
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
|
|
|
|
|
shortenMessage = replaceInString(shortenMessage, substitutions)
|
|
|
|
|
|
|
|
|
|
return fmt.Errorf("%w: %s: %s: response received: %s",
|
|
|
|
|
ErrHTTPStatusCodeNotOK, url, response.Status, shortenMessage)
|
|
|
|
|
return fmt.Errorf("%w: %s: %d %s: response received: %s",
|
|
|
|
|
ErrHTTPStatusCodeNotOK, url, response.StatusCode,
|
|
|
|
|
response.Status, shortenMessage)
|
|
|
|
|
}
|
|
|
|
|
|