chore(storage): only pass hardcoded versions to read file

This commit is contained in:
Quentin McGaw
2022-05-28 22:36:16 +00:00
parent 22455ac76f
commit 90dd3b1b5c
4 changed files with 76 additions and 55 deletions

View File

@@ -16,7 +16,7 @@ import (
// readFromFile reads the servers from server.json. // readFromFile reads the servers from server.json.
// It only reads servers that have the same version as the hardcoded servers version // It only reads servers that have the same version as the hardcoded servers version
// to avoid JSON unmarshaling errors. // to avoid JSON unmarshaling errors.
func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) ( func (s *Storage) readFromFile(filepath string, hardcodedVersions map[string]uint16) (
servers models.AllServers, err error) { servers models.AllServers, err error) {
file, err := os.Open(filepath) file, err := os.Open(filepath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
@@ -34,10 +34,10 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
return servers, err return servers, err
} }
return s.extractServersFromBytes(b, hardcoded) return s.extractServersFromBytes(b, hardcodedVersions)
} }
func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) ( func (s *Storage) extractServersFromBytes(b []byte, hardcodedVersions map[string]uint16) (
servers models.AllServers, err error) { servers models.AllServers, err error) {
rawMessages := make(map[string]json.RawMessage) rawMessages := make(map[string]json.RawMessage)
if err := json.Unmarshal(b, &rawMessages); err != nil { if err := json.Unmarshal(b, &rawMessages); err != nil {
@@ -50,7 +50,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders)) servers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
titleCaser := cases.Title(language.English) titleCaser := cases.Title(language.English)
for _, provider := range allProviders { for _, provider := range allProviders {
hardcoded, ok := hardcoded.ProviderToServers[provider] hardcodedVersion, ok := hardcodedVersions[provider]
if !ok { if !ok {
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider)) panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
} }
@@ -64,7 +64,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
continue continue
} }
mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage, titleCaser) mergedServers, versionsMatch, err := s.readServers(provider, hardcodedVersion, rawMessage, titleCaser)
if err != nil { if err != nil {
return models.AllServers{}, err return models.AllServers{}, err
} else if !versionsMatch { } else if !versionsMatch {
@@ -82,7 +82,7 @@ var (
errDecodeProvider = errors.New("cannot decode servers for provider") errDecodeProvider = errors.New("cannot decode servers for provider")
) )
func (s *Storage) readServers(provider string, hardcoded models.Servers, func (s *Storage) readServers(provider string, hardcodedVersion uint16,
rawMessage json.RawMessage, titleCaser cases.Caser) (servers models.Servers, rawMessage json.RawMessage, titleCaser cases.Caser) (servers models.Servers,
versionsMatch bool, err error) { versionsMatch bool, err error) {
provider = titleCaser.String(provider) provider = titleCaser.String(provider)
@@ -93,9 +93,9 @@ func (s *Storage) readServers(provider string, hardcoded models.Servers,
return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err) return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
} }
versionsMatch = hardcoded.Version == persistedServers.Version versionsMatch = hardcodedVersion == persistedServers.Version
if !versionsMatch { if !versionsMatch {
s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version) s.logVersionDiff(provider, hardcodedVersion, persistedServers.Version)
return servers, versionsMatch, nil return servers, versionsMatch, nil
} }

View File

