chore(all): memory and thread safe storage

- settings: get filter choices from storage for settings validation
- updater: update servers to the storage
- storage: minimal deep copying and data duplication
- storage: add merged servers mutex for thread safety
- connection: filter servers in storage
- formatter: format servers to Markdown in storage
- PIA: get server by name from storage directly
- Updater: get servers count from storage directly
- Updater: equality check done in storage, fix #882
This commit is contained in:
Quentin McGaw
2022-06-05 14:58:46 +00:00
parent 1e6b4ed5eb
commit 36b504609b
84 changed files with 1267 additions and 877 deletions

View File

@@ -0,0 +1,27 @@
package storage
import (
"github.com/qdm12/gluetun/internal/configuration/settings/validation"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
)
func (s *Storage) GetFilterChoices(provider string) models.FilterChoices {
if provider == providers.Custom {
return models.FilterChoices{}
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
servers := serversObject.Servers
return models.FilterChoices{
Countries: validation.ExtractCountries(servers),
Regions: validation.ExtractRegions(servers),
Cities: validation.ExtractCities(servers),
ISPs: validation.ExtractISPs(servers),
Names: validation.ExtractServerNames(servers),
Hostnames: validation.ExtractHostnames(servers),
}
}

32
internal/storage/copy.go Normal file
View File

@@ -0,0 +1,32 @@
package storage
import (
"net"
"github.com/qdm12/gluetun/internal/models"
)
func copyServer(server models.Server) (serverCopy models.Server) {
serverCopy = server
serverCopy.IPs = copyIPs(server.IPs)
return serverCopy
}
func copyIPs(toCopy []net.IP) (copied []net.IP) {
if toCopy == nil {
return nil
}
copied = make([]net.IP, len(toCopy))
for i := range toCopy {
copied[i] = copyIP(toCopy[i])
}
return copied
}
func copyIP(toCopy net.IP) (copied net.IP) {
copied = make(net.IP, len(toCopy))
copy(copied, toCopy)
return copied
}

View File

@@ -0,0 +1,75 @@
package storage
import (
"net"
"testing"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_copyServer(t *testing.T) {
t.Parallel()
server := models.Server{
Country: "a",
IPs: []net.IP{{1, 2, 3, 4}},
}
serverCopy := copyServer(server)
assert.Equal(t, server, serverCopy)
// Check for mutation
serverCopy.IPs[0][0] = 9
assert.NotEqual(t, server, serverCopy)
}
func Test_copyIPs(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
toCopy []net.IP
copied []net.IP
}{
"nil": {},
"empty": {
toCopy: []net.IP{},
copied: []net.IP{},
},
"single IP": {
toCopy: []net.IP{{1, 1, 1, 1}},
copied: []net.IP{{1, 1, 1, 1}},
},
"two IPs": {
toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
// Reserver leading 9 for copy modifications below
for _, ipToCopy := range testCase.toCopy {
require.NotEqual(t, 9, ipToCopy[0])
}
copied := copyIPs(testCase.toCopy)
assert.Equal(t, testCase.copied, copied)
if len(copied) > 0 {
original := testCase.toCopy[0][0]
testCase.toCopy[0][0] = 9
assert.NotEqual(t, 9, copied[0][0])
testCase.toCopy[0][0] = original
copied[0][0] = 9
assert.NotEqual(t, 9, testCase.toCopy[0][0])
}
})
}
}

143
internal/storage/filter.go Normal file
View File

