diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 327ccab0..5c801bb9 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -328,9 +329,9 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, logger.Error(line) } switch { - case line == "openvpn: Initialization Sequence Completed": + case strings.Contains(line, "Initialization Sequence Completed"): signalTunnelReady() - case line == "openvpn: TLS Error: TLS key negotiation failed to occur within 60 seconds (check your network connectivity)": + case strings.Contains(line, "TLS Error: TLS key negotiation failed to occur within 60 seconds (check your network connectivity)"): logger.Warn("This means that either...") logger.Warn("1. The VPN server IP address you are trying to connect to is no longer valid, see https://github.com/qdm12/gluetun/wiki/Update-servers-information") logger.Warn("2. The VPN server crashed, try changing region") diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 2896b2ab..8e3bd3ef 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -70,13 +70,13 @@ func OpenvpnConfig() error { if err != nil { return err } - providerConf := provider.New(allSettings.OpenVPN.Provider.Name, allServers) - connections, err := providerConf.GetOpenVPNConnections(allSettings.OpenVPN.Provider.ServerSelection) + providerConf := provider.New(allSettings.OpenVPN.Provider.Name, allServers, time.Now) + connection, err := providerConf.GetOpenVPNConnection(allSettings.OpenVPN.Provider.ServerSelection) if err != nil { return err } lines := providerConf.BuildConf( - connections, + connection, allSettings.OpenVPN.Verbosity, allSettings.System.UID, allSettings.System.GID, diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index a2251f92..0650c513 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -85,8 +85,8 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - for _, conn := range c.vpnConnections { - if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, conn, remove); err != nil { + if c.vpnConnection.IP != nil { + if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } } diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index f4d4d07a..6c33ca69 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -16,7 +16,7 @@ import ( type Configurator interface { Version(ctx context.Context) (string, error) SetEnabled(ctx context.Context, enabled bool) (err error) - SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) + SetVPNConnection(ctx context.Context, connection models.OpenVPNConnection) (err error) SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error) @@ -39,7 +39,7 @@ type configurator struct { //nolint:maligned // State enabled bool - vpnConnections []models.OpenVPNConnection + vpnConnection models.OpenVPNConnection allowedSubnets []net.IPNet allowedInputPorts map[uint16]string // port to interface mapping stateMutex sync.Mutex diff --git a/internal/firewall/vpn.go b/internal/firewall/vpn.go index 5e3924fa..bac439f0 100644 --- a/internal/firewall/vpn.go +++ b/internal/firewall/vpn.go @@ -7,95 +7,33 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) { +func (c *configurator) SetVPNConnection(ctx context.Context, connection models.OpenVPNConnection) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() if !c.enabled { - c.logger.Info("firewall disabled, only updating VPN connections internal list") - c.vpnConnections = make([]models.OpenVPNConnection, len(connections)) - copy(c.vpnConnections, connections) + c.logger.Info("firewall disabled, only updating internal VPN connection") + c.vpnConnection = connection return nil } - c.logger.Info("setting VPN connections through firewall...") + c.logger.Info("setting VPN connection through firewall...") - connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections) - connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections) - if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 { + if c.vpnConnection.Equal(connection) { return nil } - c.removeConnections(ctx, connectionsToRemove, c.defaultInterface) - if err := c.addConnections(ctx, connectionsToAdd, c.defaultInterface); err != nil { - return fmt.Errorf("cannot set VPN connections through firewall: %w", err) - } - - return nil -} - -func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection { - L := len(connections) - for i := range connections { - if connection.Equal(connections[i]) { - connections[i] = connections[L-1] - connections = connections[:L-1] - break - } - } - return connections -} - -func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) { - for _, newConnection := range newConnections { - found := false - for _, oldConnection := range oldConnections { - if oldConnection.Equal(newConnection) { - found = true - break - } - } - if !found { - connectionsToAdd = append(connectionsToAdd, newConnection) - } - } - return connectionsToAdd -} - -func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) { - for _, oldConnection := range oldConnections { - found := false - for _, newConnection := range newConnections { - if oldConnection.Equal(newConnection) { - found = true - break - } - } - if !found { - connectionsToRemove = append(connectionsToRemove, oldConnection) - } - } - return connectionsToRemove -} - -func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) { - for _, conn := range connections { - const remove = true - if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil { + remove := true + if c.vpnConnection.IP != nil { + if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil { c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err) - continue } - c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn) } -} - -func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error { - const remove = false - for _, conn := range connections { - if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil { - return err - } - c.vpnConnections = append(c.vpnConnections, conn) + c.vpnConnection = models.OpenVPNConnection{} + remove = false + if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil { + return fmt.Errorf("cannot set VPN connection through firewall: %w", err) } + c.vpnConnection = connection return nil } diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 22839759..f3be3564 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -113,16 +113,16 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { for ctx.Err() == nil { settings := l.GetSettings() l.allServersMutex.RLock() - providerConf := provider.New(l.provider, l.allServers) + providerConf := provider.New(l.provider, l.allServers, time.Now) l.allServersMutex.RUnlock() - connections, err := providerConf.GetOpenVPNConnections(settings.Provider.ServerSelection) + connection, err := providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection) if err != nil { l.logger.Error(err) l.cancel() return } lines := providerConf.BuildConf( - connections, + connection, settings.Verbosity, l.uid, l.gid, @@ -143,7 +143,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { return } - if err := l.fw.SetVPNConnections(ctx, connections); err != nil { + if err := l.fw.SetVPNConnection(ctx, connection); err != nil { l.logger.Error(err) l.cancel() return diff --git a/internal/provider/cyberghost.go b/internal/provider/cyberghost.go index 7f2104f3..c069f5bc 100644 --- a/internal/provider/cyberghost.go +++ b/internal/provider/cyberghost.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "math/rand" "net" "net/http" "strings" @@ -15,12 +16,14 @@ import ( ) type cyberghost struct { - servers []models.CyberghostServer + servers []models.CyberghostServer + randSource rand.Source } -func newCyberghost(servers []models.CyberghostServer) *cyberghost { +func newCyberghost(servers []models.CyberghostServer, timeNow timeNowFunc) *cyberghost { return &cyberghost{ - servers: servers, + servers: servers, + randSource: rand.NewSource(timeNow().UnixNano()), } } @@ -39,17 +42,18 @@ func (c *cyberghost) filterServers(region, group string) (servers []models.Cyber return servers } -func (c *cyberghost) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { +func (c *cyberghost) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { servers := c.filterServers(selection.Region, selection.Group) if len(servers) == 0 { - return nil, fmt.Errorf("no server found for region %q and group %q", selection.Region, selection.Group) + return connection, fmt.Errorf("no server found for region %q and group %q", selection.Region, selection.Group) } + var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { if selection.TargetIP != nil { if selection.TargetIP.Equal(IP) { - return []models.OpenVPNConnection{{IP: IP, Port: 443, Protocol: selection.Protocol}}, nil + return models.OpenVPNConnection{IP: IP, Port: 443, Protocol: selection.Protocol}, nil } } else { connections = append(connections, models.OpenVPNConnection{IP: IP, Port: 443, Protocol: selection.Protocol}) @@ -58,17 +62,13 @@ func (c *cyberghost) GetOpenVPNConnections(selection models.ServerSelection) (co } if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + return connection, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) } - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil + return pickRandomConnection(connections, c.randSource), nil } -func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { +func (c *cyberghost) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { if len(cipher) == 0 { cipher = aes256cbc } @@ -102,7 +102,8 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity // Modified variables fmt.Sprintf("verb %d", verbosity), fmt.Sprintf("auth-user-pass %s", constants.OpenVPNAuthConf), - fmt.Sprintf("proto %s", connections[0].Protocol), + fmt.Sprintf("proto %s", connection.Protocol), + fmt.Sprintf("remote %s %d", connection.IP, connection.Port), fmt.Sprintf("cipher %s", cipher), fmt.Sprintf("auth %s", auth), } @@ -112,9 +113,6 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity if !root { lines = append(lines, "user nonrootuser") } - for _, connection := range connections { - lines = append(lines, fmt.Sprintf("remote %s %d", connection.IP, connection.Port)) - } lines = append(lines, []string{ "", "-----BEGIN CERTIFICATE-----", diff --git a/internal/provider/mullvad.go b/internal/provider/mullvad.go index b2f809f2..5aa48a59 100644 --- a/internal/provider/mullvad.go +++ b/internal/provider/mullvad.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "math/rand" "net" "net/http" "strings" @@ -15,12 +16,14 @@ import ( ) type mullvad struct { - servers []models.MullvadServer + servers []models.MullvadServer + randSource rand.Source } -func newMullvad(servers []models.MullvadServer) *mullvad { +func newMullvad(servers []models.MullvadServer, timeNow timeNowFunc) *mullvad { return &mullvad{ - servers: servers, + servers: servers, + randSource: rand.NewSource(timeNow().UnixNano()), } } @@ -44,10 +47,10 @@ func (m *mullvad) filterServers(country, city, isp string) (servers []models.Mul return servers } -func (m *mullvad) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { +func (m *mullvad) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { servers := m.filterServers(selection.Country, selection.City, selection.ISP) if len(servers) == 0 { - return nil, fmt.Errorf("no server found for country %q, city %q and ISP %q", selection.Country, selection.City, selection.ISP) + return connection, fmt.Errorf("no server found for country %q, city %q and ISP %q", selection.Country, selection.City, selection.ISP) } var defaultPort uint16 = 1194 @@ -55,6 +58,7 @@ func (m *mullvad) GetOpenVPNConnections(selection models.ServerSelection) (conne defaultPort = 443 } + var connections []models.OpenVPNConnection for _, server := range servers { port := defaultPort if selection.CustomPort > 0 { @@ -63,7 +67,7 @@ func (m *mullvad) GetOpenVPNConnections(selection models.ServerSelection) (conne for _, IP := range server.IPs { if selection.TargetIP != nil { if selection.TargetIP.Equal(IP) { - return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + return models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}, nil } } else { connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) @@ -72,17 +76,13 @@ func (m *mullvad) GetOpenVPNConnections(selection models.ServerSelection) (conne } if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) + return connection, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) } - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil + return pickRandomConnection(connections, m.randSource), nil } -func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { +func (m *mullvad) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { if len(cipher) == 0 { cipher = aes256cbc } @@ -113,7 +113,8 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u // Modified variables fmt.Sprintf("verb %d", verbosity), fmt.Sprintf("auth-user-pass %s", constants.OpenVPNAuthConf), - fmt.Sprintf("proto %s", connections[0].Protocol), + fmt.Sprintf("proto %s", connection.Protocol), + fmt.Sprintf("remote %s %d", connection.IP, connection.Port), fmt.Sprintf("cipher %s", cipher), } if extras.OpenVPNIPv6 { @@ -125,9 +126,6 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u if !root { lines = append(lines, "user nonrootuser") } - for _, connection := range connections { - lines = append(lines, fmt.Sprintf("remote %s %d", connection.IP, connection.Port)) - } lines = append(lines, []string{ "", "-----BEGIN CERTIFICATE-----", diff --git a/internal/provider/nordvpn.go b/internal/provider/nordvpn.go index f8652156..b44a64c3 100644 --- a/internal/provider/nordvpn.go +++ b/internal/provider/nordvpn.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "math/rand" "net" "net/http" "strings" @@ -15,12 +16,14 @@ import ( ) type nordvpn struct { - servers []models.NordvpnServer + servers []models.NordvpnServer + randSource rand.Source } -func newNordvpn(servers []models.NordvpnServer) *nordvpn { +func newNordvpn(servers []models.NordvpnServer, timeNow timeNowFunc) *nordvpn { return &nordvpn{ - servers: servers, + servers: servers, + randSource: rand.NewSource(timeNow().UnixNano()), } } @@ -45,10 +48,10 @@ func (n *nordvpn) filterServers(region string, protocol models.NetworkProtocol, return servers } -func (n *nordvpn) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { //nolint:dupl +func (n *nordvpn) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { //nolint:dupl servers := n.filterServers(selection.Region, selection.Protocol, selection.Number) if len(servers) == 0 { - return nil, fmt.Errorf("no server found for region %q, protocol %s and number %d", selection.Region, selection.Protocol, selection.Number) + return connection, fmt.Errorf("no server found for region %q, protocol %s and number %d", selection.Region, selection.Protocol, selection.Number) } var port uint16 @@ -58,13 +61,14 @@ func (n *nordvpn) GetOpenVPNConnections(selection models.ServerSelection) (conne case selection.Protocol == constants.TCP: port = 443 default: - return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) + return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) } + var connections []models.OpenVPNConnection for _, server := range servers { if selection.TargetIP != nil { if selection.TargetIP.Equal(server.IP) { - return []models.OpenVPNConnection{{IP: server.IP, Port: port, Protocol: selection.Protocol}}, nil + return models.OpenVPNConnection{IP: server.IP, Port: port, Protocol: selection.Protocol}, nil } } else { connections = append(connections, models.OpenVPNConnection{IP: server.IP, Port: port, Protocol: selection.Protocol}) @@ -72,17 +76,13 @@ func (n *nordvpn) GetOpenVPNConnections(selection models.ServerSelection) (conne } if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + return connection, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) } - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil + return pickRandomConnection(connections, n.randSource), nil } -func (n *nordvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { //nolint:dupl +func (n *nordvpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { //nolint:dupl if len(cipher) == 0 { cipher = aes256cbc } @@ -119,16 +119,14 @@ func (n *nordvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u // Modified variables fmt.Sprintf("verb %d", verbosity), fmt.Sprintf("auth-user-pass %s", constants.OpenVPNAuthConf), - fmt.Sprintf("proto %s", string(connections[0].Protocol)), + fmt.Sprintf("proto %s", string(connection.Protocol)), + fmt.Sprintf("remote %s %d", connection.IP.String(), connection.Port), fmt.Sprintf("cipher %s", cipher), fmt.Sprintf("auth %s", auth), } if !root { lines = append(lines, "user nonrootuser") } - for _, connection := range connections { - lines = append(lines, fmt.Sprintf("remote %s %d", connection.IP.String(), connection.Port)) - } lines = append(lines, []string{ "", "-----BEGIN CERTIFICATE-----", diff --git a/internal/provider/pia.go b/internal/provider/pia.go index 587218b9..0fca64e0 100644 --- a/internal/provider/pia.go +++ b/internal/provider/pia.go @@ -8,7 +8,7 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -func buildPIAConf(connections []models.OpenVPNConnection, verbosity int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { +func buildPIAConf(connection models.OpenVPNConnection, verbosity int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { var X509CRL, certificate string if extras.EncryptionPreset == constants.PIAEncryptionPresetNormal { if len(cipher) == 0 { @@ -52,7 +52,8 @@ func buildPIAConf(connections []models.OpenVPNConnection, verbosity int, root bo // Modified variables fmt.Sprintf("verb %d", verbosity), fmt.Sprintf("auth-user-pass %s", constants.OpenVPNAuthConf), - fmt.Sprintf("proto %s", connections[0].Protocol), + fmt.Sprintf("proto %s", connection.Protocol), + fmt.Sprintf("remote %s %d", connection.IP, connection.Port), fmt.Sprintf("cipher %s", cipher), fmt.Sprintf("auth %s", auth), } @@ -62,9 +63,6 @@ func buildPIAConf(connections []models.OpenVPNConnection, verbosity int, root bo if !root { lines = append(lines, "user nonrootuser") } - for _, connection := range connections { - lines = append(lines, fmt.Sprintf("remote %s %d", connection.IP, connection.Port)) - } lines = append(lines, []string{ "", "-----BEGIN X509 CRL-----", diff --git a/internal/provider/piav3.go b/internal/provider/piav3.go index 84a0284e..dff93681 100644 --- a/internal/provider/piav3.go +++ b/internal/provider/piav3.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "math/rand" "net" "net/http" "strings" @@ -13,38 +14,84 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/crypto/random" "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" ) type piaV3 struct { - random random.Random - servers []models.PIAOldServer + servers []models.PIAOldServer + randSource rand.Source } -func newPrivateInternetAccessV3(servers []models.PIAOldServer) *piaV3 { +func newPrivateInternetAccessV3(servers []models.PIAOldServer, timeNow timeNowFunc) *piaV3 { return &piaV3{ - random: random.NewRandom(), - servers: servers, + servers: servers, + randSource: rand.NewSource(timeNow().UnixNano()), } } -func (p *piaV3) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - return getPIAOldOpenVPNConnections(p.servers, selection) +func (p *piaV3) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { + servers := filterPIAOldServers(p.servers, selection.Region) + if len(servers) == 0 { + return connection, fmt.Errorf("no server found for region %q", selection.Region) + } + + var port uint16 + switch selection.Protocol { + case constants.TCP: + switch selection.EncryptionPreset { + case constants.PIAEncryptionPresetNormal: + port = 502 + case constants.PIAEncryptionPresetStrong: + port = 501 + } + case constants.UDP: + switch selection.EncryptionPreset { + case constants.PIAEncryptionPresetNormal: + port = 1198 + case constants.PIAEncryptionPresetStrong: + port = 1197 + } + } + if port == 0 { + return connection, fmt.Errorf("combination of protocol %q and encryption %q does not yield any port number", selection.Protocol, selection.EncryptionPreset) + } + + var connections []models.OpenVPNConnection + for _, server := range servers { + for _, IP := range server.IPs { + if selection.TargetIP != nil { + if selection.TargetIP.Equal(IP) { + return models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}, nil + } + } else { + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + } + } + } + + if selection.TargetIP != nil { + return connection, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + } + + return pickRandomConnection(connections, p.randSource), nil } -func (p *piaV3) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { - return buildPIAConf(connections, verbosity, root, cipher, auth, extras) +func (p *piaV3) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { + return buildPIAConf(connection, verbosity, root, cipher, auth, extras) } func (p *piaV3) PortForward(ctx context.Context, client *http.Client, fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { - b, err := p.random.GenerateRandomBytes(32) + b := make([]byte, 32) + n, err := rand.New(p.randSource).Read(b) //nolint:gosec if err != nil { pfLogger.Error(err) return + } else if n != 32 { + pfLogger.Error("only read %d bytes instead of 32", n) + return } clientID := hex.EncodeToString(b) url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID) @@ -105,53 +152,3 @@ func filterPIAOldServers(servers []models.PIAOldServer, region string) (filtered } return nil } - -func getPIAOldOpenVPNConnections(allServers []models.PIAOldServer, selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - servers := filterPIAOldServers(allServers, selection.Region) - if len(servers) == 0 { - return nil, fmt.Errorf("no server found for region %q", selection.Region) - } - - var port uint16 - switch selection.Protocol { - case constants.TCP: - switch selection.EncryptionPreset { - case constants.PIAEncryptionPresetNormal: - port = 502 - case constants.PIAEncryptionPresetStrong: - port = 501 - } - case constants.UDP: - switch selection.EncryptionPreset { - case constants.PIAEncryptionPresetNormal: - port = 1198 - case constants.PIAEncryptionPresetStrong: - port = 1197 - } - } - if port == 0 { - return nil, fmt.Errorf("combination of protocol %q and encryption %q does not yield any port number", selection.Protocol, selection.EncryptionPreset) - } - - for _, server := range servers { - for _, IP := range server.IPs { - if selection.TargetIP != nil { - if selection.TargetIP.Equal(IP) { - return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil - } - } else { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) - } - } - } - - if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) - } - - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil -} diff --git a/internal/provider/piav4.go b/internal/provider/piav4.go index 1adbdbf0..429158d8 100644 --- a/internal/provider/piav4.go +++ b/internal/provider/piav4.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "math/rand" "net" "net/http" "net/url" @@ -23,23 +24,72 @@ import ( ) type piaV4 struct { - servers []models.PIAServer - timeNow func() time.Time + servers []models.PIAServer + timeNow timeNowFunc + randSource rand.Source } -func newPrivateInternetAccessV4(servers []models.PIAServer) *piaV4 { +func newPrivateInternetAccessV4(servers []models.PIAServer, timeNow timeNowFunc) *piaV4 { return &piaV4{ - servers: servers, - timeNow: time.Now, + servers: servers, + timeNow: timeNow, + randSource: rand.NewSource(timeNow().UnixNano()), } } -func (p *piaV4) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - return getPIAOpenVPNConnections(p.servers, selection) +func (p *piaV4) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { + servers := filterPIAServers(p.servers, selection.Region) + if len(servers) == 0 { + return connection, fmt.Errorf("no server found for region %q", selection.Region) + } + + var port uint16 + switch selection.Protocol { + case constants.TCP: + switch selection.EncryptionPreset { + case constants.PIAEncryptionPresetNormal: + port = 502 + case constants.PIAEncryptionPresetStrong: + port = 501 + } + case constants.UDP: + switch selection.EncryptionPreset { + case constants.PIAEncryptionPresetNormal: + port = 1198 + case constants.PIAEncryptionPresetStrong: + port = 1197 + } + } + if port == 0 { + return connection, fmt.Errorf("combination of protocol %q and encryption %q does not yield any port number", selection.Protocol, selection.EncryptionPreset) + } + + var connections []models.OpenVPNConnection + for _, server := range servers { + IPs := server.OpenvpnUDP.IPs + if selection.Protocol == constants.TCP { + IPs = server.OpenvpnTCP.IPs + } + for _, IP := range IPs { + if selection.TargetIP != nil { + if selection.TargetIP.Equal(IP) { + return models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}, nil + } + } else { + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + } + } + } + + if selection.TargetIP != nil { + return connection, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + } + + return pickRandomConnection(connections, p.randSource), nil } -func (p *piaV4) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { - return buildPIAConf(connections, verbosity, root, cipher, auth, extras) +func (p *piaV4) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { + return buildPIAConf(connection, verbosity, root, cipher, auth, extras) } //nolint:gocognit @@ -173,59 +223,6 @@ func filterPIAServers(servers []models.PIAServer, region string) (filtered []mod return nil } -func getPIAOpenVPNConnections(allServers []models.PIAServer, selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - servers := filterPIAServers(allServers, selection.Region) - if len(servers) == 0 { - return nil, fmt.Errorf("no server found for region %q", selection.Region) - } - - var port uint16 - switch selection.Protocol { - case constants.TCP: - switch selection.EncryptionPreset { - case constants.PIAEncryptionPresetNormal: - port = 502 - case constants.PIAEncryptionPresetStrong: - port = 501 - } - case constants.UDP: - switch selection.EncryptionPreset { - case constants.PIAEncryptionPresetNormal: - port = 1198 - case constants.PIAEncryptionPresetStrong: - port = 1197 - } - } - if port == 0 { - return nil, fmt.Errorf("combination of protocol %q and encryption %q does not yield any port number", selection.Protocol, selection.EncryptionPreset) - } - for _, server := range servers { - IPs := server.OpenvpnUDP.IPs - if selection.Protocol == constants.TCP { - IPs = server.OpenvpnTCP.IPs - } - for _, IP := range IPs { - if selection.TargetIP != nil { - if selection.TargetIP.Equal(IP) { - return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil - } - } else { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) - } - } - } - - if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) - } - - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil -} - func newPIAv4HTTPClient() (client *http.Client, err error) { certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong) if err != nil { diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 033a36b9..c3f6278d 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -14,33 +14,33 @@ import ( // Provider contains methods to read and modify the openvpn configuration to connect as a client type Provider interface { - GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) - BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) + GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) + BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) PortForward(ctx context.Context, client *http.Client, fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) } -func New(provider models.VPNProvider, allServers models.AllServers) Provider { +func New(provider models.VPNProvider, allServers models.AllServers, timeNow timeNowFunc) Provider { switch provider { case constants.PrivateInternetAccess: - return newPrivateInternetAccessV4(allServers.Pia.Servers) + return newPrivateInternetAccessV4(allServers.Pia.Servers, timeNow) case constants.PrivateInternetAccessOld: - return newPrivateInternetAccessV3(allServers.PiaOld.Servers) + return newPrivateInternetAccessV3(allServers.PiaOld.Servers, timeNow) case constants.Mullvad: - return newMullvad(allServers.Mullvad.Servers) + return newMullvad(allServers.Mullvad.Servers, timeNow) case constants.Windscribe: - return newWindscribe(allServers.Windscribe.Servers) + return newWindscribe(allServers.Windscribe.Servers, timeNow) case constants.Surfshark: - return newSurfshark(allServers.Surfshark.Servers) + return newSurfshark(allServers.Surfshark.Servers, timeNow) case constants.Cyberghost: - return newCyberghost(allServers.Cyberghost.Servers) + return newCyberghost(allServers.Cyberghost.Servers, timeNow) case constants.Vyprvpn: - return newVyprvpn(allServers.Vyprvpn.Servers) + return newVyprvpn(allServers.Vyprvpn.Servers, timeNow) case constants.Nordvpn: - return newNordvpn(allServers.Nordvpn.Servers) + return newNordvpn(allServers.Nordvpn.Servers, timeNow) case constants.Purevpn: - return newPurevpn(allServers.Purevpn.Servers) + return newPurevpn(allServers.Purevpn.Servers, timeNow) default: return nil // should never occur } diff --git a/internal/provider/purevpn.go b/internal/provider/purevpn.go index 73389177..544c0ef2 100644 --- a/internal/provider/purevpn.go +++ b/internal/provider/purevpn.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "math/rand" "net" "net/http" "strings" @@ -15,12 +16,14 @@ import ( ) type purevpn struct { - servers []models.PurevpnServer + servers []models.PurevpnServer + randSource rand.Source } -func newPurevpn(servers []models.PurevpnServer) *purevpn { +func newPurevpn(servers []models.PurevpnServer, timeNow timeNowFunc) *purevpn { return &purevpn{ - servers: servers, + servers: servers, + randSource: rand.NewSource(timeNow().UnixNano()), } } @@ -44,10 +47,10 @@ func (p *purevpn) filterServers(region, country, city string) (servers []models. return servers } -func (p *purevpn) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { //nolint:dupl +func (p *purevpn) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { //nolint:dupl servers := p.filterServers(selection.Region, selection.Country, selection.City) if len(servers) == 0 { - return nil, fmt.Errorf("no server found for region %q, country %q and city %q", selection.Region, selection.Country, selection.City) + return connection, fmt.Errorf("no server found for region %q, country %q and city %q", selection.Region, selection.Country, selection.City) } var port uint16 @@ -57,14 +60,15 @@ func (p *purevpn) GetOpenVPNConnections(selection models.ServerSelection) (conne case selection.Protocol == constants.TCP: port = 80 default: - return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) + return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) } + var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { if selection.TargetIP != nil { if IP.Equal(selection.TargetIP) { - return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + return models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}, nil } } else { connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) @@ -73,17 +77,13 @@ func (p *purevpn) GetOpenVPNConnections(selection models.ServerSelection) (conne } if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) + return connection, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) } - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil + return pickRandomConnection(connections, p.randSource), nil } -func (p *purevpn) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { //nolint:dupl +func (p *purevpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { //nolint:dupl if len(cipher) == 0 { cipher = aes256cbc } @@ -114,15 +114,13 @@ func (p *purevpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u // Modified variables fmt.Sprintf("verb %d", verbosity), fmt.Sprintf("auth-user-pass %s", constants.OpenVPNAuthConf), - fmt.Sprintf("proto %s", string(connections[0].Protocol)), + fmt.Sprintf("proto %s", string(connection.Protocol)), + fmt.Sprintf("remote %s %d", connection.IP.String(), connection.Port), fmt.Sprintf("cipher %s", cipher), } if !root { lines = append(lines, "user nonrootuser") } - for _, connection := range connections { - lines = append(lines, fmt.Sprintf("remote %s %d", connection.IP.String(), connection.Port)) - } lines = append(lines, []string{ "", "-----BEGIN CERTIFICATE-----", @@ -156,7 +154,7 @@ func (p *purevpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u if len(auth) > 0 { lines = append(lines, "auth "+auth) } - if connections[0].Protocol == constants.UDP { + if connection.Protocol == constants.UDP { lines = append(lines, "explicit-exit-notify") } return lines diff --git a/internal/provider/surfshark.go b/internal/provider/surfshark.go index 8b61181d..072f767f 100644 --- a/internal/provider/surfshark.go +++ b/internal/provider/surfshark.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "math/rand" "net" "net/http" "strings" @@ -15,12 +16,14 @@ import ( ) type surfshark struct { - servers []models.SurfsharkServer + servers []models.SurfsharkServer + randSource rand.Source } -func newSurfshark(servers []models.SurfsharkServer) *surfshark { +func newSurfshark(servers []models.SurfsharkServer, timeNow timeNowFunc) *surfshark { return &surfshark{ - servers: servers, + servers: servers, + randSource: rand.NewSource(timeNow().UnixNano()), } } @@ -36,10 +39,10 @@ func (s *surfshark) filterServers(region string) (servers []models.SurfsharkServ return nil } -func (s *surfshark) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { //nolint:dupl +func (s *surfshark) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { //nolint:dupl servers := s.filterServers(selection.Region) if len(servers) == 0 { - return nil, fmt.Errorf("no server found for region %q", selection.Region) + return connection, fmt.Errorf("no server found for region %q", selection.Region) } var port uint16 @@ -49,14 +52,15 @@ func (s *surfshark) GetOpenVPNConnections(selection models.ServerSelection) (con case selection.Protocol == constants.UDP: port = 1194 default: - return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) + return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) } + var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { if selection.TargetIP != nil { if selection.TargetIP.Equal(IP) { - return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + return models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}, nil } } else { connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) @@ -65,17 +69,13 @@ func (s *surfshark) GetOpenVPNConnections(selection models.ServerSelection) (con } if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + return connection, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) } - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil + return pickRandomConnection(connections, s.randSource), nil } -func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { //nolint:dupl +func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { //nolint:dupl if len(cipher) == 0 { cipher = aes256cbc } @@ -112,16 +112,14 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity, // Modified variables fmt.Sprintf("verb %d", verbosity), fmt.Sprintf("auth-user-pass %s", constants.OpenVPNAuthConf), - fmt.Sprintf("proto %s", connections[0].Protocol), + fmt.Sprintf("proto %s", connection.Protocol), + fmt.Sprintf("remote %s %d", connection.IP, connection.Port), fmt.Sprintf("cipher %s", cipher), fmt.Sprintf("auth %s", auth), } if !root { lines = append(lines, "user nonrootuser") } - for _, connection := range connections { - lines = append(lines, fmt.Sprintf("remote %s %d", connection.IP, connection.Port)) - } lines = append(lines, []string{ "", "-----BEGIN CERTIFICATE-----", diff --git a/internal/provider/utils.go b/internal/provider/utils.go index 238d8add..67a06447 100644 --- a/internal/provider/utils.go +++ b/internal/provider/utils.go @@ -2,11 +2,15 @@ package provider import ( "context" + "math/rand" "time" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/golibs/logging" ) +type timeNowFunc func() time.Time + func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) { const retryPeriod = 10 * time.Second for { @@ -27,3 +31,7 @@ func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() er } } } + +func pickRandomConnection(connections []models.OpenVPNConnection, source rand.Source) models.OpenVPNConnection { + return connections[rand.New(source).Intn(len(connections))] //nolint:gosec +} diff --git a/internal/provider/vyprvpn.go b/internal/provider/vyprvpn.go index 2d54598c..de190b5f 100644 --- a/internal/provider/vyprvpn.go +++ b/internal/provider/vyprvpn.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "math/rand" "net" "net/http" "strings" @@ -15,12 +16,14 @@ import ( ) type vyprvpn struct { - servers []models.VyprvpnServer + servers []models.VyprvpnServer + randSource rand.Source } -func newVyprvpn(servers []models.VyprvpnServer) *vyprvpn { +func newVyprvpn(servers []models.VyprvpnServer, timeNow timeNowFunc) *vyprvpn { return &vyprvpn{ - servers: servers, + servers: servers, + randSource: rand.NewSource(timeNow().UnixNano()), } } @@ -36,27 +39,28 @@ func (v *vyprvpn) filterServers(region string) (servers []models.VyprvpnServer) return nil } -func (v *vyprvpn) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { +func (v *vyprvpn) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { servers := v.filterServers(selection.Region) if len(servers) == 0 { - return nil, fmt.Errorf("no server found for region %q", selection.Region) + return connection, fmt.Errorf("no server found for region %q", selection.Region) } var port uint16 switch { case selection.Protocol == constants.TCP: - return nil, fmt.Errorf("TCP protocol not supported by this VPN provider") + return connection, fmt.Errorf("TCP protocol not supported by this VPN provider") case selection.Protocol == constants.UDP: port = 443 default: - return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) + return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) } + var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { if selection.TargetIP != nil { if selection.TargetIP.Equal(IP) { - return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + return models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}, nil } } else { connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) @@ -65,17 +69,13 @@ func (v *vyprvpn) GetOpenVPNConnections(selection models.ServerSelection) (conne } if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + return connection, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) } - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil + return pickRandomConnection(connections, v.randSource), nil } -func (v *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { +func (v *vyprvpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { if len(cipher) == 0 { cipher = aes256cbc } @@ -106,16 +106,14 @@ func (v *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u // Modified variables fmt.Sprintf("verb %d", verbosity), fmt.Sprintf("auth-user-pass %s", constants.OpenVPNAuthConf), - fmt.Sprintf("proto %s", connections[0].Protocol), + fmt.Sprintf("proto %s", connection.Protocol), + fmt.Sprintf("remote %s %d", connection.IP, connection.Port), fmt.Sprintf("cipher %s", cipher), fmt.Sprintf("auth %s", auth), } if !root { lines = append(lines, "user nonrootuser") } - for _, connection := range connections { - lines = append(lines, fmt.Sprintf("remote %s %d", connection.IP, connection.Port)) - } lines = append(lines, []string{ "", "-----BEGIN CERTIFICATE-----", diff --git a/internal/provider/windscribe.go b/internal/provider/windscribe.go index 95142582..3075e380 100644 --- a/internal/provider/windscribe.go +++ b/internal/provider/windscribe.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "math/rand" "net" "net/http" "strings" @@ -15,12 +16,14 @@ import ( ) type windscribe struct { - servers []models.WindscribeServer + servers []models.WindscribeServer + randSource rand.Source } -func newWindscribe(servers []models.WindscribeServer) *windscribe { +func newWindscribe(servers []models.WindscribeServer, timeNow timeNowFunc) *windscribe { return &windscribe{ - servers: servers, + servers: servers, + randSource: rand.NewSource(timeNow().UnixNano()), } } @@ -36,10 +39,10 @@ func (w *windscribe) filterServers(region string) (servers []models.WindscribeSe return nil } -func (w *windscribe) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { +func (w *windscribe) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) { servers := w.filterServers(selection.Region) if len(servers) == 0 { - return nil, fmt.Errorf("no server found for region %q", selection.Region) + return connection, fmt.Errorf("no server found for region %q", selection.Region) } var port uint16 @@ -51,14 +54,15 @@ func (w *windscribe) GetOpenVPNConnections(selection models.ServerSelection) (co case selection.Protocol == constants.UDP: port = 443 default: - return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) + return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) } + var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { if selection.TargetIP != nil { if selection.TargetIP.Equal(IP) { - return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + return models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}, nil } } else { connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) @@ -67,17 +71,13 @@ func (w *windscribe) GetOpenVPNConnections(selection models.ServerSelection) (co } if selection.TargetIP != nil { - return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + return connection, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) } - if len(connections) > 64 { - connections = connections[:64] - } - - return connections, nil + return pickRandomConnection(connections, w.randSource), nil } -func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { +func (w *windscribe) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { if len(cipher) == 0 { cipher = aes256cbc } @@ -107,7 +107,8 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity // Modified variables fmt.Sprintf("verb %d", verbosity), fmt.Sprintf("auth-user-pass %s", constants.OpenVPNAuthConf), - fmt.Sprintf("proto %s", connections[0].Protocol), + fmt.Sprintf("proto %s", connection.Protocol), + fmt.Sprintf("remote %s %d", connection.IP, connection.Port), fmt.Sprintf("cipher %s", cipher), fmt.Sprintf("auth %s", auth), } @@ -117,9 +118,6 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity if !root { lines = append(lines, "user nonrootuser") } - for _, connection := range connections { - lines = append(lines, fmt.Sprintf("remote %s %d", connection.IP, connection.Port)) - } lines = append(lines, []string{ "", "-----BEGIN CERTIFICATE-----",