@@ -11,7 +11,21 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func populateProviders(allProviderVersion uint16, allProviderTimestamp int64, func populateProviderToVersion(allProviderVersion uint16,
providerToVersion map[string]uint16) map[string]uint16 {
allProviders := providers.All()
for _, provider := range allProviders {
_, has := providerToVersion[provider]
if has {
continue
}
providerToVersion[provider] = allProviderVersion
}
return providerToVersion
}
func populateAllServersVersion(allProviderVersion uint16,
servers models.AllServers) models.AllServers { servers models.AllServers) models.AllServers {
allProviders := providers.All() allProviders := providers.All()
if servers.ProviderToServers == nil { if servers.ProviderToServers == nil {
@@ -23,8 +37,7 @@ func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
continue continue
} }
servers.ProviderToServers[provider] = models.Servers{ servers.ProviderToServers[provider] = models.Servers{
Version: allProviderVersion, Version: allProviderVersion,
Timestamp: allProviderTimestamp,
} }
} }
return servers return servers
@@ -34,54 +47,54 @@ func Test_extractServersFromBytes(t *testing.T) {
t.Parallel() t.Parallel()
testCases := map[string]struct { testCases := map[string]struct {
b []byte b []byte
hardcoded models.AllServers hardcodedVersions map[string]uint16
logged []string logged []string
persisted models.AllServers persisted models.AllServers
errMessage string errMessage string
}{ }{
"bad JSON": { "bad JSON": {
b: []byte("garbage"), b: []byte("garbage"),
errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value", errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value",
}, },
"bad provider JSON": { "bad provider JSON": {
b: []byte(`{"cyberghost": "garbage"}`), b: []byte(`{"cyberghost": "garbage"}`),
hardcoded: populateProviders(1, 0, models.AllServers{}), hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
errMessage: "cannot decode servers for provider: Cyberghost: " + errMessage: "cannot decode servers for provider: Cyberghost: " +
"json: cannot unmarshal string into Go value of type models.Servers", "json: cannot unmarshal string into Go value of type models.Servers",
}, },
"absent provider keys": { "absent provider keys": {
b: []byte(`{}`), b: []byte(`{}`),
hardcoded: populateProviders(1, 0, models.AllServers{}), hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
persisted: models.AllServers{ persisted: models.AllServers{
ProviderToServers: map[string]models.Servers{}, ProviderToServers: map[string]models.Servers{},
}, },
}, },
"same versions": { "same versions": {
b: []byte(`{ b: []byte(`{
"cyberghost": {"version": 1, "timestamp": 1}, "cyberghost": {"version": 1, "timestamp": 0},
"expressvpn": {"version": 1, "timestamp": 1}, "expressvpn": {"version": 1, "timestamp": 0},
"fastestvpn": {"version": 1, "timestamp": 1}, "fastestvpn": {"version": 1, "timestamp": 0},
"hidemyass": {"version": 1, "timestamp": 1}, "hidemyass": {"version": 1, "timestamp": 0},
"ipvanish": {"version": 1, "timestamp": 1}, "ipvanish": {"version": 1, "timestamp": 0},
"ivpn": {"version": 1, "timestamp": 1}, "ivpn": {"version": 1, "timestamp": 0},
"mullvad": {"version": 1, "timestamp": 1}, "mullvad": {"version": 1, "timestamp": 0},
"nordvpn": {"version": 1, "timestamp": 1}, "nordvpn": {"version": 1, "timestamp": 0},
"perfect privacy": {"version": 1, "timestamp": 1}, "perfect privacy": {"version": 1, "timestamp": 0},
"privado": {"version": 1, "timestamp": 1}, "privado": {"version": 1, "timestamp": 0},
"private internet access": {"version": 1, "timestamp": 1}, "private internet access": {"version": 1, "timestamp": 0},
"privatevpn": {"version": 1, "timestamp": 1}, "privatevpn": {"version": 1, "timestamp": 0},
"protonvpn": {"version": 1, "timestamp": 1}, "protonvpn": {"version": 1, "timestamp": 0},
"purevpn": {"version": 1, "timestamp": 1}, "purevpn": {"version": 1, "timestamp": 0},
"surfshark": {"version": 1, "timestamp": 1}, "surfshark": {"version": 1, "timestamp": 0},
"torguard": {"version": 1, "timestamp": 1}, "torguard": {"version": 1, "timestamp": 0},
"vpn unlimited": {"version": 1, "timestamp": 1}, "vpn unlimited": {"version": 1, "timestamp": 0},
"vyprvpn": {"version": 1, "timestamp": 1}, "vyprvpn": {"version": 1, "timestamp": 0},
"wevpn": {"version": 1, "timestamp": 1}, "wevpn": {"version": 1, "timestamp": 0},
"windscribe": {"version": 1, "timestamp": 1} "windscribe": {"version": 1, "timestamp": 0}
}`), }`),
hardcoded: populateProviders(1, 0, models.AllServers{}), hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
persisted: populateProviders(1, 1, models.AllServers{}), persisted: populateAllServersVersion(1, models.AllServers{}),
}, },
"different versions": { "different versions": {
b: []byte(`{ b: []byte(`{
@@ -106,7 +119,7 @@ func Test_extractServersFromBytes(t *testing.T) {
"wevpn": {"version": 1, "timestamp": 1}, "wevpn": {"version": 1, "timestamp": 1},
"windscribe": {"version": 1, "timestamp": 1} "windscribe": {"version": 1, "timestamp": 1}
}`), }`),
hardcoded: populateProviders(2, 0, models.AllServers{}), hardcodedVersions: populateProviderToVersion(2, map[string]uint16{}),
logged: []string{ logged: []string{
"Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2", "Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2", "Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
@@ -155,7 +168,7 @@ func Test_extractServersFromBytes(t *testing.T) {
logger: logger, logger: logger,
} }
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded) servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcodedVersions)
if testCase.errMessage != "" { if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage) assert.EqualError(t, err, testCase.errMessage)
@@ -176,15 +189,13 @@ func Test_extractServersFromBytes(t *testing.T) {
require.GreaterOrEqual(t, len(allProviders), 2) require.GreaterOrEqual(t, len(allProviders), 2)
b := []byte(`{}`) b := []byte(`{}`)
hardcoded := models.AllServers{ hardcodedVersions := map[string]uint16{
ProviderToServers: map[string]models.Servers{ allProviders[0]: 1,
allProviders[0]: {}, // Missing provider allProviders[1]
// Missing provider allProviders[1]
},
} }
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1]) expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
assert.PanicsWithValue(t, expectedPanicValue, func() { assert.PanicsWithValue(t, expectedPanicValue, func() {
_, _ = s.extractServersFromBytes(b, hardcoded) _, _ = s.extractServersFromBytes(b, hardcodedVersions)
}) })
}) })
} }

