chore(storage): only pass hardcoded versions to read file
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
|||||||
// readFromFile reads the servers from server.json.
|
// readFromFile reads the servers from server.json.
|
||||||
// It only reads servers that have the same version as the hardcoded servers version
|
// It only reads servers that have the same version as the hardcoded servers version
|
||||||
// to avoid JSON unmarshaling errors.
|
// to avoid JSON unmarshaling errors.
|
||||||
func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
|
func (s *Storage) readFromFile(filepath string, hardcodedVersions map[string]uint16) (
|
||||||
servers models.AllServers, err error) {
|
servers models.AllServers, err error) {
|
||||||
file, err := os.Open(filepath)
|
file, err := os.Open(filepath)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
@@ -34,10 +34,10 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
|
|||||||
return servers, err
|
return servers, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.extractServersFromBytes(b, hardcoded)
|
return s.extractServersFromBytes(b, hardcodedVersions)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) (
|
func (s *Storage) extractServersFromBytes(b []byte, hardcodedVersions map[string]uint16) (
|
||||||
servers models.AllServers, err error) {
|
servers models.AllServers, err error) {
|
||||||
rawMessages := make(map[string]json.RawMessage)
|
rawMessages := make(map[string]json.RawMessage)
|
||||||
if err := json.Unmarshal(b, &rawMessages); err != nil {
|
if err := json.Unmarshal(b, &rawMessages); err != nil {
|
||||||
@@ -50,7 +50,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
|
|||||||
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
|
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
|
||||||
titleCaser := cases.Title(language.English)
|
titleCaser := cases.Title(language.English)
|
||||||
for _, provider := range allProviders {
|
for _, provider := range allProviders {
|
||||||
hardcoded, ok := hardcoded.ProviderToServers[provider]
|
hardcodedVersion, ok := hardcodedVersions[provider]
|
||||||
if !ok {
|
if !ok {
|
||||||
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
|
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
|
||||||
}
|
}
|
||||||
@@ -64,7 +64,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage, titleCaser)
|
mergedServers, versionsMatch, err := s.readServers(provider, hardcodedVersion, rawMessage, titleCaser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return models.AllServers{}, err
|
return models.AllServers{}, err
|
||||||
} else if !versionsMatch {
|
} else if !versionsMatch {
|
||||||
@@ -82,7 +82,7 @@ var (
|
|||||||
errDecodeProvider = errors.New("cannot decode servers for provider")
|
errDecodeProvider = errors.New("cannot decode servers for provider")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Storage) readServers(provider string, hardcoded models.Servers,
|
func (s *Storage) readServers(provider string, hardcodedVersion uint16,
|
||||||
rawMessage json.RawMessage, titleCaser cases.Caser) (servers models.Servers,
|
rawMessage json.RawMessage, titleCaser cases.Caser) (servers models.Servers,
|
||||||
versionsMatch bool, err error) {
|
versionsMatch bool, err error) {
|
||||||
provider = titleCaser.String(provider)
|
provider = titleCaser.String(provider)
|
||||||
@@ -93,9 +93,9 @@ func (s *Storage) readServers(provider string, hardcoded models.Servers,
|
|||||||
return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
|
return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
versionsMatch = hardcoded.Version == persistedServers.Version
|
versionsMatch = hardcodedVersion == persistedServers.Version
|
||||||
if !versionsMatch {
|
if !versionsMatch {
|
||||||
s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version)
|
s.logVersionDiff(provider, hardcodedVersion, persistedServers.Version)
|
||||||
return servers, versionsMatch, nil
|
return servers, versionsMatch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,21 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
|
func populateProviderToVersion(allProviderVersion uint16,
|
||||||
|
providerToVersion map[string]uint16) map[string]uint16 {
|
||||||
|
allProviders := providers.All()
|
||||||
|
for _, provider := range allProviders {
|
||||||
|
_, has := providerToVersion[provider]
|
||||||
|
if has {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
providerToVersion[provider] = allProviderVersion
|
||||||
|
}
|
||||||
|
return providerToVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
func populateAllServersVersion(allProviderVersion uint16,
|
||||||
servers models.AllServers) models.AllServers {
|
servers models.AllServers) models.AllServers {
|
||||||
allProviders := providers.All()
|
allProviders := providers.All()
|
||||||
if servers.ProviderToServers == nil {
|
if servers.ProviderToServers == nil {
|
||||||
@@ -23,8 +37,7 @@ func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
servers.ProviderToServers[provider] = models.Servers{
|
servers.ProviderToServers[provider] = models.Servers{
|
||||||
Version: allProviderVersion,
|
Version: allProviderVersion,
|
||||||
Timestamp: allProviderTimestamp,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return servers
|
return servers
|
||||||
@@ -34,54 +47,54 @@ func Test_extractServersFromBytes(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
b []byte
|
b []byte
|
||||||
hardcoded models.AllServers
|
hardcodedVersions map[string]uint16
|
||||||
logged []string
|
logged []string
|
||||||
persisted models.AllServers
|
persisted models.AllServers
|
||||||
errMessage string
|
errMessage string
|
||||||
}{
|
}{
|
||||||
"bad JSON": {
|
"bad JSON": {
|
||||||
b: []byte("garbage"),
|
b: []byte("garbage"),
|
||||||
errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value",
|
errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value",
|
||||||
},
|
},
|
||||||
"bad provider JSON": {
|
"bad provider JSON": {
|
||||||
b: []byte(`{"cyberghost": "garbage"}`),
|
b: []byte(`{"cyberghost": "garbage"}`),
|
||||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
|
||||||
errMessage: "cannot decode servers for provider: Cyberghost: " +
|
errMessage: "cannot decode servers for provider: Cyberghost: " +
|
||||||
"json: cannot unmarshal string into Go value of type models.Servers",
|
"json: cannot unmarshal string into Go value of type models.Servers",
|
||||||
},
|
},
|
||||||
"absent provider keys": {
|
"absent provider keys": {
|
||||||
b: []byte(`{}`),
|
b: []byte(`{}`),
|
||||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
|
||||||
persisted: models.AllServers{
|
persisted: models.AllServers{
|
||||||
ProviderToServers: map[string]models.Servers{},
|
ProviderToServers: map[string]models.Servers{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"same versions": {
|
"same versions": {
|
||||||
b: []byte(`{
|
b: []byte(`{
|
||||||
"cyberghost": {"version": 1, "timestamp": 1},
|
"cyberghost": {"version": 1, "timestamp": 0},
|
||||||
"expressvpn": {"version": 1, "timestamp": 1},
|
"expressvpn": {"version": 1, "timestamp": 0},
|
||||||
"fastestvpn": {"version": 1, "timestamp": 1},
|
"fastestvpn": {"version": 1, "timestamp": 0},
|
||||||
"hidemyass": {"version": 1, "timestamp": 1},
|
"hidemyass": {"version": 1, "timestamp": 0},
|
||||||
"ipvanish": {"version": 1, "timestamp": 1},
|
"ipvanish": {"version": 1, "timestamp": 0},
|
||||||
"ivpn": {"version": 1, "timestamp": 1},
|
"ivpn": {"version": 1, "timestamp": 0},
|
||||||
"mullvad": {"version": 1, "timestamp": 1},
|
"mullvad": {"version": 1, "timestamp": 0},
|
||||||
"nordvpn": {"version": 1, "timestamp": 1},
|
"nordvpn": {"version": 1, "timestamp": 0},
|
||||||
"perfect privacy": {"version": 1, "timestamp": 1},
|
"perfect privacy": {"version": 1, "timestamp": 0},
|
||||||
"privado": {"version": 1, "timestamp": 1},
|
"privado": {"version": 1, "timestamp": 0},
|
||||||
"private internet access": {"version": 1, "timestamp": 1},
|
"private internet access": {"version": 1, "timestamp": 0},
|
||||||
"privatevpn": {"version": 1, "timestamp": 1},
|
"privatevpn": {"version": 1, "timestamp": 0},
|
||||||
"protonvpn": {"version": 1, "timestamp": 1},
|
"protonvpn": {"version": 1, "timestamp": 0},
|
||||||
"purevpn": {"version": 1, "timestamp": 1},
|
"purevpn": {"version": 1, "timestamp": 0},
|
||||||
"surfshark": {"version": 1, "timestamp": 1},
|
"surfshark": {"version": 1, "timestamp": 0},
|
||||||
"torguard": {"version": 1, "timestamp": 1},
|
"torguard": {"version": 1, "timestamp": 0},
|
||||||
"vpn unlimited": {"version": 1, "timestamp": 1},
|
"vpn unlimited": {"version": 1, "timestamp": 0},
|
||||||
"vyprvpn": {"version": 1, "timestamp": 1},
|
"vyprvpn": {"version": 1, "timestamp": 0},
|
||||||
"wevpn": {"version": 1, "timestamp": 1},
|
"wevpn": {"version": 1, "timestamp": 0},
|
||||||
"windscribe": {"version": 1, "timestamp": 1}
|
"windscribe": {"version": 1, "timestamp": 0}
|
||||||
}`),
|
}`),
|
||||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
|
||||||
persisted: populateProviders(1, 1, models.AllServers{}),
|
persisted: populateAllServersVersion(1, models.AllServers{}),
|
||||||
},
|
},
|
||||||
"different versions": {
|
"different versions": {
|
||||||
b: []byte(`{
|
b: []byte(`{
|
||||||
@@ -106,7 +119,7 @@ func Test_extractServersFromBytes(t *testing.T) {
|
|||||||
"wevpn": {"version": 1, "timestamp": 1},
|
"wevpn": {"version": 1, "timestamp": 1},
|
||||||
"windscribe": {"version": 1, "timestamp": 1}
|
"windscribe": {"version": 1, "timestamp": 1}
|
||||||
}`),
|
}`),
|
||||||
hardcoded: populateProviders(2, 0, models.AllServers{}),
|
hardcodedVersions: populateProviderToVersion(2, map[string]uint16{}),
|
||||||
logged: []string{
|
logged: []string{
|
||||||
"Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
"Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||||
"Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
"Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||||
@@ -155,7 +168,7 @@ func Test_extractServersFromBytes(t *testing.T) {
|
|||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded)
|
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcodedVersions)
|
||||||
|
|
||||||
if testCase.errMessage != "" {
|
if testCase.errMessage != "" {
|
||||||
assert.EqualError(t, err, testCase.errMessage)
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
@@ -176,15 +189,13 @@ func Test_extractServersFromBytes(t *testing.T) {
|
|||||||
require.GreaterOrEqual(t, len(allProviders), 2)
|
require.GreaterOrEqual(t, len(allProviders), 2)
|
||||||
|
|
||||||
b := []byte(`{}`)
|
b := []byte(`{}`)
|
||||||
hardcoded := models.AllServers{
|
hardcodedVersions := map[string]uint16{
|
||||||
ProviderToServers: map[string]models.Servers{
|
allProviders[0]: 1,
|
||||||
allProviders[0]: {},
|
// Missing provider allProviders[1]
|
||||||
// Missing provider allProviders[1]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
|
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
|
||||||
assert.PanicsWithValue(t, expectedPanicValue, func() {
|
assert.PanicsWithValue(t, expectedPanicValue, func() {
|
||||||
_, _ = s.extractServersFromBytes(b, hardcoded)
|
_, _ = s.extractServersFromBytes(b, hardcodedVersions)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ import (
|
|||||||
//go:generate mockgen -destination=infoerrorer_mock_test.go -package $GOPACKAGE . InfoErrorer
|
//go:generate mockgen -destination=infoerrorer_mock_test.go -package $GOPACKAGE . InfoErrorer
|
||||||
|
|
||||||
type Storage struct {
|
type Storage struct {
|
||||||
mergedServers models.AllServers
|
mergedServers models.AllServers
|
||||||
|
// this is stored in memory to avoid re-parsing
|
||||||
|
// the embedded JSON file on every call to the
|
||||||
|
// SyncServers method.
|
||||||
hardcodedServers models.AllServers
|
hardcodedServers models.AllServers
|
||||||
logger Infoer
|
logger Infoer
|
||||||
filepath string
|
filepath string
|
||||||
@@ -22,11 +25,12 @@ type Infoer interface {
|
|||||||
// embedded servers file and the file on disk.
|
// embedded servers file and the file on disk.
|
||||||
// Passing an empty filepath disables writing servers to a file.
|
// Passing an empty filepath disables writing servers to a file.
|
||||||
func New(logger Infoer, filepath string) (storage *Storage, err error) {
|
func New(logger Infoer, filepath string) (storage *Storage, err error) {
|
||||||
// error returned covered by unit test
|
// A unit test prevents any error from being returned
|
||||||
harcodedServers, _ := parseHardcodedServers()
|
// and ensures all providers are part of the servers returned.
|
||||||
|
hardcodedServers, _ := parseHardcodedServers()
|
||||||
|
|
||||||
storage = &Storage{
|
storage = &Storage{
|
||||||
hardcodedServers: harcodedServers,
|
hardcodedServers: hardcodedServers,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
filepath: filepath,
|
filepath: filepath,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,14 @@ func countServers(allServers models.AllServers) (count int) {
|
|||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SyncServers merges the hardcoded servers with the ones from the file.
|
||||||
func (s *Storage) SyncServers() (err error) {
|
func (s *Storage) SyncServers() (err error) {
|
||||||
serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers)
|
hardcodedVersions := make(map[string]uint16, len(s.hardcodedServers.ProviderToServers))
|
||||||
|
for provider, servers := range s.hardcodedServers.ProviderToServers {
|
||||||
|
hardcodedVersions[provider] = servers.Version
|
||||||
|
}
|
||||||
|
|
||||||
|
serversOnFile, err := s.readFromFile(s.filepath, hardcodedVersions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot read servers from file: %w", err)
|
return fmt.Errorf("cannot read servers from file: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user