@@ -0,0 +1,143 @@
package storage
import (
"strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
)
// FilterServers filter servers for the given provider and according
// to the given selection. The filtered servers are deep copied so they
// are safe for mutation by the caller.
func (s *Storage) FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error) {
if provider == providers.Custom {
return nil, nil
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
allServers := serversObject.Servers
if len(allServers) == 0 {
return nil, ErrNoServerFound
}
for _, server := range allServers {
if filterServer(server, selection) {
continue
}
server = copyServer(server)
servers = append(servers, server)
}
if len(servers) == 0 {
return nil, noServerFoundError(selection)
}
return servers, nil
}
func filterServer(server models.Server,
selection settings.ServerSelection) (filtered bool) {
// Note each condition is split to make sure
// we have full testing coverage.
if server.VPN != selection.VPN {
return true
}
if filterByProtocol(selection, server.TCP, server.UDP) {
return true
}
if *selection.MultiHopOnly && !server.MultiHop {
return true
}
if *selection.FreeOnly && !server.Free {
return true
}
if *selection.StreamOnly && !server.Stream {
return true
}
if *selection.OwnedOnly && !server.Owned {
return true
}
if filterByPossibilities(server.Country, selection.Countries) {
return true
}
if filterByPossibilities(server.Region, selection.Regions) {
return true
}
if filterByPossibilities(server.City, selection.Cities) {
return true
}
if filterByPossibilities(server.ISP, selection.ISPs) {
return true
}
if filterByPossibilitiesUint16(server.Number, selection.Numbers) {
return true
}
if filterByPossibilities(server.ServerName, selection.Names) {
return true
}
if filterByPossibilities(server.Hostname, selection.Hostnames) {
return true
}
// TODO filter port forward server for PIA
return false
}
func filterByPossibilities(value string, possibilities []string) (filtered bool) {
if len(possibilities) == 0 {
return false
}
for _, possibility := range possibilities {
if strings.EqualFold(value, possibility) {
return false
}
}
return true
}
// TODO merge with filterByPossibilities with generics in Go 1.18.
func filterByPossibilitiesUint16(value uint16, possibilities []uint16) (filtered bool) {
if len(possibilities) == 0 {
return false
}
for _, possibility := range possibilities {
if value == possibility {
return false
}
}
return true
}
func filterByProtocol(selection settings.ServerSelection,
serverTCP, serverUDP bool) (filtered bool) {
switch selection.VPN {
case vpn.Wireguard:
return !serverUDP
default: // OpenVPN
wantTCP := *selection.OpenVPN.TCP
wantUDP := !wantTCP
return (wantTCP && !serverTCP) || (wantUDP && !serverUDP)
}
}

View File

@@ -4,21 +4,20 @@ import (
"encoding/json"
"os"
"path/filepath"
"github.com/qdm12/gluetun/internal/models"
)
var _ Flusher = (*Storage)(nil)
// FlushToFile flushes the merged servers data to the file
// specified by path, as indented JSON.
func (s *Storage) FlushToFile(path string) error {
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
type Flusher interface {
FlushToFile(allServers *models.AllServers) error
return s.flushToFile(path)
}
func (s *Storage) FlushToFile(allServers *models.AllServers) error {
return flushToFile(s.filepath, allServers)
}
func flushToFile(path string, servers *models.AllServers) error {
// flushToFile flushes the merged servers data to the file
// specified by path, as indented JSON. It is not thread-safe.
func (s *Storage) flushToFile(path string) error {
dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, 0644); err != nil {
return err
@@ -28,11 +27,15 @@ func flushToFile(path string, servers *models.AllServers) error {
if err != nil {
return err
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(servers); err != nil {
err = encoder.Encode(&s.mergedServers)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -0,0 +1,120 @@
package storage
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
)
func commaJoin(slice []string) string {
return strings.Join(slice, ", ")
}
var ErrNoServerFound = errors.New("no server found")
func noServerFoundError(selection settings.ServerSelection) (err error) {
var messageParts []string
messageParts = append(messageParts, "VPN "+selection.VPN)
protocol := constants.UDP
if *selection.OpenVPN.TCP {
protocol = constants.TCP
}
messageParts = append(messageParts, "protocol "+protocol)
switch len(selection.Countries) {
case 0:
case 1:
part := "country " + selection.Countries[0]
messageParts = append(messageParts, part)
default:
part := "countries " + commaJoin(selection.Countries)
messageParts = append(messageParts, part)
}
switch len(selection.Regions) {
case 0:
case 1:
part := "region " + selection.Regions[0]
messageParts = append(messageParts, part)
default:
part := "regions " + commaJoin(selection.Regions)
messageParts = append(messageParts, part)
}
switch len(selection.Cities) {
case 0:
case 1:
part := "city " + selection.Cities[0]
messageParts = append(messageParts, part)
default:
part := "cities " + commaJoin(selection.Cities)
messageParts = append(messageParts, part)
}
if *selection.OwnedOnly {
messageParts = append(messageParts, "owned servers only")
}
switch len(selection.ISPs) {
case 0:
case 1:
part := "ISP " + selection.ISPs[0]
messageParts = append(messageParts, part)
default:
part := "ISPs " + commaJoin(selection.ISPs)
messageParts = append(messageParts, part)
}
switch len(selection.Hostnames) {
case 0:
case 1:
part := "hostname " + selection.Hostnames[0]
messageParts = append(messageParts, part)
default:
part := "hostnames " + commaJoin(selection.Hostnames)
messageParts = append(messageParts, part)
}
switch len(selection.Names) {
case 0:
case 1:
part := "name " + selection.Names[0]
messageParts = append(messageParts, part)
default:
part := "names " + commaJoin(selection.Names)
messageParts = append(messageParts, part)
}
switch len(selection.Numbers) {
case 0:
case 1:
part := "server number " + strconv.Itoa(int(selection.Numbers[0]))
messageParts = append(messageParts, part)
default:
serverNumbers := make([]string, len(selection.Numbers))
for i := range selection.Numbers {
serverNumbers[i] = strconv.Itoa(int(selection.Numbers[i]))
}
part := "server numbers " + commaJoin(serverNumbers)
messageParts = append(messageParts, part)
}
if *selection.OpenVPN.PIAEncPreset != "" {
part := "encryption preset " + *selection.OpenVPN.PIAEncPreset
messageParts = append(messageParts, part)
}
if *selection.FreeOnly {
messageParts = append(messageParts, "free tier only")
}
message := "for " + strings.Join(messageParts, "; ")
return fmt.Errorf("%w: %s", ErrNoServerFound, message)
}

View File

@@ -1,7 +1,118 @@
package storage
import "github.com/qdm12/gluetun/internal/models"
import (
"fmt"
"time"
func (s *Storage) GetServers() models.AllServers {
return s.mergedServers.GetCopy()
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/models"
)
// SetServers sets the given servers for the given provider
// in the storage in-memory map and saves all the servers
// to file.
// Note the servers given are not copied so the caller must
// NOT MUTATE them after calling this method.
func (s *Storage) SetServers(provider string, servers []models.Server) (err error) {
if provider == providers.Custom {
return
}
s.mergedMutex.Lock()
defer s.mergedMutex.Unlock()
serversObject := s.getMergedServersObject(provider)
serversObject.Timestamp = time.Now().Unix()
serversObject.Servers = servers
s.mergedServers.ProviderToServers[provider] = serversObject
err = s.flushToFile(s.filepath)
if err != nil {
return fmt.Errorf("cannot save servers to file: %w", err)
}
return nil
}
// GetServerByName returns the server for the given provider
// and server name. It returns `ok` as false if the server is
// not found. The returned server is also deep copied so it is
// safe for mutation and/or thread safe use.
func (s *Storage) GetServerByName(provider, name string) (
server models.Server, ok bool) {
if provider == providers.Custom {
return server, false
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
for _, server := range serversObject.Servers {
if server.ServerName == name {
return copyServer(server), true
}
}
return server, false
}
// GetServersCount returns the number of servers for the provider given.
func (s *Storage) GetServersCount(provider string) (count int) {
if provider == providers.Custom {
return 0
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
return len(serversObject.Servers)
}
// FormatToMarkdown Markdown formats the servers for the provider given
// and returns the resulting string.
func (s *Storage) FormatToMarkdown(provider string) (formatted string) {
if provider == providers.Custom {
return ""
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
formatted = serversObject.ToMarkdown(provider)
return formatted
}
// GetServersCount returns the number of servers for the provider given.
func (s *Storage) ServersAreEqual(provider string, servers []models.Server) (equal bool) {
if provider == providers.Custom {
return true
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
existingServers := serversObject.Servers
if len(existingServers) != len(servers) {
return false
}
for i := range existingServers {
if !existingServers[i].Equal(servers[i]) {
return false
}
}
return true
}
func (s *Storage) getMergedServersObject(provider string) (serversObject models.Servers) {
serversObject, ok := s.mergedServers.ProviderToServers[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in in-memory servers map", provider))
}
return serversObject
}

View File

@@ -2,11 +2,14 @@
package storage
import (
"sync"
"github.com/qdm12/gluetun/internal/models"
)
type Storage struct {
mergedServers models.AllServers
mergedMutex sync.RWMutex
// this is stored in memory to avoid re-parsing
// the embedded JSON file on every call to the
// SyncServers method.

View File

@@ -29,6 +29,9 @@ func (s *Storage) syncServers() (err error) {
hardcodedCount := countServers(s.hardcodedServers)
countOnFile := countServers(serversOnFile)
s.mergedMutex.Lock()
defer s.mergedMutex.Unlock()
if countOnFile == 0 {
s.logger.Info(fmt.Sprintf(
"creating %s with %d hardcoded servers",
@@ -47,7 +50,8 @@ func (s *Storage) syncServers() (err error) {
return nil
}
if err := flushToFile(s.filepath, &s.mergedServers); err != nil {
err = s.flushToFile(s.filepath)
if err != nil {
return fmt.Errorf("cannot write servers to file: %w", err)
}
return nil