View File

@@ -8,7 +8,10 @@ import (
//go:generate mockgen -destination=infoerrorer_mock_test.go -package $GOPACKAGE . InfoErrorer //go:generate mockgen -destination=infoerrorer_mock_test.go -package $GOPACKAGE . InfoErrorer
type Storage struct { type Storage struct {
mergedServers models.AllServers mergedServers models.AllServers
// this is stored in memory to avoid re-parsing
// the embedded JSON file on every call to the
// SyncServers method.
hardcodedServers models.AllServers hardcodedServers models.AllServers
logger Infoer logger Infoer
filepath string filepath string
@@ -22,11 +25,12 @@ type Infoer interface {
// embedded servers file and the file on disk. // embedded servers file and the file on disk.
// Passing an empty filepath disables writing servers to a file. // Passing an empty filepath disables writing servers to a file.
func New(logger Infoer, filepath string) (storage *Storage, err error) { func New(logger Infoer, filepath string) (storage *Storage, err error) {
// error returned covered by unit test // A unit test prevents any error from being returned
harcodedServers, _ := parseHardcodedServers() // and ensures all providers are part of the servers returned.
hardcodedServers, _ := parseHardcodedServers()
storage = &Storage{ storage = &Storage{
hardcodedServers: harcodedServers, hardcodedServers: hardcodedServers,
logger: logger, logger: logger,
filepath: filepath, filepath: filepath,
} }

View File

@@ -14,8 +14,14 @@ func countServers(allServers models.AllServers) (count int) {
return count return count
} }
// SyncServers merges the hardcoded servers with the ones from the file.
func (s *Storage) SyncServers() (err error) { func (s *Storage) SyncServers() (err error) {
serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers) hardcodedVersions := make(map[string]uint16, len(s.hardcodedServers.ProviderToServers))
for provider, servers := range s.hardcodedServers.ProviderToServers {
hardcodedVersions[provider] = servers.Version
}
serversOnFile, err := s.readFromFile(s.filepath, hardcodedVersions)
if err != nil { if err != nil {
return fmt.Errorf("cannot read servers from file: %w", err) return fmt.Errorf("cannot read servers from file: %w", err)
} }