diff --git a/internal/constants/mullvad.go b/internal/constants/mullvad.go index 9dc0338f..a995996c 100644 --- a/internal/constants/mullvad.go +++ b/internal/constants/mullvad.go @@ -13,7 +13,7 @@ const ( func MullvadCountryChoices() (choices []string) { uniqueChoices := map[string]struct{}{} - for _, server := range mullvadServers() { + for _, server := range MullvadServers() { uniqueChoices[server.Country] = struct{}{} } for choice := range uniqueChoices { @@ -27,7 +27,7 @@ func MullvadCountryChoices() (choices []string) { func MullvadCityChoices() (choices []string) { uniqueChoices := map[string]struct{}{} - for _, server := range mullvadServers() { + for _, server := range MullvadServers() { uniqueChoices[server.City] = struct{}{} } for choice := range uniqueChoices { @@ -41,7 +41,7 @@ func MullvadCityChoices() (choices []string) { func MullvadISPChoices() (choices []string) { uniqueChoices := map[string]struct{}{} - for _, server := range mullvadServers() { + for _, server := range MullvadServers() { uniqueChoices[server.ISP] = struct{}{} } for choice := range uniqueChoices { @@ -53,25 +53,7 @@ func MullvadISPChoices() (choices []string) { return choices } -func MullvadServerFilter(country, city, isp string) (servers []models.MullvadServer) { - for _, server := range mullvadServers() { - if len(country) == 0 { - server.Country = "" - } - if len(city) == 0 { - server.City = "" - } - if len(isp) == 0 { - server.ISP = "" - } - if server.Country == country && server.City == city && server.ISP == isp { - servers = append(servers, server) - } - } - return servers -} - -func mullvadServers() []models.MullvadServer { +func MullvadServers() []models.MullvadServer { return []models.MullvadServer{ { Country: "united arab emirates", diff --git a/internal/provider/cyberghost.go b/internal/provider/cyberghost.go index 18bf2566..da26eec8 100644 --- a/internal/provider/cyberghost.go +++ b/internal/provider/cyberghost.go @@ -2,7 +2,6 @@ package provider import ( "fmt" - "net" "strings" "github.com/qdm12/golibs/network" @@ -16,32 +15,48 @@ func newCyberghost() *cyberghost { return &cyberghost{} } -func (c *cyberghost) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - var IPs []net.IP - for _, server := range constants.CyberghostServers() { - if strings.EqualFold(server.Region, selection.Region) && strings.EqualFold(server.Group, selection.Group) { - IPs = server.IPs +func (c *cyberghost) filterServers(region, group string) (servers []models.CyberghostServer) { + allServers := constants.CyberghostServers() + for i, server := range allServers { + if len(region) == 0 { + server.Region = "" + } + if len(group) == 0 { + server.Group = "" + } + if strings.EqualFold(server.Region, region) && strings.EqualFold(server.Group, group) { + servers = append(servers, allServers[i]) } } - if len(IPs) == 0 { - return nil, fmt.Errorf("no IP found for group %q and region %q", selection.Group, selection.Region) + return servers +} + +func (c *cyberghost) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { + servers := c.filterServers(selection.Region, selection.Group) + if len(servers) == 0 { + return nil, fmt.Errorf("no server found for region %q and group %q", selection.Region, selection.Group) } - if selection.TargetIP != nil { - found := false - for i := range IPs { - if IPs[i].Equal(selection.TargetIP) { - found = true - break + + for _, server := range servers { + for _, IP := range server.IPs { + if selection.TargetIP != nil { + if selection.TargetIP.Equal(IP) { + return []models.OpenVPNConnection{{IP: IP, Port: 443, Protocol: selection.Protocol}}, nil + } + } else { + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: 443, Protocol: selection.Protocol}) } } - if !found { - return nil, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) - } - IPs = []net.IP{selection.TargetIP} } - for _, IP := range IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: 443, Protocol: selection.Protocol}) + + if selection.TargetIP != nil { + return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) } + + if len(connections) > 64 { + connections = connections[:64] + } + return connections, nil } diff --git a/internal/provider/mullvad.go b/internal/provider/mullvad.go index a59bc9e6..7fbefc3a 100644 --- a/internal/provider/mullvad.go +++ b/internal/provider/mullvad.go @@ -14,11 +14,31 @@ func newMullvad() *mullvad { return &mullvad{} } +func (m *mullvad) filterServers(country, city, isp string) (servers []models.MullvadServer) { + allServers := constants.MullvadServers() + for i, server := range allServers { + if len(country) == 0 { + server.Country = "" + } + if len(city) == 0 { + server.City = "" + } + if len(isp) == 0 { + server.ISP = "" + } + if server.Country == country && server.City == city && server.ISP == isp { + servers = append(servers, allServers[i]) + } + } + return servers +} + func (m *mullvad) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - servers := constants.MullvadServerFilter(selection.Country, selection.City, selection.ISP) + servers := m.filterServers(selection.Country, selection.City, selection.ISP) if len(servers) == 0 { return nil, fmt.Errorf("no server found for country %q, city %q and ISP %q", selection.Country, selection.City, selection.ISP) } + for _, server := range servers { port := server.DefaultPort if selection.CustomPort > 0 { @@ -34,9 +54,15 @@ func (m *mullvad) GetOpenVPNConnections(selection models.ServerSelection) (conne } } } + if selection.TargetIP != nil { return nil, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) } + + if len(connections) > 64 { + connections = connections[:64] + } + return connections, nil } diff --git a/internal/provider/nordvpn.go b/internal/provider/nordvpn.go index 8265f8bf..76f9acbf 100644 --- a/internal/provider/nordvpn.go +++ b/internal/provider/nordvpn.go @@ -2,8 +2,6 @@ package provider import ( "fmt" - "net" - "strings" "github.com/qdm12/golibs/network" "github.com/qdm12/private-internet-access-docker/internal/constants" @@ -16,60 +14,34 @@ func newNordvpn() *nordvpn { return &nordvpn{} } -func findServers(selection models.ServerSelection) (servers []models.NordvpnServer) { - for _, server := range constants.NordvpnServers() { - if strings.EqualFold(server.Region, selection.Region) { - if (selection.Protocol == constants.TCP && !server.TCP) || (selection.Protocol == constants.UDP && !server.UDP) { - continue - } - if selection.Number > 0 && server.Number == selection.Number { - return []models.NordvpnServer{server} - } - servers = append(servers, server) +func (n *nordvpn) filterServers(region string, protocol models.NetworkProtocol, number uint16) (servers []models.NordvpnServer) { + allServers := constants.NordvpnServers() + for i, server := range allServers { + if len(region) == 0 { + server.Region = "" + } + if number == 0 { + server.Number = 0 + } + + if protocol == constants.TCP && !server.TCP { + continue + } else if protocol == constants.UDP && !server.UDP { + continue + } + if server.Region == region && server.Number == number { + servers = append(servers, allServers[i]) } } return servers } -func extractIPsFromServers(servers []models.NordvpnServer) (ips []net.IP) { - ips = make([]net.IP, len(servers)) - for i := range servers { - ips[i] = servers[i].IP - } - return ips -} - -func targetIPInIps(targetIP net.IP, ips []net.IP) error { - for i := range ips { - if targetIP.Equal(ips[i]) { - return nil - } - } - ipsString := make([]string, len(ips)) - for i := range ips { - ipsString[i] = ips[i].String() - } - return fmt.Errorf("target IP address %s not found in IP addresses %s", targetIP, strings.Join(ipsString, ", ")) -} - func (n *nordvpn) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { //nolint:dupl - servers := findServers(selection) - ips := extractIPsFromServers(servers) - if len(ips) == 0 { - if selection.Number > 0 { - return nil, fmt.Errorf("no IP found for region %q, protocol %s and number %d", selection.Region, selection.Protocol, selection.Number) - } - return nil, fmt.Errorf("no IP found for region %q, protocol %s", selection.Region, selection.Protocol) - } - var IP net.IP - if selection.TargetIP != nil { - if err := targetIPInIps(selection.TargetIP, ips); err != nil { - return nil, err - } - IP = selection.TargetIP - } else { - IP = ips[0] + servers := n.filterServers(selection.Region, selection.Protocol, selection.Number) + if len(servers) == 0 { + return nil, fmt.Errorf("no server found for region %q, protocol %s and number %d", selection.Region, selection.Protocol, selection.Number) } + var port uint16 switch { case selection.Protocol == constants.UDP: @@ -79,7 +51,26 @@ func (n *nordvpn) GetOpenVPNConnections(selection models.ServerSelection) (conne default: return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) } - return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + + for _, server := range servers { + if selection.TargetIP != nil { + if selection.TargetIP.Equal(server.IP) { + return []models.OpenVPNConnection{{IP: server.IP, Port: port, Protocol: selection.Protocol}}, nil + } + } else { + connections = append(connections, models.OpenVPNConnection{IP: server.IP, Port: port, Protocol: selection.Protocol}) + } + } + + if selection.TargetIP != nil { + return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + } + + if len(connections) > 64 { + connections = connections[:64] + } + + return connections, nil } func (n *nordvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { //nolint:dupl @@ -97,7 +88,6 @@ func (n *nordvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u "remote-cert-tls server", // Nordvpn specific - "resolv-retry infinite", "tun-mtu 1500", "tun-mtu-extra 32", "mssfix 1450", diff --git a/internal/provider/pia.go b/internal/provider/pia.go index 10b78780..e1bc0917 100644 --- a/internal/provider/pia.go +++ b/internal/provider/pia.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "net" "net/http" "strings" @@ -24,29 +23,24 @@ func newPrivateInternetAccess() *pia { } } -func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - var IPs []net.IP +func (p *pia) filterServers(region string) (servers []models.PIAServer) { + if len(region) == 0 { + return constants.PIAServers() + } for _, server := range constants.PIAServers() { - if strings.EqualFold(server.Region, selection.Region) { - IPs = server.IPs + if strings.EqualFold(server.Region, region) { + return []models.PIAServer{server} } } - if len(IPs) == 0 { - return nil, fmt.Errorf("no IP found for region %q", selection.Region) - } - if selection.TargetIP != nil { - found := false - for i := range IPs { - if IPs[i].Equal(selection.TargetIP) { - found = true - break - } - } - if !found { - return nil, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) - } - IPs = []net.IP{selection.TargetIP} + return nil +} + +func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { + servers := p.filterServers(selection.Region) + if len(servers) == 0 { + return nil, fmt.Errorf("no server found for region %q", selection.Region) } + var port uint16 switch selection.Protocol { case constants.TCP: @@ -67,9 +61,27 @@ func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connectio if port == 0 { return nil, fmt.Errorf("combination of protocol %q and encryption %q does not yield any port number", selection.Protocol, selection.EncryptionPreset) } - for _, IP := range IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + + for _, server := range servers { + for _, IP := range server.IPs { + if selection.TargetIP != nil { + if selection.TargetIP.Equal(IP) { + return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + } + } else { + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + } + } } + + if selection.TargetIP != nil { + return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + } + + if len(connections) > 64 { + connections = connections[:64] + } + return connections, nil } diff --git a/internal/provider/surfshark.go b/internal/provider/surfshark.go index f0416257..e1e3f57d 100644 --- a/internal/provider/surfshark.go +++ b/internal/provider/surfshark.go @@ -2,7 +2,6 @@ package provider import ( "fmt" - "net" "strings" "github.com/qdm12/golibs/network" @@ -16,29 +15,24 @@ func newSurfshark() *surfshark { return &surfshark{} } -func (s *surfshark) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { //nolint:dupl - var IPs []net.IP +func (s *surfshark) filterServers(region string) (servers []models.SurfsharkServer) { + if len(region) == 0 { + return constants.SurfsharkServers() + } for _, server := range constants.SurfsharkServers() { - if strings.EqualFold(server.Region, selection.Region) { - IPs = server.IPs + if strings.EqualFold(server.Region, region) { + return []models.SurfsharkServer{server} } } - if len(IPs) == 0 { - return nil, fmt.Errorf("no IP found for region %q", selection.Region) - } - if selection.TargetIP != nil { - found := false - for i := range IPs { - if IPs[i].Equal(selection.TargetIP) { - found = true - break - } - } - if !found { - return nil, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) - } - IPs = []net.IP{selection.TargetIP} + return nil +} + +func (s *surfshark) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { //nolint:dupl + servers := s.filterServers(selection.Region) + if len(servers) == 0 { + return nil, fmt.Errorf("no server found for region %q", selection.Region) } + var port uint16 switch { case selection.Protocol == constants.TCP: @@ -48,9 +42,27 @@ func (s *surfshark) GetOpenVPNConnections(selection models.ServerSelection) (con default: return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) } - for _, IP := range IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + + for _, server := range servers { + for _, IP := range server.IPs { + if selection.TargetIP != nil { + if selection.TargetIP.Equal(IP) { + return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + } + } else { + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + } + } } + + if selection.TargetIP != nil { + return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + } + + if len(connections) > 64 { + connections = connections[:64] + } + return connections, nil } diff --git a/internal/provider/vyprvpn.go b/internal/provider/vyprvpn.go index 8abd40a2..3eb28677 100644 --- a/internal/provider/vyprvpn.go +++ b/internal/provider/vyprvpn.go @@ -2,7 +2,6 @@ package provider import ( "fmt" - "net" "strings" "github.com/qdm12/golibs/network" @@ -16,29 +15,24 @@ func newVyprvpn() *vyprvpn { return &vyprvpn{} } -func (s *vyprvpn) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - var IPs []net.IP +func (v *vyprvpn) filterServers(region string) (servers []models.VyprvpnServer) { + if len(region) == 0 { + return constants.VyprvpnServers() + } for _, server := range constants.VyprvpnServers() { - if strings.EqualFold(server.Region, selection.Region) { - IPs = server.IPs + if strings.EqualFold(server.Region, region) { + return []models.VyprvpnServer{server} } } - if len(IPs) == 0 { - return nil, fmt.Errorf("no IP found for region %q", selection.Region) - } - if selection.TargetIP != nil { - found := false - for i := range IPs { - if IPs[i].Equal(selection.TargetIP) { - found = true - break - } - } - if !found { - return nil, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) - } - IPs = []net.IP{selection.TargetIP} + return nil +} + +func (v *vyprvpn) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { + servers := v.filterServers(selection.Region) + if len(servers) == 0 { + return nil, fmt.Errorf("no server found for region %q", selection.Region) } + var port uint16 switch { case selection.Protocol == constants.TCP: @@ -48,13 +42,31 @@ func (s *vyprvpn) GetOpenVPNConnections(selection models.ServerSelection) (conne default: return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) } - for _, IP := range IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + + for _, server := range servers { + for _, IP := range server.IPs { + if selection.TargetIP != nil { + if selection.TargetIP.Equal(IP) { + return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + } + } else { + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + } + } } + + if selection.TargetIP != nil { + return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + } + + if len(connections) > 64 { + connections = connections[:64] + } + return connections, nil } -func (s *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { +func (v *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { if len(cipher) == 0 { cipher = aes256cbc } @@ -105,6 +117,6 @@ func (s *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u return lines } -func (s *vyprvpn) GetPortForward(client network.Client) (port uint16, err error) { +func (v *vyprvpn) GetPortForward(client network.Client) (port uint16, err error) { panic("port forwarding is not supported for vyprvpn") } diff --git a/internal/provider/windscribe.go b/internal/provider/windscribe.go index 814e20cb..da757352 100644 --- a/internal/provider/windscribe.go +++ b/internal/provider/windscribe.go @@ -2,7 +2,6 @@ package provider import ( "fmt" - "net" "strings" "github.com/qdm12/golibs/network" @@ -16,29 +15,24 @@ func newWindscribe() *windscribe { return &windscribe{} } -func (w *windscribe) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { - var IPs []net.IP +func (w *windscribe) filterServers(region string) (servers []models.WindscribeServer) { + if len(region) == 0 { + return constants.WindscribeServers() + } for _, server := range constants.WindscribeServers() { - if strings.EqualFold(server.Region, selection.Region) { - IPs = server.IPs + if strings.EqualFold(server.Region, region) { + return []models.WindscribeServer{server} } } - if len(IPs) == 0 { - return nil, fmt.Errorf("no IP found for region %q", selection.Region) - } - if selection.TargetIP != nil { - found := false - for i := range IPs { - if IPs[i].Equal(selection.TargetIP) { - found = true - break - } - } - if !found { - return nil, fmt.Errorf("target IP address %q not found in IP addresses", selection.TargetIP) - } - IPs = []net.IP{selection.TargetIP} + return nil +} + +func (w *windscribe) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { + servers := w.filterServers(selection.Region) + if len(servers) == 0 { + return nil, fmt.Errorf("no server found for region %q", selection.Region) } + var port uint16 switch { case selection.CustomPort > 0: @@ -50,9 +44,27 @@ func (w *windscribe) GetOpenVPNConnections(selection models.ServerSelection) (co default: return nil, fmt.Errorf("protocol %q is unknown", selection.Protocol) } - for _, IP := range IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + + for _, server := range servers { + for _, IP := range server.IPs { + if selection.TargetIP != nil { + if selection.TargetIP.Equal(IP) { + return []models.OpenVPNConnection{{IP: IP, Port: port, Protocol: selection.Protocol}}, nil + } + } else { + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + } + } } + + if selection.TargetIP != nil { + return nil, fmt.Errorf("target IP %s not found in IP addresses", selection.TargetIP) + } + + if len(connections) > 64 { + connections = connections[:64] + } + return connections, nil }