Maintenance: split each provider in a package
- Fix VyprVPN port - Fix missing Auth overrides
This commit is contained in:
65
internal/provider/privateinternetaccess/connection.go
Normal file
65
internal/provider/privateinternetaccess/connection.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
func (p *PIA) GetOpenVPNConnection(selection configuration.ServerSelection) (
|
||||
connection models.OpenVPNConnection, err error) {
|
||||
protocol := constants.UDP
|
||||
if selection.TCP {
|
||||
protocol = constants.TCP
|
||||
}
|
||||
|
||||
port, err := getPort(selection.TCP, selection.EncryptionPreset, selection.CustomPort)
|
||||
if err != nil {
|
||||
return connection, err
|
||||
}
|
||||
|
||||
servers, err := p.filterServers(selection)
|
||||
if err != nil {
|
||||
return connection, err
|
||||
}
|
||||
|
||||
var connections []models.OpenVPNConnection
|
||||
for _, server := range servers {
|
||||
for _, IP := range server.IPs {
|
||||
connection := models.OpenVPNConnection{
|
||||
IP: IP,
|
||||
Port: port,
|
||||
Protocol: protocol,
|
||||
}
|
||||
connections = append(connections, connection)
|
||||
}
|
||||
}
|
||||
|
||||
if selection.TargetIP != nil {
|
||||
connection, err = utils.GetTargetIPConnection(connections, selection.TargetIP)
|
||||
} else {
|
||||
connection, err = utils.PickRandomConnection(connections, p.randSource), nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return connection, err
|
||||
}
|
||||
|
||||
p.activeServer = findActiveServer(servers, connection)
|
||||
|
||||
return connection, nil
|
||||
}
|
||||
|
||||
func findActiveServer(servers []models.PIAServer,
|
||||
connection models.OpenVPNConnection) (activeServer models.PIAServer) {
|
||||
// Reverse lookup server using the randomly picked connection
|
||||
for _, server := range servers {
|
||||
for _, ip := range server.IPs {
|
||||
if connection.IP.Equal(ip) {
|
||||
return server
|
||||
}
|
||||
}
|
||||
}
|
||||
return activeServer
|
||||
}
|
||||
29
internal/provider/privateinternetaccess/filter.go
Normal file
29
internal/provider/privateinternetaccess/filter.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
func (p *PIA) filterServers(selection configuration.ServerSelection) (
|
||||
servers []models.PIAServer, err error) {
|
||||
for _, server := range p.servers {
|
||||
switch {
|
||||
case
|
||||
utils.FilterByPossibilities(server.Region, selection.Regions),
|
||||
utils.FilterByPossibilities(server.Hostname, selection.Hostnames),
|
||||
utils.FilterByPossibilities(server.ServerName, selection.Names),
|
||||
selection.TCP && !server.TCP,
|
||||
!selection.TCP && !server.UDP:
|
||||
default:
|
||||
servers = append(servers, server)
|
||||
}
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
return nil, utils.NoServerFoundError(selection)
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
57
internal/provider/privateinternetaccess/httpclient.go
Normal file
57
internal/provider/privateinternetaccess/httpclient.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrParseCertificate = errors.New("cannot parse X509 certificate")
|
||||
)
|
||||
|
||||
func newHTTPClient(serverName string) (client *http.Client, err error) {
|
||||
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrParseCertificate, err)
|
||||
}
|
||||
certificate, err := x509.ParseCertificate(certificateBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrParseCertificate, err)
|
||||
}
|
||||
|
||||
//nolint:gomnd
|
||||
transport := &http.Transport{
|
||||
// Settings taken from http.DefaultTransport
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(certificate)
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
const httpTimeout = 30 * time.Second
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: httpTimeout,
|
||||
}, nil
|
||||
}
|
||||
83
internal/provider/privateinternetaccess/openvpnconf.go
Normal file
83
internal/provider/privateinternetaccess/openvpnconf.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
)
|
||||
|
||||
func (p *PIA) BuildConf(connection models.OpenVPNConnection,
|
||||
username string, settings configuration.OpenVPN) (lines []string) {
|
||||
var defaultCipher, defaultAuth, X509CRL, certificate string
|
||||
if settings.Provider.ExtraConfigOptions.EncryptionPreset == constants.PIAEncryptionPresetNormal {
|
||||
defaultCipher = constants.AES128cbc
|
||||
defaultAuth = constants.SHA1
|
||||
X509CRL = constants.PiaX509CRLNormal
|
||||
certificate = constants.PIACertificateNormal
|
||||
} else { // strong encryption
|
||||
defaultCipher = constants.AES256cbc
|
||||
defaultAuth = constants.SHA256
|
||||
X509CRL = constants.PiaX509CRLStrong
|
||||
certificate = constants.PIACertificateStrong
|
||||
}
|
||||
|
||||
if settings.Cipher == "" {
|
||||
settings.Cipher = defaultCipher
|
||||
}
|
||||
|
||||
if settings.Auth == "" {
|
||||
settings.Auth = defaultAuth
|
||||
}
|
||||
|
||||
lines = []string{
|
||||
"client",
|
||||
"dev tun",
|
||||
"nobind",
|
||||
"persist-key",
|
||||
"remote-cert-tls server",
|
||||
|
||||
// PIA specific
|
||||
"reneg-sec 0",
|
||||
"disable-occ",
|
||||
"compress", // allow PIA server to choose the compression to use
|
||||
|
||||
// Added constant values
|
||||
"auth-nocache",
|
||||
"mute-replay-warnings",
|
||||
"pull-filter ignore \"auth-token\"", // prevent auth failed loops
|
||||
"auth-retry nointeract",
|
||||
"suppress-timestamps",
|
||||
|
||||
// Modified variables
|
||||
"verb " + strconv.Itoa(settings.Verbosity),
|
||||
"auth-user-pass " + constants.OpenVPNAuthConf,
|
||||
connection.ProtoLine(),
|
||||
connection.RemoteLine(),
|
||||
"data-ciphers-fallback " + settings.Cipher,
|
||||
"data-ciphers " + settings.Cipher,
|
||||
"auth " + settings.Auth,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(settings.Cipher, "-gcm") {
|
||||
lines = append(lines, "ncp-disable")
|
||||
}
|
||||
|
||||
if !settings.Root {
|
||||
lines = append(lines, "user "+username)
|
||||
}
|
||||
|
||||
if settings.MSSFix > 0 {
|
||||
lines = append(lines, "mssfix "+strconv.Itoa(int(settings.MSSFix)))
|
||||
}
|
||||
|
||||
lines = append(lines, utils.WrapOpenvpnCA(certificate)...)
|
||||
lines = append(lines, utils.WrapOpenvpnCRLVerify(X509CRL)...)
|
||||
|
||||
lines = append(lines, "")
|
||||
|
||||
return lines
|
||||
}
|
||||
59
internal/provider/privateinternetaccess/port.go
Normal file
59
internal/provider/privateinternetaccess/port.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
func getPort(tcp bool, encryptionPreset string, customPort uint16) (
|
||||
port uint16, err error) {
|
||||
if customPort == 0 {
|
||||
return getDefaultPort(tcp, encryptionPreset), nil
|
||||
}
|
||||
|
||||
if err := checkPort(customPort, tcp); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return customPort, nil
|
||||
}
|
||||
|
||||
func getDefaultPort(tcp bool, encryptionPreset string) (port uint16) {
|
||||
if tcp {
|
||||
switch encryptionPreset {
|
||||
case constants.PIAEncryptionPresetNormal:
|
||||
port = 502
|
||||
case constants.PIAEncryptionPresetStrong:
|
||||
port = 501
|
||||
}
|
||||
} else {
|
||||
switch encryptionPreset {
|
||||
case constants.PIAEncryptionPresetNormal:
|
||||
port = 1198
|
||||
case constants.PIAEncryptionPresetStrong:
|
||||
port = 1197
|
||||
}
|
||||
}
|
||||
return port
|
||||
}
|
||||
|
||||
var ErrInvalidPort = errors.New("invalid port number")
|
||||
|
||||
func checkPort(port uint16, tcp bool) (err error) {
|
||||
if tcp {
|
||||
switch port {
|
||||
case 80, 110, 443: //nolint:gomnd
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%w: %d for protocol TCP", ErrInvalidPort, port)
|
||||
}
|
||||
}
|
||||
switch port {
|
||||
case 53, 1194, 1197, 1198, 8080, 9201: //nolint:gomnd
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%w: %d for protocol UDP", ErrInvalidPort, port)
|
||||
}
|
||||
}
|
||||
508
internal/provider/privateinternetaccess/portforward.go
Normal file
508
internal/provider/privateinternetaccess/portforward.go
Normal file
@@ -0,0 +1,508 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
format "github.com/qdm12/gluetun/internal/logging"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBindPort = errors.New("cannot bind port")
|
||||
)
|
||||
|
||||
//nolint:gocognit
|
||||
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
defer logger.Warn("loop exited")
|
||||
|
||||
commonName := p.activeServer.ServerName
|
||||
if !p.activeServer.PortForward {
|
||||
logger.Error("The server " + commonName +
|
||||
" (region " + p.activeServer.Region + ") does not support port forwarding")
|
||||
return
|
||||
}
|
||||
if gateway == nil {
|
||||
logger.Error("aborting because: VPN gateway IP address was not found")
|
||||
return
|
||||
}
|
||||
|
||||
privateIPClient, err := newHTTPClient(commonName)
|
||||
if err != nil {
|
||||
logger.Error("aborting because: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
data, err := readPIAPortForwardData(openFile)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
dataFound := data.Port > 0
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
expired := durationToExpiration <= 0
|
||||
|
||||
if dataFound {
|
||||
logger.Info("Found persistent forwarded port data for port " + strconv.Itoa(int(data.Port)))
|
||||
if expired {
|
||||
logger.Warn("Forwarded port data expired on " +
|
||||
data.Expiration.Format(time.RFC1123) + ", getting another one")
|
||||
} else {
|
||||
logger.Info("Forwarded port data expires in " + format.FormatDuration(durationToExpiration))
|
||||
}
|
||||
}
|
||||
|
||||
if !dataFound || expired {
|
||||
tryUntilSuccessful(ctx, logger, func() error {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
|
||||
return err
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
durationToExpiration = data.Expiration.Sub(p.timeNow())
|
||||
}
|
||||
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
|
||||
" expiring in " + format.FormatDuration(durationToExpiration))
|
||||
|
||||
// First time binding
|
||||
tryUntilSuccessful(ctx, logger, func() error {
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrBindPort, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
expiryTimer := time.NewTimer(durationToExpiration)
|
||||
const keepAlivePeriod = 15 * time.Minute
|
||||
// Timer behaving as a ticker
|
||||
keepAliveTimer := time.NewTimer(keepAlivePeriod)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := fw.RemoveAllowedPort(removeCtx, data.Port); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
if !keepAliveTimer.Stop() {
|
||||
<-keepAliveTimer.C
|
||||
}
|
||||
if !expiryTimer.Stop() {
|
||||
<-expiryTimer.C
|
||||
}
|
||||
return
|
||||
case <-keepAliveTimer.C:
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
logger.Error("cannot bind port: " + err.Error())
|
||||
}
|
||||
keepAliveTimer.Reset(keepAlivePeriod)
|
||||
case <-expiryTimer.C:
|
||||
logger.Warn("Forward port has expired on " +
|
||||
data.Expiration.Format(time.RFC1123) + ", getting another one")
|
||||
oldPort := data.Port
|
||||
for {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
|
||||
" expiring in " + format.FormatDuration(durationToExpiration))
|
||||
if err := fw.RemoveAllowedPort(ctx, oldPort); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
||||
logger.Error("Cannot write port forward data to file: " + err.Error())
|
||||
}
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
logger.Error("Cannot bind port: " + err.Error())
|
||||
}
|
||||
if !keepAliveTimer.Stop() {
|
||||
<-keepAliveTimer.C
|
||||
}
|
||||
keepAliveTimer.Reset(keepAlivePeriod)
|
||||
expiryTimer.Reset(durationToExpiration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrFetchToken = errors.New("cannot fetch token")
|
||||
ErrFetchPortForwarding = errors.New("cannot fetch port forwarding data")
|
||||
ErrPersistPortForwarding = errors.New("cannot persist port forwarding data")
|
||||
)
|
||||
|
||||
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchToken(ctx, openFile, client)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("%w: %s", ErrFetchToken, err)
|
||||
}
|
||||
|
||||
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err)
|
||||
}
|
||||
|
||||
if err := writePIAPortForwardData(openFile, data); err != nil {
|
||||
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
type piaPayload struct {
|
||||
Token string `json:"token"`
|
||||
Port uint16 `json:"port"`
|
||||
Expiration time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type piaPortForwardData struct {
|
||||
Port uint16 `json:"port"`
|
||||
Token string `json:"token"`
|
||||
Signature string `json:"signature"`
|
||||
Expiration time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||
file, err := openFile(constants.PIAPortForward, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return data, nil
|
||||
} else if err != nil {
|
||||
return data, err
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
if err := decoder.Decode(&data); err != nil {
|
||||
_ = file.Close()
|
||||
return data, err
|
||||
}
|
||||
|
||||
return data, file.Close()
|
||||
}
|
||||
|
||||
func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) {
|
||||
file, err := openFile(constants.PIAPortForward, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(file)
|
||||
|
||||
if err := encoder.Encode(data); err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func unpackPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
|
||||
b, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return 0, "", expiration,
|
||||
fmt.Errorf("%w: for payload: %s", err, payload)
|
||||
}
|
||||
|
||||
var payloadData piaPayload
|
||||
if err := json.Unmarshal(b, &payloadData); err != nil {
|
||||
return 0, "", expiration,
|
||||
fmt.Errorf("%w: for data: %s", err, string(b))
|
||||
}
|
||||
|
||||
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
|
||||
}
|
||||
|
||||
func packPayload(port uint16, token string, expiration time.Time) (payload string, err error) {
|
||||
payloadData := piaPayload{
|
||||
Token: token,
|
||||
Port: port,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
b, err := json.Marshal(&payloadData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
payload = base64.StdEncoding.EncodeToString(b)
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
var (
|
||||
errGetCredentials = errors.New("cannot get username and password")
|
||||
errEmptyToken = errors.New("token received is empty")
|
||||
)
|
||||
|
||||
func fetchToken(ctx context.Context, openFile os.OpenFileFunc,
|
||||
client *http.Client) (token string, err error) {
|
||||
username, password, err := getOpenvpnCredentials(openFile)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %s", errGetCredentials, err)
|
||||
}
|
||||
|
||||
errSubstitutions := map[string]string{
|
||||
username: "<username>",
|
||||
password: "<password>",
|
||||
}
|
||||
|
||||
url := url.URL{
|
||||
Scheme: "https",
|
||||
User: url.UserPassword(username, password),
|
||||
Host: "privateinternetaccess.com",
|
||||
Path: "/gtoken/generateToken",
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||
if err != nil {
|
||||
return "", replaceInErr(err, errSubstitutions)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", replaceInErr(err, errSubstitutions)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return "", makeNOKStatusError(response, nil)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := decoder.Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("%w: %s", ErrUnmarshalResponse, err)
|
||||
}
|
||||
|
||||
if result.Token == "" {
|
||||
return "", errEmptyToken
|
||||
}
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
var (
|
||||
errAuthFileRead = errors.New("cannot read OpenVPN authentication file")
|
||||
errAuthFileMalformed = errors.New("authentication file is malformed")
|
||||
)
|
||||
|
||||
func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) {
|
||||
file, err := openFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
|
||||
}
|
||||
|
||||
authData, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
|
||||
}
|
||||
|
||||
if err := file.Close(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
lines := strings.Split(string(authData), "\n")
|
||||
const minLines = 2
|
||||
if len(lines) < minLines {
|
||||
return "", "", fmt.Errorf("%w: only %d lines exist", errAuthFileMalformed, len(lines))
|
||||
}
|
||||
|
||||
username, password = lines[0], lines[1]
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
var (
|
||||
errGetSignaturePayload = errors.New("cannot obtain signature payload")
|
||||
errUnpackPayload = errors.New("cannot unpack payload data")
|
||||
)
|
||||
|
||||
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (
|
||||
port uint16, signature string, expiration time.Time, err error) {
|
||||
errSubstitutions := map[string]string{token: "<token>"}
|
||||
|
||||
queryParams := new(url.Values)
|
||||
queryParams.Add("token", token)
|
||||
url := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
||||
Path: "/getSignature",
|
||||
RawQuery: queryParams.Encode(),
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||
if err != nil {
|
||||
err = replaceInErr(err, errSubstitutions)
|
||||
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
err = replaceInErr(err, errSubstitutions)
|
||||
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return 0, "", expiration, makeNOKStatusError(response, errSubstitutions)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
var data struct {
|
||||
Status string `json:"status"`
|
||||
Payload string `json:"payload"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
if err := decoder.Decode(&data); err != nil {
|
||||
return 0, "", expiration, fmt.Errorf("%w: %s", ErrUnmarshalResponse, err)
|
||||
}
|
||||
|
||||
if data.Status != "OK" {
|
||||
return 0, "", expiration, fmt.Errorf("%w: status is: %s", ErrBadResponse, data.Status)
|
||||
}
|
||||
|
||||
port, _, expiration, err = unpackPayload(data.Payload)
|
||||
if err != nil {
|
||||
return 0, "", expiration, fmt.Errorf("%w: %s", errUnpackPayload, err)
|
||||
}
|
||||
return port, data.Signature, expiration, err
|
||||
}
|
||||
|
||||
var (
|
||||
ErrSerializePayload = errors.New("cannot serialize payload")
|
||||
ErrUnmarshalResponse = errors.New("cannot unmarshal response")
|
||||
ErrBadResponse = errors.New("bad response received")
|
||||
)
|
||||
|
||||
func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
|
||||
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrSerializePayload, err)
|
||||
}
|
||||
|
||||
queryParams := new(url.Values)
|
||||
queryParams.Add("payload", payload)
|
||||
queryParams.Add("signature", data.Signature)
|
||||
url := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
||||
Path: "/bindPort",
|
||||
RawQuery: queryParams.Encode(),
|
||||
}
|
||||
|
||||
errSubstitutions := map[string]string{
|
||||
payload: "<payload>",
|
||||
data.Signature: "<signature>",
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||
if err != nil {
|
||||
return replaceInErr(err, errSubstitutions)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return replaceInErr(err, errSubstitutions)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return makeNOKStatusError(response, errSubstitutions)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
var responseData struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := decoder.Decode(&responseData); err != nil {
|
||||
return fmt.Errorf("%w: from %s: %s", ErrUnmarshalResponse, url.String(), err)
|
||||
}
|
||||
|
||||
if responseData.Status != "OK" {
|
||||
return fmt.Errorf("%w: %s: %s", ErrBadResponse, responseData.Status, responseData.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(openFile os.OpenFileFunc,
|
||||
filepath string, port uint16) (err error) {
|
||||
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
// replaceInErr is used to remove sensitive information from errors.
|
||||
func replaceInErr(err error, substitutions map[string]string) error {
|
||||
s := replaceInString(err.Error(), substitutions)
|
||||
return errors.New(s)
|
||||
}
|
||||
|
||||
// replaceInString is used to remove sensitive information.
|
||||
func replaceInString(s string, substitutions map[string]string) string {
|
||||
for old, new := range substitutions {
|
||||
s = strings.ReplaceAll(s, old, new)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
var ErrHTTPStatusCodeNotOK = errors.New("HTTP status code is not OK")
|
||||
|
||||
func makeNOKStatusError(response *http.Response, substitutions map[string]string) (err error) {
|
||||
url := response.Request.URL.String()
|
||||
url = replaceInString(url, substitutions)
|
||||
|
||||
b, _ := ioutil.ReadAll(response.Body)
|
||||
shortenMessage := string(b)
|
||||
shortenMessage = strings.ReplaceAll(shortenMessage, "\n", "")
|
||||
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
|
||||
shortenMessage = replaceInString(shortenMessage, substitutions)
|
||||
|
||||
return fmt.Errorf("%w: %s: %s: response received: %s",
|
||||
ErrHTTPStatusCodeNotOK, url, response.Status, shortenMessage)
|
||||
}
|
||||
143
internal/provider/privateinternetaccess/portforward_test.go
Normal file
143
internal/provider/privateinternetaccess/portforward_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_newHTTPClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const serverName = "testserver"
|
||||
|
||||
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong)
|
||||
require.NoError(t, err)
|
||||
certificate, err := x509.ParseCertificate(certificateBytes)
|
||||
require.NoError(t, err)
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(certificate)
|
||||
expectedRootCAsSubjects := rootCAs.Subjects()
|
||||
|
||||
expectedPIATransportTLSConfig := &tls.Config{
|
||||
// Can't directly compare RootCAs because of private fields
|
||||
RootCAs: nil,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
piaClient, err := newHTTPClient(serverName)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify pia transport TLS config is set
|
||||
piaTransport := piaClient.Transport.(*http.Transport)
|
||||
rootCAsSubjects := piaTransport.TLSClientConfig.RootCAs.Subjects()
|
||||
assert.Equal(t, expectedRootCAsSubjects, rootCAsSubjects)
|
||||
piaTransport.TLSClientConfig.RootCAs = nil
|
||||
assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig)
|
||||
}
|
||||
|
||||
func Test_unpackPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const exampleToken = "token"
|
||||
const examplePort = 2000
|
||||
exampleExpiration := time.Unix(1000, 0).UTC()
|
||||
|
||||
testCases := map[string]struct {
|
||||
payload string
|
||||
port uint16
|
||||
token string
|
||||
expiration time.Time
|
||||
err error
|
||||
}{
|
||||
"valid payload": {
|
||||
payload: makePIAPayload(t, exampleToken, examplePort, exampleExpiration),
|
||||
port: examplePort,
|
||||
token: exampleToken,
|
||||
expiration: exampleExpiration,
|
||||
err: nil,
|
||||
},
|
||||
"invalid base64 payload": {
|
||||
payload: "invalid",
|
||||
err: errors.New("illegal base64 data at input byte 4: for payload: invalid"),
|
||||
},
|
||||
"invalid json payload": {
|
||||
payload: base64.StdEncoding.EncodeToString([]byte{1}),
|
||||
err: errors.New("invalid character '\\x01' looking for beginning of value: for data: \x01"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
port, token, expiration, err := unpackPayload(testCase.payload)
|
||||
|
||||
if testCase.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, testCase.err.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testCase.port, port)
|
||||
assert.Equal(t, testCase.token, token)
|
||||
assert.Equal(t, testCase.expiration, expiration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makePIAPayload(t *testing.T, token string, port uint16, expiration time.Time) (payload string) {
|
||||
t.Helper()
|
||||
|
||||
data := piaPayload{
|
||||
Token: token,
|
||||
Port: port,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
b, err := json.Marshal(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func Test_replaceInString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
s string
|
||||
substitutions map[string]string
|
||||
result string
|
||||
}{
|
||||
"empty": {},
|
||||
"multiple replacements": {
|
||||
s: "https://test.com/username/password/",
|
||||
substitutions: map[string]string{
|
||||
"username": "xxx",
|
||||
"password": "yyy",
|
||||
},
|
||||
result: "https://test.com/xxx/yyy/",
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := replaceInString(testCase.s, testCase.substitutions)
|
||||
assert.Equal(t, testCase.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
23
internal/provider/privateinternetaccess/provider.go
Normal file
23
internal/provider/privateinternetaccess/provider.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type PIA struct {
|
||||
servers []models.PIAServer
|
||||
randSource rand.Source
|
||||
timeNow func() time.Time
|
||||
activeServer models.PIAServer
|
||||
}
|
||||
|
||||
func New(servers []models.PIAServer, randSource rand.Source, timeNow func() time.Time) *PIA {
|
||||
return &PIA{
|
||||
servers: servers,
|
||||
timeNow: timeNow,
|
||||
randSource: randSource,
|
||||
}
|
||||
}
|
||||
31
internal/provider/privateinternetaccess/try.go
Normal file
31
internal/provider/privateinternetaccess/try.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) {
|
||||
const initialRetryPeriod = 5 * time.Second
|
||||
retryPeriod := initialRetryPeriod
|
||||
for {
|
||||
err := fn()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
logger.Error(err)
|
||||
logger.Info("Trying again in " + retryPeriod.String())
|
||||
timer := time.NewTimer(retryPeriod)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return
|
||||
}
|
||||
retryPeriod *= 2
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user