diff --git a/README.md b/README.md index 42d69e66..eae9bdd4 100644 --- a/README.md +++ b/README.md @@ -115,8 +115,8 @@ docker run --rm --network=container:gluetun alpine:3.12 wget -qO- https://ipinfo | 🏁 `PASSWORD` | | | Your password | | `REGION` | | One of the [PIA regions](https://www.privateinternetaccess.com/pages/network/) | VPN server region | | `PIA_ENCRYPTION` | `strong` | `normal`, `strong` | Encryption preset | - | `PORT_FORWARDING` | `off` | `on`, `off` | Enable port forwarding on the VPN server **for old only** | - | `PORT_FORWARDING_STATUS_FILE` | `/tmp/gluetun/forwarded_port` | Any filepath | Filepath to store the forwarded port number **for old only** | + | `PORT_FORWARDING` | `off` | `on`, `off` | Enable port forwarding on the VPN server | + | `PORT_FORWARDING_STATUS_FILE` | `/tmp/gluetun/forwarded_port` | Any filepath | Filepath to store the forwarded port number | - Mullvad diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 82110b7f..327ccab0 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "net/http" "os" "os/signal" @@ -188,7 +189,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers, - ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel) + ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel) wg.Add(1) // wait for restartOpenvpn go openvpnLooper.Run(ctx, wg) @@ -341,10 +342,11 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, }) } +//nolint:gocognit func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{}, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, routing routing.Routing, logger logging.Logger, httpClient *http.Client, - versionInformation, portForwardingEnabled bool, startPortForward func()) { + versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) { defer wg.Done() tickerWg := &sync.WaitGroup{} // for linters only @@ -364,18 +366,35 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn tickerWg.Add(2) go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) - if portForwardingEnabled { - time.AfterFunc(5*time.Second, startPortForward) - } defaultInterface, _, err := routing.DefaultRoute() if err != nil { logger.Warn(err) } else { - vpnGatewayIP, err := routing.VPNGatewayIP(defaultInterface) + vpnDestination, err := routing.VPNDestinationIP(defaultInterface) if err != nil { logger.Warn(err) } else { - logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) + logger.Info("VPN routing IP address: %s", vpnDestination) + } + } + if portForwardingEnabled { + // TODO make instantaneous once v3 go out of service + const waitDuration = 5 * time.Second + timer := time.NewTimer(waitDuration) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + continue + case <-timer.C: + // vpnGateway required only for PIA v4 + vpnGateway, err := routing.VPNLocalGatewayIP() + if err != nil { + logger.Error(err) + } + logger.Info("VPN gateway IP address: %s", vpnGateway) + startPortForward(vpnGateway) } } case <-dnsReadyCh: diff --git a/internal/constants/paths.go b/internal/constants/paths.go index deb5a995..c0d2a9b8 100644 --- a/internal/constants/paths.go +++ b/internal/constants/paths.go @@ -15,6 +15,8 @@ const ( OpenVPNAuthConf models.Filepath = "/etc/openvpn/auth.conf" // OpenVPNConf is the file path to the OpenVPN client configuration file OpenVPNConf models.Filepath = "/etc/openvpn/target.ovpn" + // PIAPortForward is the file path to the port forwarding JSON information for PIA v4 servers + PIAPortForward models.Filepath = "/gluetun/piaportforward.json" // TunnelDevice is the file path to tun device TunnelDevice models.Filepath = "/dev/net/tun" // NetRoute is the path to the file containing information on the network route diff --git a/internal/constants/splash.go b/internal/constants/splash.go index 3172803d..8b2754f6 100644 --- a/internal/constants/splash.go +++ b/internal/constants/splash.go @@ -2,9 +2,9 @@ package constants const ( // Announcement is a message announcement - Announcement = "Update servers information see https://github.com/qdm12/gluetun/wiki/Update-servers-information" + Announcement = "Port forwarding is working for PIA v4 servers" // AnnouncementExpiration is the expiration date of the announcement in format yyyy-mm-dd - AnnouncementExpiration = "2020-10-10" + AnnouncementExpiration = "2020-11-15" ) const ( diff --git a/internal/logging/duration.go b/internal/logging/duration.go new file mode 100644 index 00000000..aa1869bb --- /dev/null +++ b/internal/logging/duration.go @@ -0,0 +1,29 @@ +package logging + +import ( + "fmt" + "time" +) + +func FormatDuration(duration time.Duration) string { + switch { + case duration < time.Minute: + seconds := int(duration.Round(time.Second).Seconds()) + if seconds < 2 { + return fmt.Sprintf("%d second", seconds) + } + return fmt.Sprintf("%d seconds", seconds) + case duration <= time.Hour: + minutes := int(duration.Round(time.Minute).Minutes()) + if minutes == 1 { + return "1 minute" + } + return fmt.Sprintf("%d minutes", minutes) + case duration < 48*time.Hour: + hours := int(duration.Truncate(time.Hour).Hours()) + return fmt.Sprintf("%d hours", hours) + default: + days := int(duration.Truncate(time.Hour).Hours() / 24) + return fmt.Sprintf("%d days", days) + } +} diff --git a/internal/version/version_test.go b/internal/logging/duration_test.go similarity index 91% rename from internal/version/version_test.go rename to internal/logging/duration_test.go index bd560665..593ed188 100644 --- a/internal/version/version_test.go +++ b/internal/logging/duration_test.go @@ -1,4 +1,4 @@ -package version +package logging import ( "testing" @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_formatDuration(t *testing.T) { +func Test_FormatDuration(t *testing.T) { t.Parallel() testCases := map[string]struct { duration time.Duration @@ -57,7 +57,7 @@ func Test_formatDuration(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() - s := formatDuration(testCase.duration) + s := FormatDuration(testCase.duration) assert.Equal(t, testCase.s, s) }) } diff --git a/internal/models/selection.go b/internal/models/selection.go index d0b36e65..4d664da8 100644 --- a/internal/models/selection.go +++ b/internal/models/selection.go @@ -90,6 +90,7 @@ func (p *ProviderSettings) String() string { settingsList = append(settingsList, "Region: "+p.ServerSelection.Region, "Encryption preset: "+p.ExtraConfigOptions.EncryptionPreset, + "Port forwarding: "+p.PortForwarding.String(), ) case "mullvad": settingsList = append(settingsList, diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 1f92cca0..22839759 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -2,7 +2,8 @@ package openvpn import ( "context" - "fmt" + "net" + "net/http" "sync" "time" @@ -10,17 +11,17 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider" + "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/network" ) type Looper interface { Run(ctx context.Context, wg *sync.WaitGroup) Restart() - PortForward() + PortForward(vpnGatewayIP net.IP) GetSettings() (settings settings.OpenVPN) SetSettings(settings settings.OpenVPN) GetPortForwarded() (portForwarded uint16) @@ -40,23 +41,24 @@ type looper struct { uid int gid int // Configurators - conf Configurator - fw firewall.Configurator + conf Configurator + fw firewall.Configurator + routing routing.Routing // Other objects - logger logging.Logger - client network.Client - fileManager files.FileManager - streamMerger command.StreamMerger - cancel context.CancelFunc + logger, pfLogger logging.Logger + client *http.Client + fileManager files.FileManager + streamMerger command.StreamMerger + cancel context.CancelFunc // Internal channels restart chan struct{} - portForwardSignals chan struct{} + portForwardSignals chan net.IP } func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, uid, gid int, allServers models.AllServers, - conf Configurator, fw firewall.Configurator, - logger logging.Logger, client network.Client, fileManager files.FileManager, + conf Configurator, fw firewall.Configurator, routing routing.Routing, + logger logging.Logger, client *http.Client, fileManager files.FileManager, streamMerger command.StreamMerger, cancel context.CancelFunc) Looper { return &looper{ provider: provider, @@ -66,18 +68,20 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, allServers: allServers, conf: conf, fw: fw, + routing: routing, logger: logger.WithPrefix("openvpn: "), + pfLogger: logger.WithPrefix("port forwarding: "), client: client, fileManager: fileManager, streamMerger: streamMerger, cancel: cancel, restart: make(chan struct{}), - portForwardSignals: make(chan struct{}), + portForwardSignals: make(chan net.IP), } } -func (l *looper) Restart() { l.restart <- struct{}{} } -func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} } +func (l *looper) Restart() { l.restart <- struct{}{} } +func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway } func (l *looper) GetSettings() (settings settings.OpenVPN) { l.settingsMutex.RLock() @@ -158,10 +162,12 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { go func(ctx context.Context) { for { select { + // TODO have a way to disable pf with a context case <-ctx.Done(): return - case <-l.portForwardSignals: - l.portForward(ctx, providerConf, l.client) + case gateway := <-l.portForwardSignals: + wg.Add(1) + go l.portForward(ctx, wg, providerConf, l.client, gateway) } } }(openvpnCtx) @@ -200,43 +206,25 @@ func (l *looper) logAndWait(ctx context.Context, err error) { <-ctx.Done() } -func (l *looper) portForward(ctx context.Context, providerConf provider.Provider, client network.Client) { +// portForward is a blocking operation which may or may not be infinite. +// You should therefore always call it in a goroutine +func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup, + providerConf provider.Provider, client *http.Client, gateway net.IP) { + defer wg.Done() settings := l.GetSettings() if !settings.Provider.PortForwarding.Enabled { return } - var port uint16 - err := fmt.Errorf("") - for err != nil { - if ctx.Err() != nil { - return - } - port, err = providerConf.GetPortForward(client) - if err != nil { - l.logAndWait(ctx, err) - } - } - - l.logger.Info("port forwarded is %d", port) - l.portForwardedMutex.Lock() - if err := l.fw.RemoveAllowedPort(ctx, l.portForwarded); err != nil { - l.logger.Error(err) - } - if err := l.fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil { - l.logger.Error(err) - } - l.portForwarded = port - l.portForwardedMutex.Unlock() - - filepath := settings.Provider.PortForwarding.Filepath - l.logger.Info("writing forwarded port to %s", filepath) - err = l.fileManager.WriteLinesToFile( - string(filepath), []string{fmt.Sprintf("%d", port)}, - files.Ownership(l.uid, l.gid), files.Permissions(0400), - ) - if err != nil { - l.logger.Error(err) + syncState := func(port uint16) (pfFilepath models.Filepath) { + l.portForwardedMutex.Lock() + l.portForwarded = port + l.portForwardedMutex.Unlock() + settings := l.GetSettings() + return settings.Provider.PortForwarding.Filepath } + providerConf.PortForward(ctx, + client, l.fileManager, l.pfLogger, + gateway, l.fw, syncState) } func (l *looper) GetPortForwarded() (portForwarded uint16) { diff --git a/internal/provider/cyberghost.go b/internal/provider/cyberghost.go index b86425c8..7f2104f3 100644 --- a/internal/provider/cyberghost.go +++ b/internal/provider/cyberghost.go @@ -1,12 +1,17 @@ package provider import ( + "context" "fmt" + "net" + "net/http" "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/network" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" ) type cyberghost struct { @@ -135,6 +140,8 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity return lines } -func (c *cyberghost) GetPortForward(client network.Client) (port uint16, err error) { +func (c *cyberghost) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for cyberghost") } diff --git a/internal/provider/mullvad.go b/internal/provider/mullvad.go index 8f0a8fc4..b2f809f2 100644 --- a/internal/provider/mullvad.go +++ b/internal/provider/mullvad.go @@ -1,12 +1,17 @@ package provider import ( + "context" "fmt" + "net" + "net/http" "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/network" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" ) type mullvad struct { @@ -134,6 +139,8 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u return lines } -func (m *mullvad) GetPortForward(client network.Client) (port uint16, err error) { +func (m *mullvad) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for mullvad") } diff --git a/internal/provider/nordvpn.go b/internal/provider/nordvpn.go index 3f9146c0..f8652156 100644 --- a/internal/provider/nordvpn.go +++ b/internal/provider/nordvpn.go @@ -1,12 +1,17 @@ package provider import ( + "context" "fmt" + "net" + "net/http" "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/network" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" ) type nordvpn struct { @@ -142,6 +147,8 @@ func (n *nordvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u return lines } -func (n *nordvpn) GetPortForward(client network.Client) (port uint16, err error) { +func (n *nordvpn) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for nordvpn") } diff --git a/internal/provider/pia.go b/internal/provider/pia.go index 8486fa39..abbf2d24 100644 --- a/internal/provider/pia.go +++ b/internal/provider/pia.go @@ -1,35 +1,18 @@ package provider import ( - "encoding/hex" - "encoding/json" "fmt" - "net/http" "strings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/crypto/random" - "github.com/qdm12/golibs/network" ) -type pia struct { - random random.Random - servers []models.PIAServer -} - -func newPrivateInternetAccess(servers []models.PIAServer) *pia { - return &pia{ - random: random.NewRandom(), - servers: servers, - } -} - -func (p *pia) filterServers(region string) (servers []models.PIAServer) { +func filterPIAServers(servers []models.PIAServer, region string) (filtered []models.PIAServer) { if len(region) == 0 { - return p.servers + return servers } - for _, server := range p.servers { + for _, server := range servers { if strings.EqualFold(server.Region, region) { return []models.PIAServer{server} } @@ -37,8 +20,8 @@ func (p *pia) filterServers(region string) (servers []models.PIAServer) { return nil } -func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - servers := p.filterServers(selection.Region) +func getPIAOpenVPNConnections(allServers []models.PIAServer, selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { + servers := filterPIAServers(allServers, selection.Region) if len(servers) == 0 { return nil, fmt.Errorf("no server found for region %q", selection.Region) } @@ -87,7 +70,7 @@ func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connectio return connections, nil } -func (p *pia) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { +func buildPIAConf(connections []models.OpenVPNConnection, verbosity int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { var X509CRL, certificate string if extras.EncryptionPreset == constants.PIAEncryptionPresetNormal { if len(cipher) == 0 { @@ -161,28 +144,3 @@ func (p *pia) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, }...) return lines } - -func (p *pia) GetPortForward(client network.Client) (port uint16, err error) { - b, err := p.random.GenerateRandomBytes(32) - if err != nil { - return 0, err - } - clientID := hex.EncodeToString(b) - url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID) - content, status, err := client.GetContent(url) // TODO add ctx - switch { - case err != nil: - return 0, err - case status != http.StatusOK: - return 0, fmt.Errorf("status is %d for %s; does your PIA server support port forwarding?", status, url) - case len(content) == 0: - return 0, fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding") - } - body := struct { - Port uint16 `json:"port"` - }{} - if err := json.Unmarshal(content, &body); err != nil { - return 0, fmt.Errorf("port forwarding response: %w", err) - } - return body.Port, nil -} diff --git a/internal/provider/piav3.go b/internal/provider/piav3.go new file mode 100644 index 00000000..dec5f79c --- /dev/null +++ b/internal/provider/piav3.go @@ -0,0 +1,94 @@ +package provider + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/golibs/crypto/random" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" +) + +type piaV3 struct { + random random.Random + servers []models.PIAServer +} + +func newPrivateInternetAccessV3(servers []models.PIAServer) *piaV3 { + return &piaV3{ + random: random.NewRandom(), + servers: servers, + } +} + +func (p *piaV3) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { + return getPIAOpenVPNConnections(p.servers, selection) +} + +func (p *piaV3) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { + return buildPIAConf(connections, verbosity, root, cipher, auth, extras) +} + +func (p *piaV3) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { + b, err := p.random.GenerateRandomBytes(32) + if err != nil { + pfLogger.Error(err) + return + } + clientID := hex.EncodeToString(b) + url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID) + response, err := client.Get(url) // TODO add ctx + if err != nil { + pfLogger.Error(err) + return + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + pfLogger.Error(fmt.Errorf("%s for %s; does your PIA server support port forwarding?", response.Status, url)) + return + } + b, err = ioutil.ReadAll(response.Body) + if err != nil { + pfLogger.Error(err) + return + } else if len(b) == 0 { + pfLogger.Error(fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding")) + return + } + body := struct { + Port uint16 `json:"port"` + }{} + if err := json.Unmarshal(b, &body); err != nil { + pfLogger.Error(fmt.Errorf("port forwarding response: %w", err)) + return + } + port := body.Port + + filepath := syncState(port) + pfLogger.Info("Writing port to %s", filepath) + if err := fileManager.WriteToFile( + string(filepath), []byte(fmt.Sprintf("%d", port)), + files.Permissions(0666), + ); err != nil { + pfLogger.Error(err) + } + + if err := fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil { + pfLogger.Error(err) + } + + <-ctx.Done() + if err := fw.RemoveAllowedPort(ctx, port); err != nil { + pfLogger.Error(err) + } +} diff --git a/internal/provider/piav4.go b/internal/provider/piav4.go new file mode 100644 index 00000000..2936a958 --- /dev/null +++ b/internal/provider/piav4.go @@ -0,0 +1,412 @@ +package provider + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" + gluetunLog "github.com/qdm12/gluetun/internal/logging" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" +) + +type piaV4 struct { + servers []models.PIAServer + timeNow func() time.Time +} + +func newPrivateInternetAccessV4(servers []models.PIAServer) *piaV4 { + return &piaV4{ + servers: servers, + timeNow: time.Now, + } +} + +func (p *piaV4) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { + return getPIAOpenVPNConnections(p.servers, selection) +} + +func (p *piaV4) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { + return buildPIAConf(connections, verbosity, root, cipher, auth, extras) +} + +//nolint:gocognit +func (p *piaV4) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { + if gateway == nil { + pfLogger.Error("aborting because: VPN gateway IP address was not found") + return + } + client, err := newPIAv4HTTPClient() + if err != nil { + pfLogger.Error("aborting because: %s", err) + return + } + defer pfLogger.Warn("loop exited") + data, err := readPIAPortForwardData(fileManager) + if err != nil { + pfLogger.Error(err) + } + dataFound := data.Port > 0 + durationToExpiration := data.Expiration.Sub(p.timeNow()) + expired := durationToExpiration <= 0 + + if dataFound { + pfLogger.Info("Found persistent forwarded port data for port %d", data.Port) + if expired { + pfLogger.Warn("Forwarded port data expired on %s, getting another one", data.Expiration.Format(time.RFC1123)) + } else { + pfLogger.Info("Forwarded port data expires in %s", gluetunLog.FormatDuration(durationToExpiration)) + } + } + + if !dataFound || expired { + tryUntilSuccessful(ctx, pfLogger, func() error { + data, err = refreshPIAPortForwardData(client, gateway, fileManager) + return err + }) + if ctx.Err() != nil { + return + } + durationToExpiration = data.Expiration.Sub(p.timeNow()) + } + pfLogger.Info("Port forwarded is %d expiring in %s", data.Port, gluetunLog.FormatDuration(durationToExpiration)) + + // First time binding + tryUntilSuccessful(ctx, pfLogger, func() error { + return bindPIAPort(client, gateway, data) + }) + if ctx.Err() != nil { + return + } + + filepath := syncState(data.Port) + pfLogger.Info("Writing port to %s", filepath) + if err := fileManager.WriteToFile( + string(filepath), []byte(fmt.Sprintf("%d", data.Port)), + files.Permissions(0666), + ); err != nil { + pfLogger.Error(err) + } + + if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil { + pfLogger.Error(err) + } + + expiryTimer := time.NewTimer(durationToExpiration) + defer expiryTimer.Stop() + const keepAlivePeriod = 15 * time.Minute + keepAliveTicker := time.NewTicker(keepAlivePeriod) + defer keepAliveTicker.Stop() + + for { + select { + case <-ctx.Done(): + removeCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := fw.RemoveAllowedPort(removeCtx, data.Port); err != nil { + pfLogger.Error(err) + } + return + case <-keepAliveTicker.C: + if err := bindPIAPort(client, gateway, data); err != nil { + pfLogger.Error(err) + } + case <-expiryTimer.C: + pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123)) + oldPort := data.Port + for { + data, err = refreshPIAPortForwardData(client, gateway, fileManager) + if err != nil { + pfLogger.Error(err) + continue + } + break + } + durationToExpiration := data.Expiration.Sub(p.timeNow()) + pfLogger.Info("Port forwarded is %d expiring in %s", data.Port, gluetunLog.FormatDuration(durationToExpiration)) + if err := fw.RemoveAllowedPort(ctx, oldPort); err != nil { + pfLogger.Error(err) + } + if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil { + pfLogger.Error(err) + } + filepath := syncState(data.Port) + pfLogger.Info("Writing port to %s", filepath) + if err := fileManager.WriteToFile( + string(filepath), []byte(fmt.Sprintf("%d", data.Port)), + files.Permissions(0666), + ); err != nil { + pfLogger.Error(err) + } + if err := bindPIAPort(client, gateway, data); err != nil { + pfLogger.Error(err) + } + keepAliveTicker.Reset(keepAlivePeriod) + expiryTimer.Reset(durationToExpiration) + } + } +} + +func newPIAv4HTTPClient() (client *http.Client, err error) { + certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong) + if err != nil { + return nil, fmt.Errorf("cannot decode PIA root certificate: %w", err) + } + certificate, err := x509.ParseCertificate(certificateBytes) + if err != nil { + return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err) + } + rootCAs := x509.NewCertPool() + rootCAs.AddCert(certificate) + TLSClientConfig := &tls.Config{ + RootCAs: rootCAs, + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: true, //nolint:gosec + } // TODO fix and remove InsecureSkipVerify + transport := http.Transport{ + TLSClientConfig: TLSClientConfig, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + const httpTimeout = 5 * time.Second + client = &http.Client{Transport: &transport, Timeout: httpTimeout} + return client, nil +} + +func refreshPIAPortForwardData(client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) { + data.Token, err = fetchPIAToken(fileManager, client) + if err != nil { + return data, fmt.Errorf("cannot obtain token: %w", err) + } + data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(client, gateway, data.Token) + if err != nil { + if strings.HasSuffix(err.Error(), "connection refused") { + return data, fmt.Errorf("cannot obtain port forwarding data: connection was refused, are you sure the region you are using supports port forwarding ;)") + } + return data, fmt.Errorf("cannot obtain port forwarding data: %w", err) + } + if err := writePIAPortForwardData(fileManager, data); err != nil { + return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err) + } + return data, nil +} + +type piaPayload struct { + Token string `json:"token"` + Port uint16 `json:"port"` + Expiration time.Time `json:"expires_at"` +} + +type piaPortForwardData struct { + Port uint16 `json:"port"` + Token string `json:"token"` + Signature string `json:"signature"` + Expiration time.Time `json:"expires_at"` +} + +func readPIAPortForwardData(fileManager files.FileManager) (data piaPortForwardData, err error) { + const filepath = string(constants.PIAPortForward) + exists, err := fileManager.FileExists(filepath) + if err != nil { + return data, err + } else if !exists { + return data, nil + } + b, err := fileManager.ReadFile(filepath) + if err != nil { + return data, err + } + if err := json.Unmarshal(b, &data); err != nil { + return data, err + } + return data, nil +} + +func writePIAPortForwardData(fileManager files.FileManager, data piaPortForwardData) (err error) { + b, err := json.Marshal(&data) + if err != nil { + return fmt.Errorf("cannot encode data: %w", err) + } + err = fileManager.WriteToFile(string(constants.PIAPortForward), b) + if err != nil { + return err + } + return nil +} + +func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) { + b, err := base64.RawStdEncoding.DecodeString(payload) + if err != nil { + return 0, "", expiration, fmt.Errorf("cannot decode payload: %w", err) + } + var payloadData piaPayload + if err := json.Unmarshal(b, &payloadData); err != nil { + return 0, "", expiration, fmt.Errorf("cannot parse payload data: %w", err) + } + return payloadData.Port, payloadData.Token, payloadData.Expiration, nil +} + +func packPIAPayload(port uint16, token string, expiration time.Time) (payload string, err error) { + payloadData := piaPayload{ + Token: token, + Port: port, + Expiration: expiration, + } + b, err := json.Marshal(&payloadData) + if err != nil { + return "", fmt.Errorf("cannot serialize payload data: %w", err) + } + payload = base64.RawStdEncoding.EncodeToString(b) + return payload, nil +} + +func fetchPIAToken(fileManager files.FileManager, client *http.Client) (token string, err error) { + username, password, err := getOpenvpnCredentials(fileManager) + if err != nil { + return "", fmt.Errorf("cannot get Openvpn credentials: %w", err) + } + url := url.URL{ + Scheme: "https", + User: url.UserPassword(username, password), + Host: "10.0.0.1", + Path: "/authv3/generateToken", + } + request, err := http.NewRequest(http.MethodGet, url.String(), nil) + if err != nil { + return "", err + } + response, err := client.Do(request) + if err != nil { + return "", err + } + defer response.Body.Close() + b, err := ioutil.ReadAll(response.Body) + if response.StatusCode != http.StatusOK { + shortenMessage := string(b) + shortenMessage = strings.ReplaceAll(shortenMessage, "\n", "") + shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ") + return "", fmt.Errorf("%s: response received: %q", response.Status, shortenMessage) + } else if err != nil { + return "", err + } + var result struct { + Token string `json:"token"` + } + if err := json.Unmarshal(b, &result); err != nil { + return "", err + } else if len(result.Token) == 0 { + return "", fmt.Errorf("token is empty") + } + return result.Token, nil +} + +func getOpenvpnCredentials(fileManager files.FileManager) (username, password string, err error) { + authData, err := fileManager.ReadFile(string(constants.OpenVPNAuthConf)) + if err != nil { + return "", "", fmt.Errorf("cannot read openvpn auth file: %w", err) + } + lines := strings.Split(string(authData), "\n") + if len(lines) < 2 { + return "", "", fmt.Errorf("not enough lines (%d) in openvpn auth file", len(lines)) + } + username, password = lines[0], lines[1] + return username, password, nil +} + +func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) { + queryParams := url.Values{} + queryParams.Add("token", token) + url := url.URL{ + Scheme: "https", + Host: net.JoinHostPort(gateway.String(), "19999"), + Path: "/getSignature", + RawQuery: queryParams.Encode(), + } + response, err := client.Get(url.String()) + if err != nil { + return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return 0, "", expiration, fmt.Errorf("cannot obtain signature: %s", response.Status) + } + b, err := ioutil.ReadAll(response.Body) + if err != nil { + return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err) + } + var data struct { + Status string `json:"status"` + Payload string `json:"payload"` + Signature string `json:"signature"` + } + if err := json.Unmarshal(b, &data); err != nil { + return 0, "", expiration, fmt.Errorf("cannot decode received data: %w", err) + } else if data.Status != "OK" { + return 0, "", expiration, fmt.Errorf("response received from PIA has status %s", data.Status) + } + + port, _, expiration, err = unpackPIAPayload(data.Payload) + return port, data.Signature, expiration, err +} + +func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (err error) { + payload, err := packPIAPayload(data.Port, data.Token, data.Expiration) + if err != nil { + return err + } + queryParams := url.Values{} + queryParams.Add("payload", payload) + queryParams.Add("signature", data.Signature) + url := url.URL{ + Scheme: "https", + Host: net.JoinHostPort(gateway.String(), "19999"), + Path: "/bindPort", + RawQuery: queryParams.Encode(), + } + + response, err := client.Get(url.String()) + if err != nil { + return fmt.Errorf("cannot bind port: %w", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("cannot bind port: %s", response.Status) + } + b, err := ioutil.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("cannot bind port: %w", err) + } + var responseData struct { + Status string `json:"status"` + Message string `json:"message"` + } + if err := json.Unmarshal(b, &responseData); err != nil { + return fmt.Errorf("cannot bind port: %w", err) + } else if responseData.Status != "OK" { + return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message) + } + return nil +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 1ce8fe1b..033a36b9 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -1,24 +1,32 @@ package provider import ( + "context" + "net" + "net/http" + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/network" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" ) // Provider contains methods to read and modify the openvpn configuration to connect as a client type Provider interface { GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) - GetPortForward(client network.Client) (port uint16, err error) + PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) } func New(provider models.VPNProvider, allServers models.AllServers) Provider { switch provider { case constants.PrivateInternetAccess: - return newPrivateInternetAccess(allServers.Pia.Servers) + return newPrivateInternetAccessV4(allServers.Pia.Servers) case constants.PrivateInternetAccessOld: - return newPrivateInternetAccess(allServers.PiaOld.Servers) + return newPrivateInternetAccessV3(allServers.PiaOld.Servers) case constants.Mullvad: return newMullvad(allServers.Mullvad.Servers) case constants.Windscribe: diff --git a/internal/provider/purevpn.go b/internal/provider/purevpn.go index 75a68793..73389177 100644 --- a/internal/provider/purevpn.go +++ b/internal/provider/purevpn.go @@ -1,12 +1,17 @@ package provider import ( + "context" "fmt" + "net" + "net/http" "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/network" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" ) type purevpn struct { @@ -157,6 +162,8 @@ func (p *purevpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u return lines } -func (p *purevpn) GetPortForward(client network.Client) (port uint16, err error) { +func (p *purevpn) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for purevpn") } diff --git a/internal/provider/surfshark.go b/internal/provider/surfshark.go index c1841335..8b61181d 100644 --- a/internal/provider/surfshark.go +++ b/internal/provider/surfshark.go @@ -1,12 +1,17 @@ package provider import ( + "context" "fmt" + "net" + "net/http" "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/network" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" ) type surfshark struct { @@ -135,6 +140,8 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity, return lines } -func (s *surfshark) GetPortForward(client network.Client) (port uint16, err error) { +func (s *surfshark) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for surfshark") } diff --git a/internal/provider/utils.go b/internal/provider/utils.go new file mode 100644 index 00000000..238d8add --- /dev/null +++ b/internal/provider/utils.go @@ -0,0 +1,29 @@ +package provider + +import ( + "context" + "time" + + "github.com/qdm12/golibs/logging" +) + +func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) { + const retryPeriod = 10 * time.Second + for { + err := fn() + if err == nil { + break + } + logger.Error(err) + logger.Info("Trying again in %s", retryPeriod) + timer := time.NewTimer(retryPeriod) + select { + case <-timer.C: + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + } + } +} diff --git a/internal/provider/vyprvpn.go b/internal/provider/vyprvpn.go index 193e7cc2..2d54598c 100644 --- a/internal/provider/vyprvpn.go +++ b/internal/provider/vyprvpn.go @@ -1,12 +1,17 @@ package provider import ( + "context" "fmt" + "net" + "net/http" "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/network" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" ) type vyprvpn struct { @@ -121,6 +126,8 @@ func (v *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u return lines } -func (v *vyprvpn) GetPortForward(client network.Client) (port uint16, err error) { +func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for vyprvpn") } diff --git a/internal/provider/windscribe.go b/internal/provider/windscribe.go index fca0dcea..95142582 100644 --- a/internal/provider/windscribe.go +++ b/internal/provider/windscribe.go @@ -1,12 +1,17 @@ package provider import ( + "context" "fmt" + "net" + "net/http" "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/network" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" ) type windscribe struct { @@ -133,6 +138,8 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity return lines } -func (w *windscribe) GetPortForward(client network.Client) (port uint16, err error) { +func (w *windscribe) PortForward(ctx context.Context, client *http.Client, + fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for windscribe") } diff --git a/internal/routing/reader.go b/internal/routing/reader.go index 08515f6f..2caf239c 100644 --- a/internal/routing/reader.go +++ b/internal/routing/reader.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/golibs/files" ) func parseRoutingTable(data []byte) (entries []routingEntry, err error) { @@ -23,12 +24,16 @@ func parseRoutingTable(data []byte) (entries []routingEntry, err error) { return entries, nil } -func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { - data, err := r.fileManager.ReadFile(string(constants.NetRoute)) +func getRoutingEntries(fileManager files.FileManager) (entries []routingEntry, err error) { + data, err := fileManager.ReadFile(string(constants.NetRoute)) if err != nil { - return "", nil, err + return nil, err } - entries, err := parseRoutingTable(data) + return parseRoutingTable(data) +} + +func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { + entries, err := getRoutingEntries(r.fileManager) if err != nil { return "", nil, err } @@ -52,11 +57,7 @@ func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP } func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { - data, err := r.fileManager.ReadFile(string(constants.NetRoute)) - if err != nil { - return defaultSubnet, err - } - entries, err := parseRoutingTable(data) + entries, err := getRoutingEntries(r.fileManager) if err != nil { return defaultSubnet, err } @@ -79,11 +80,7 @@ func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { } func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) { - data, err := r.fileManager.ReadFile(string(constants.NetRoute)) - if err != nil { - return false, fmt.Errorf("cannot check route existence: %w", err) - } - entries, err := parseRoutingTable(data) + entries, err := getRoutingEntries(r.fileManager) if err != nil { return false, fmt.Errorf("cannot check route existence: %w", err) } @@ -96,12 +93,8 @@ func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) { return false, nil } -func (r *routing) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) { - data, err := r.fileManager.ReadFile(string(constants.NetRoute)) - if err != nil { - return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err) - } - entries, err := parseRoutingTable(data) +func (r *routing) VPNDestinationIP(defaultInterface string) (ip net.IP, err error) { + entries, err := getRoutingEntries(r.fileManager) if err != nil { return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err) } @@ -115,6 +108,20 @@ func (r *routing) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) { return nil, fmt.Errorf("cannot find VPN gateway IP address from ip routes") } +func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) { + entries, err := getRoutingEntries(r.fileManager) + if err != nil { + return nil, fmt.Errorf("cannot find VPN local gateway IP address: %w", err) + } + for _, entry := range entries { + if entry.iface == string(constants.TUN) && + entry.destination.Equal(net.IP{0, 0, 0, 0}) { + return entry.gateway, nil + } + } + return nil, fmt.Errorf("cannot find VPN local gateway IP address from ip routes") +} + func ipIsPrivate(ip net.IP) bool { if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { return true diff --git a/internal/routing/reader_test.go b/internal/routing/reader_test.go index 82306b6c..c90d11ea 100644 --- a/internal/routing/reader_test.go +++ b/internal/routing/reader_test.go @@ -291,7 +291,7 @@ eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF } } -func Test_VPNGatewayIP(t *testing.T) { +func Test_VPNDestinationIP(t *testing.T) { t.Parallel() tests := map[string]struct { defaultInterface string @@ -334,7 +334,7 @@ eth0 x filemanager.EXPECT().ReadFile(string(constants.NetRoute)). Return(tc.data, tc.readErr).Times(1) r := &routing{fileManager: filemanager} - ip, err := r.VPNGatewayIP(tc.defaultInterface) + ip, err := r.VPNDestinationIP(tc.defaultInterface) if tc.err != nil { require.Error(t, err) assert.Equal(t, tc.err.Error(), err.Error()) diff --git a/internal/routing/routing.go b/internal/routing/routing.go index 34496eb4..4cc57c2b 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -14,7 +14,8 @@ type Routing interface { DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) LocalSubnet() (defaultSubnet net.IPNet, err error) - VPNGatewayIP(defaultInterface string) (ip net.IP, err error) + VPNDestinationIP(defaultInterface string) (ip net.IP, err error) + VPNLocalGatewayIP() (ip net.IP, err error) SetDebug() } diff --git a/internal/settings/providers.go b/internal/settings/providers.go index 94240d8d..2299d155 100644 --- a/internal/settings/providers.go +++ b/internal/settings/providers.go @@ -10,7 +10,16 @@ import ( // GetPIASettings obtains PIA settings from environment variables using the params package. func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) { - settings.Name = constants.PrivateInternetAccess + return getPIASettings(paramsReader, constants.PrivateInternetAccess) +} + +// GetPIAOldSettings obtains PIA settings for the older PIA servers (pre summer 2020) from environment variables using the params package. +func GetPIAOldSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) { + return getPIASettings(paramsReader, constants.PrivateInternetAccessOld) +} + +func getPIASettings(paramsReader params.Reader, name models.VPNProvider) (settings models.ProviderSettings, err error) { + settings.Name = name settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol() if err != nil { return settings, err @@ -29,30 +38,6 @@ func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSetting if err != nil { return settings, err } - return settings, nil -} - -// GetPIAOldSettings obtains PIA settings for the older PIA servers (pre summer 2020) from environment variables using the params package. -func GetPIAOldSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) { - settings.Name = constants.PrivateInternetAccessOld - settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol() - if err != nil { - return settings, err - } - settings.ServerSelection.TargetIP, err = paramsReader.GetTargetIP() - if err != nil { - return settings, err - } - encryptionPreset, err := paramsReader.GetPIAEncryptionPreset() - if err != nil { - return settings, err - } - settings.ServerSelection.EncryptionPreset = encryptionPreset - settings.ExtraConfigOptions.EncryptionPreset = encryptionPreset - settings.ServerSelection.Region, err = paramsReader.GetPIAOldRegion() - if err != nil { - return settings, err - } settings.PortForwarding.Enabled, err = paramsReader.GetPortForwarding() if err != nil { return settings, err diff --git a/internal/version/version.go b/internal/version/version.go index dea6c9da..3d6e3c55 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -4,6 +4,8 @@ import ( "fmt" "net/http" "time" + + "github.com/qdm12/gluetun/internal/logging" ) // GetMessage returns a message for the user describing if there is a newer version @@ -30,35 +32,12 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin if tagName == version { return fmt.Sprintf("You are running the latest release %s", version), nil } - timeSinceRelease := formatDuration(time.Since(releaseTime)) + timeSinceRelease := logging.FormatDuration(time.Since(releaseTime)) return fmt.Sprintf("There is a new release %s (%s) created %s ago", tagName, name, timeSinceRelease), nil } -func formatDuration(duration time.Duration) string { - switch { - case duration < time.Minute: - seconds := int(duration.Round(time.Second).Seconds()) - if seconds < 2 { - return fmt.Sprintf("%d second", seconds) - } - return fmt.Sprintf("%d seconds", seconds) - case duration <= time.Hour: - minutes := int(duration.Round(time.Minute).Minutes()) - if minutes == 1 { - return "1 minute" - } - return fmt.Sprintf("%d minutes", minutes) - case duration < 48*time.Hour: - hours := int(duration.Truncate(time.Hour).Hours()) - return fmt.Sprintf("%d hours", hours) - default: - days := int(duration.Truncate(time.Hour).Hours() / 24) - return fmt.Sprintf("%d days", days) - } -} - func getLatestRelease(client *http.Client) (tagName, name string, time time.Time, err error) { releases, err := getGithubReleases(client) if err != nil {