chore(all): provider to servers map in allServers

- Simplify formatting CLI
- Simplify updater code
- Simplify filter choices for config validation
- Simplify all servers deep copying
- Custom JSON marshaling methods for `AllServers`
- Simplify provider constructor switch
- Simplify storage merging
- Simplify storage reading and extraction
- Simplify updating code
This commit is contained in:
Quentin McGaw
2022-05-27 00:59:47 +00:00
parent 5ffe8555ba
commit bd0868d764
22 changed files with 854 additions and 1295 deletions

View File

@@ -4,108 +4,17 @@ import (
"net"
)
func (a AllServers) GetCopy() (servers AllServers) {
servers = a // copy versions and timestamps
servers.Cyberghost.Servers = a.GetCyberghost()
servers.Expressvpn.Servers = a.GetExpressvpn()
servers.Fastestvpn.Servers = a.GetFastestvpn()
servers.HideMyAss.Servers = a.GetHideMyAss()
servers.Ipvanish.Servers = a.GetIpvanish()
servers.Ivpn.Servers = a.GetIvpn()
servers.Mullvad.Servers = a.GetMullvad()
servers.Nordvpn.Servers = a.GetNordvpn()
servers.Perfectprivacy.Servers = a.GetPerfectprivacy()
servers.Privado.Servers = a.GetPrivado()
servers.Pia.Servers = a.GetPia()
servers.Privatevpn.Servers = a.GetPrivatevpn()
servers.Protonvpn.Servers = a.GetProtonvpn()
servers.Purevpn.Servers = a.GetPurevpn()
servers.Surfshark.Servers = a.GetSurfshark()
servers.Torguard.Servers = a.GetTorguard()
servers.VPNUnlimited.Servers = a.GetVPNUnlimited()
servers.Vyprvpn.Servers = a.GetVyprvpn()
servers.Windscribe.Servers = a.GetWindscribe()
return servers
}
func (a *AllServers) GetCyberghost() (servers []Server) {
return copyServers(a.Cyberghost.Servers)
}
func (a *AllServers) GetExpressvpn() (servers []Server) {
return copyServers(a.Expressvpn.Servers)
}
func (a *AllServers) GetFastestvpn() (servers []Server) {
return copyServers(a.Fastestvpn.Servers)
}
func (a *AllServers) GetHideMyAss() (servers []Server) {
return copyServers(a.HideMyAss.Servers)
}
func (a *AllServers) GetIpvanish() (servers []Server) {
return copyServers(a.Ipvanish.Servers)
}
func (a *AllServers) GetIvpn() (servers []Server) {
return copyServers(a.Ivpn.Servers)
}
func (a *AllServers) GetMullvad() (servers []Server) {
return copyServers(a.Mullvad.Servers)
}
func (a *AllServers) GetNordvpn() (servers []Server) {
return copyServers(a.Nordvpn.Servers)
}
func (a *AllServers) GetPerfectprivacy() (servers []Server) {
return copyServers(a.Perfectprivacy.Servers)
}
func (a *AllServers) GetPia() (servers []Server) {
return copyServers(a.Pia.Servers)
}
func (a *AllServers) GetPrivado() (servers []Server) {
return copyServers(a.Privado.Servers)
}
func (a *AllServers) GetPrivatevpn() (servers []Server) {
return copyServers(a.Privatevpn.Servers)
}
func (a *AllServers) GetProtonvpn() (servers []Server) {
return copyServers(a.Protonvpn.Servers)
}
func (a *AllServers) GetPurevpn() (servers []Server) {
return copyServers(a.Purevpn.Servers)
}
func (a *AllServers) GetSurfshark() (servers []Server) {
return copyServers(a.Surfshark.Servers)
}
func (a *AllServers) GetTorguard() (servers []Server) {
return copyServers(a.Torguard.Servers)
}
func (a *AllServers) GetVPNUnlimited() (servers []Server) {
return copyServers(a.VPNUnlimited.Servers)
}
func (a *AllServers) GetVyprvpn() (servers []Server) {
return copyServers(a.Vyprvpn.Servers)
}
func (a *AllServers) GetWevpn() (servers []Server) {
return copyServers(a.Wevpn.Servers)
}
func (a *AllServers) GetWindscribe() (servers []Server) {
return copyServers(a.Windscribe.Servers)
func (a AllServers) GetCopy() (allServersCopy AllServers) {
allServersCopy.Version = a.Version
allServersCopy.ProviderToServers = make(map[string]Servers, len(a.ProviderToServers))
for provider, servers := range a.ProviderToServers {
allServersCopy.ProviderToServers[provider] = Servers{
Version: servers.Version,
Timestamp: servers.Timestamp,
Servers: copyServers(servers.Servers),
}
}
return allServersCopy
}
func copyServers(servers []Server) (serversCopy []Server) {

View File

@@ -4,108 +4,117 @@ import (
"net"
"testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AllServers_GetCopy(t *testing.T) {
allServers := AllServers{
Cyberghost: Servers{
Version: 2,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Expressvpn: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Fastestvpn: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
HideMyAss: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Ipvanish: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Ivpn: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Mullvad: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Nordvpn: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Perfectprivacy: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Privado: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Pia: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Privatevpn: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Protonvpn: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Purevpn: Servers{
Version: 1,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Surfshark: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Torguard: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
VPNUnlimited: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Vyprvpn: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
Windscribe: Servers{
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 2,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Expressvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Fastestvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.HideMyAss: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Ipvanish: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Ivpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Mullvad: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Nordvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Perfectprivacy: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Privado: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.PrivateInternetAccess: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Privatevpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Protonvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Purevpn: {
Version: 1,
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Surfshark: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Torguard: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.VPNUnlimited: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Vyprvpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Wevpn: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
providers.Windscribe: {
Servers: []Server{{
IPs: []net.IP{{1, 2, 3, 4}},
}},
},
},
}
@@ -114,32 +123,6 @@ func Test_AllServers_GetCopy(t *testing.T) {
assert.Equal(t, allServers, servers)
}
func Test_AllServers_GetVyprvpn(t *testing.T) {
allServers := AllServers{
Vyprvpn: Servers{
Servers: []Server{
{Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}},
{Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}},
},
},
}
servers := allServers.GetVyprvpn()
expectedServers := []Server{
{Hostname: "a", IPs: []net.IP{{1, 1, 1, 1}}},
{Hostname: "b", IPs: []net.IP{{2, 2, 2, 2}}},
}
assert.Equal(t, expectedServers, servers)
allServers.Vyprvpn.Servers[0].IPs[0][0] = 9
assert.NotEqual(t, 9, servers[0].IPs[0][0])
allServers.Vyprvpn.Servers[0].IPs[0][0] = 1
servers[0].IPs[0][0] = 9
assert.NotEqual(t, 9, allServers.Vyprvpn.Servers[0].IPs[0][0])
}
func Test_copyIPs(t *testing.T) {
t.Parallel()

View File

@@ -1,54 +1,163 @@
package models
import (
"bytes"
"encoding/json"
"fmt"
"math"
"reflect"
"github.com/qdm12/gluetun/internal/constants/providers"
)
type AllServers struct {
Version uint16 `json:"version"` // used for migration of the top level scheme
Cyberghost Servers `json:"cyberghost"`
Expressvpn Servers `json:"expressvpn"`
Fastestvpn Servers `json:"fastestvpn"`
HideMyAss Servers `json:"hidemyass"`
Ipvanish Servers `json:"ipvanish"`
Ivpn Servers `json:"ivpn"`
Mullvad Servers `json:"mullvad"`
Perfectprivacy Servers `json:"perfect privacy"`
Nordvpn Servers `json:"nordvpn"`
Privado Servers `json:"privado"`
Pia Servers `json:"private internet access"`
Privatevpn Servers `json:"privatevpn"`
Protonvpn Servers `json:"protonvpn"`
Purevpn Servers `json:"purevpn"`
Surfshark Servers `json:"surfshark"`
Torguard Servers `json:"torguard"`
VPNUnlimited Servers `json:"vpn unlimited"`
Vyprvpn Servers `json:"vyprvpn"`
Wevpn Servers `json:"wevpn"`
Windscribe Servers `json:"windscribe"`
Version uint16 // used for migration of the top level scheme
ProviderToServers map[string]Servers
}
func (a *AllServers) Count() int {
return len(a.Cyberghost.Servers) +
len(a.Expressvpn.Servers) +
len(a.Fastestvpn.Servers) +
len(a.HideMyAss.Servers) +
len(a.Ipvanish.Servers) +
len(a.Ivpn.Servers) +
len(a.Mullvad.Servers) +
len(a.Nordvpn.Servers) +
len(a.Perfectprivacy.Servers) +
len(a.Privado.Servers) +
len(a.Pia.Servers) +
len(a.Privatevpn.Servers) +
len(a.Protonvpn.Servers) +
len(a.Purevpn.Servers) +
len(a.Surfshark.Servers) +
len(a.Torguard.Servers) +
len(a.VPNUnlimited.Servers) +
len(a.Vyprvpn.Servers) +
len(a.Wevpn.Servers) +
len(a.Windscribe.Servers)
func (a *AllServers) ServersSlice(provider string) []Server {
servers, ok := a.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in all servers", provider))
}
return copyServers(servers.Servers)
}
var _ json.Marshaler = (*AllServers)(nil)
// MarshalJSON marshals all servers to JSON.
// Note you need to use a pointer to all servers
// for it to work with native json methods such as
// json.Marshal.
func (a *AllServers) MarshalJSON() (data []byte, err error) {
buffer := bytes.NewBuffer(nil)
_, err = buffer.WriteString("{")
if err != nil {
return nil, fmt.Errorf("cannot write opening bracket: %w", err)
}
versionString := fmt.Sprintf(`"version":%d`, a.Version)
_, err = buffer.WriteString(versionString)
if err != nil {
return nil, fmt.Errorf("cannot write schema version string: %w", err)
}
for _, provider := range providers.All() {
servers, ok := a.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in all servers", provider))
}
providerKey := fmt.Sprintf(`,"%s":`, provider)
_, err = buffer.WriteString(providerKey)
if err != nil {
return nil, fmt.Errorf("cannot write provider key %s: %w",
providerKey, err)
}
serversJSON, err := json.Marshal(servers)
if err != nil {
return nil, fmt.Errorf("failed encoding servers for provider %s: %w",
provider, err)
}
_, err = buffer.Write(serversJSON)
if err != nil {
return nil, fmt.Errorf("cannot write JSON servers data for provider %s: %w",
provider, err)
}
}
_, err = buffer.WriteString("}")
if err != nil {
return nil, fmt.Errorf("cannot write closing bracket: %w", err)
}
return buffer.Bytes(), nil
}
var _ json.Unmarshaler = (*AllServers)(nil)
func (a *AllServers) UnmarshalJSON(data []byte) (err error) {
keyValues := make(map[string]interface{})
err = json.Unmarshal(data, &keyValues)
if err != nil {
return err
}
versionUnmarshaled := keyValues["version"]
if versionUnmarshaled != nil { // defaults to 0
version, ok := versionUnmarshaled.(float64)
if !ok {
return &json.UnmarshalTypeError{
Value: fmt.Sprintf("number %v", versionUnmarshaled),
Type: reflect.TypeOf(uint16(0)),
Struct: "models.AllServers",
Field: "Version",
}
}
if math.Round(version) != version ||
version < 0 || version > float64(^uint16(0)) {
return &json.UnmarshalTypeError{
Value: fmt.Sprintf("number %v", version),
Type: reflect.TypeOf(uint16(0)),
Struct: "models.AllServers",
Field: "Version",
}
}
a.Version = uint16(version)
delete(keyValues, "version")
}
if len(keyValues) == 0 {
return nil
}
a.ProviderToServers = make(map[string]Servers, len(keyValues))
allProviders := providers.All()
allProvidersSet := make(map[string]struct{}, len(allProviders))
for _, provider := range allProviders {
allProvidersSet[provider] = struct{}{}
}
for key, value := range keyValues {
if _, ok := allProvidersSet[key]; !ok {
// not a provider known by Gluetun
// or a non-servers field.
continue
}
jsonValue, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("cannot marshal %s servers: %w",
key, err)
}
var servers Servers
err = json.Unmarshal(jsonValue, &servers)
if err != nil {
return fmt.Errorf("cannot unmarshal %s servers: %w",
key, err)
}
a.ProviderToServers[key] = servers
}
return nil
}
func (a *AllServers) Count() (count int) {
for _, servers := range a.ProviderToServers {
count += len(servers.Servers)
}
return count
}
type Servers struct {
Version uint16 `json:"version"`
Timestamp int64 `json:"timestamp"`
Servers []Server `json:"servers"`
Servers []Server `json:"servers,omitempty"`
}

View File

@@ -0,0 +1,189 @@
package models
import (
"bytes"
"encoding/json"
"testing"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AllServers_MarshalJSON(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
allServers *AllServers
dataString string
errWrapped error
errMessage string
}{
"empty": {
allServers: &AllServers{
ProviderToServers: map[string]Servers{},
},
dataString: `{"version":0,` +
`"cyberghost":{"version":0,"timestamp":0},` +
`"expressvpn":{"version":0,"timestamp":0},` +
`"fastestvpn":{"version":0,"timestamp":0},` +
`"hidemyass":{"version":0,"timestamp":0},` +
`"ipvanish":{"version":0,"timestamp":0},` +
`"ivpn":{"version":0,"timestamp":0},` +
`"mullvad":{"version":0,"timestamp":0},` +
`"nordvpn":{"version":0,"timestamp":0},` +
`"perfect privacy":{"version":0,"timestamp":0},` +
`"privado":{"version":0,"timestamp":0},` +
`"private internet access":{"version":0,"timestamp":0},` +
`"privatevpn":{"version":0,"timestamp":0},` +
`"protonvpn":{"version":0,"timestamp":0},` +
`"purevpn":{"version":0,"timestamp":0},` +
`"surfshark":{"version":0,"timestamp":0},` +
`"torguard":{"version":0,"timestamp":0},` +
`"vpn unlimited":{"version":0,"timestamp":0},` +
`"vyprvpn":{"version":0,"timestamp":0},` +
`"wevpn":{"version":0,"timestamp":0},` +
`"windscribe":{"version":0,"timestamp":0}}`,
},
"two known providers": {
allServers: &AllServers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 1,
Timestamp: 1000,
Servers: []Server{
{Country: "A"},
{Country: "B"},
},
},
providers.Privado: {
Version: 2,
Timestamp: 2000,
Servers: []Server{
{City: "C"},
{City: "D"},
},
},
},
},
dataString: `{"version":1,` +
`"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` +
`"expressvpn":{"version":0,"timestamp":0},` +
`"fastestvpn":{"version":0,"timestamp":0},` +
`"hidemyass":{"version":0,"timestamp":0},` +
`"ipvanish":{"version":0,"timestamp":0},` +
`"ivpn":{"version":0,"timestamp":0},` +
`"mullvad":{"version":0,"timestamp":0},` +
`"nordvpn":{"version":0,"timestamp":0},` +
`"perfect privacy":{"version":0,"timestamp":0},` +
`"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]},` +
`"private internet access":{"version":0,"timestamp":0},` +
`"privatevpn":{"version":0,"timestamp":0},` +
`"protonvpn":{"version":0,"timestamp":0},` +
`"purevpn":{"version":0,"timestamp":0},` +
`"surfshark":{"version":0,"timestamp":0},` +
`"torguard":{"version":0,"timestamp":0},` +
`"vpn unlimited":{"version":0,"timestamp":0},` +
`"vyprvpn":{"version":0,"timestamp":0},` +
`"wevpn":{"version":0,"timestamp":0},` +
`"windscribe":{"version":0,"timestamp":0}}`,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
// Populate all providers in all servers
for _, provider := range providers.All() {
_, has := testCase.allServers.ProviderToServers[provider]
if !has {
testCase.allServers.ProviderToServers[provider] = Servers{}
}
}
data, err := testCase.allServers.MarshalJSON()
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.dataString, string(data))
data, err = json.Marshal(testCase.allServers)
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.dataString, string(data))
buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
// encoder.SetIndent("", " ")
err = encoder.Encode(testCase.allServers)
require.NoError(t, err)
assert.Equal(t, testCase.dataString+"\n", buffer.String())
})
}
}
func Test_AllServers_UnmarshalJSON(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
dataString string
allServers AllServers
errWrapped error
errMessage string
}{
"empty": {
dataString: "{}",
allServers: AllServers{},
},
"two known providers": {
dataString: `{"version":1,` +
`"cyberghost":{"version":1,"timestamp":1000,"servers":[{"country":"A"},{"country":"B"}]},` +
`"privado":{"version":2,"timestamp":2000,"servers":[{"city":"C"},{"city":"D"}]}}`,
allServers: AllServers{
Version: 1,
ProviderToServers: map[string]Servers{
providers.Cyberghost: {
Version: 1,
Timestamp: 1000,
Servers: []Server{
{Country: "A"},
{Country: "B"},
},
},
providers.Privado: {
Version: 2,
Timestamp: 2000,
Servers: []Server{
{City: "C"},
{City: "D"},
},
},
},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
data := []byte(testCase.dataString)
var allServers AllServers
err := json.Unmarshal(data, &allServers)
assert.ErrorIs(t, err, testCase.errWrapped)
if err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.allServers, allServers)
})
}
}