chore(all): memory and thread safe storage
- settings: get filter choices from storage for settings validation - updater: update servers to the storage - storage: minimal deep copying and data duplication - storage: add merged servers mutex for thread safety - connection: filter servers in storage - formatter: format servers to Markdown in storage - PIA: get server by name from storage directly - Updater: get servers count from storage directly - Updater: equality check done in storage, fix #882
This commit is contained in:
27
internal/storage/choices.go
Normal file
27
internal/storage/choices.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings/validation"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func (s *Storage) GetFilterChoices(provider string) models.FilterChoices {
|
||||
if provider == providers.Custom {
|
||||
return models.FilterChoices{}
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
servers := serversObject.Servers
|
||||
return models.FilterChoices{
|
||||
Countries: validation.ExtractCountries(servers),
|
||||
Regions: validation.ExtractRegions(servers),
|
||||
Cities: validation.ExtractCities(servers),
|
||||
ISPs: validation.ExtractISPs(servers),
|
||||
Names: validation.ExtractServerNames(servers),
|
||||
Hostnames: validation.ExtractHostnames(servers),
|
||||
}
|
||||
}
|
||||
32
internal/storage/copy.go
Normal file
32
internal/storage/copy.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
func copyServer(server models.Server) (serverCopy models.Server) {
|
||||
serverCopy = server
|
||||
serverCopy.IPs = copyIPs(server.IPs)
|
||||
return serverCopy
|
||||
}
|
||||
|
||||
func copyIPs(toCopy []net.IP) (copied []net.IP) {
|
||||
if toCopy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
copied = make([]net.IP, len(toCopy))
|
||||
for i := range toCopy {
|
||||
copied[i] = copyIP(toCopy[i])
|
||||
}
|
||||
|
||||
return copied
|
||||
}
|
||||
|
||||
func copyIP(toCopy net.IP) (copied net.IP) {
|
||||
copied = make(net.IP, len(toCopy))
|
||||
copy(copied, toCopy)
|
||||
return copied
|
||||
}
|
||||
75
internal/storage/copy_test.go
Normal file
75
internal/storage/copy_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_copyServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := models.Server{
|
||||
Country: "a",
|
||||
IPs: []net.IP{{1, 2, 3, 4}},
|
||||
}
|
||||
|
||||
serverCopy := copyServer(server)
|
||||
|
||||
assert.Equal(t, server, serverCopy)
|
||||
// Check for mutation
|
||||
serverCopy.IPs[0][0] = 9
|
||||
assert.NotEqual(t, server, serverCopy)
|
||||
}
|
||||
|
||||
func Test_copyIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
toCopy []net.IP
|
||||
copied []net.IP
|
||||
}{
|
||||
"nil": {},
|
||||
"empty": {
|
||||
toCopy: []net.IP{},
|
||||
copied: []net.IP{},
|
||||
},
|
||||
"single IP": {
|
||||
toCopy: []net.IP{{1, 1, 1, 1}},
|
||||
copied: []net.IP{{1, 1, 1, 1}},
|
||||
},
|
||||
"two IPs": {
|
||||
toCopy: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
|
||||
copied: []net.IP{{1, 1, 1, 1}, {2, 2, 2, 2}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Reserver leading 9 for copy modifications below
|
||||
for _, ipToCopy := range testCase.toCopy {
|
||||
require.NotEqual(t, 9, ipToCopy[0])
|
||||
}
|
||||
|
||||
copied := copyIPs(testCase.toCopy)
|
||||
|
||||
assert.Equal(t, testCase.copied, copied)
|
||||
|
||||
if len(copied) > 0 {
|
||||
original := testCase.toCopy[0][0]
|
||||
testCase.toCopy[0][0] = 9
|
||||
assert.NotEqual(t, 9, copied[0][0])
|
||||
testCase.toCopy[0][0] = original
|
||||
|
||||
copied[0][0] = 9
|
||||
assert.NotEqual(t, 9, testCase.toCopy[0][0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
143
internal/storage/filter.go
Normal file
143
internal/storage/filter.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
// FilterServers filter servers for the given provider and according
|
||||
// to the given selection. The filtered servers are deep copied so they
|
||||
// are safe for mutation by the caller.
|
||||
func (s *Storage) FilterServers(provider string, selection settings.ServerSelection) (
|
||||
servers []models.Server, err error) {
|
||||
if provider == providers.Custom {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
allServers := serversObject.Servers
|
||||
|
||||
if len(allServers) == 0 {
|
||||
return nil, ErrNoServerFound
|
||||
}
|
||||
|
||||
for _, server := range allServers {
|
||||
if filterServer(server, selection) {
|
||||
continue
|
||||
}
|
||||
|
||||
server = copyServer(server)
|
||||
servers = append(servers, server)
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
return nil, noServerFoundError(selection)
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
func filterServer(server models.Server,
|
||||
selection settings.ServerSelection) (filtered bool) {
|
||||
// Note each condition is split to make sure
|
||||
// we have full testing coverage.
|
||||
if server.VPN != selection.VPN {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByProtocol(selection, server.TCP, server.UDP) {
|
||||
return true
|
||||
}
|
||||
|
||||
if *selection.MultiHopOnly && !server.MultiHop {
|
||||
return true
|
||||
}
|
||||
|
||||
if *selection.FreeOnly && !server.Free {
|
||||
return true
|
||||
}
|
||||
|
||||
if *selection.StreamOnly && !server.Stream {
|
||||
return true
|
||||
}
|
||||
|
||||
if *selection.OwnedOnly && !server.Owned {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.Country, selection.Countries) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.Region, selection.Regions) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.City, selection.Cities) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.ISP, selection.ISPs) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilitiesUint16(server.Number, selection.Numbers) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.ServerName, selection.Names) {
|
||||
return true
|
||||
}
|
||||
|
||||
if filterByPossibilities(server.Hostname, selection.Hostnames) {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO filter port forward server for PIA
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func filterByPossibilities(value string, possibilities []string) (filtered bool) {
|
||||
if len(possibilities) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, possibility := range possibilities {
|
||||
if strings.EqualFold(value, possibility) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO merge with filterByPossibilities with generics in Go 1.18.
|
||||
func filterByPossibilitiesUint16(value uint16, possibilities []uint16) (filtered bool) {
|
||||
if len(possibilities) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, possibility := range possibilities {
|
||||
if value == possibility {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func filterByProtocol(selection settings.ServerSelection,
|
||||
serverTCP, serverUDP bool) (filtered bool) {
|
||||
switch selection.VPN {
|
||||
case vpn.Wireguard:
|
||||
return !serverUDP
|
||||
default: // OpenVPN
|
||||
wantTCP := *selection.OpenVPN.TCP
|
||||
wantUDP := !wantTCP
|
||||
return (wantTCP && !serverTCP) || (wantUDP && !serverUDP)
|
||||
}
|
||||
}
|
||||
@@ -4,21 +4,20 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
var _ Flusher = (*Storage)(nil)
|
||||
// FlushToFile flushes the merged servers data to the file
|
||||
// specified by path, as indented JSON.
|
||||
func (s *Storage) FlushToFile(path string) error {
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
type Flusher interface {
|
||||
FlushToFile(allServers *models.AllServers) error
|
||||
return s.flushToFile(path)
|
||||
}
|
||||
|
||||
func (s *Storage) FlushToFile(allServers *models.AllServers) error {
|
||||
return flushToFile(s.filepath, allServers)
|
||||
}
|
||||
|
||||
func flushToFile(path string, servers *models.AllServers) error {
|
||||
// flushToFile flushes the merged servers data to the file
|
||||
// specified by path, as indented JSON. It is not thread-safe.
|
||||
func (s *Storage) flushToFile(path string) error {
|
||||
dirPath := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dirPath, 0644); err != nil {
|
||||
return err
|
||||
@@ -28,11 +27,15 @@ func flushToFile(path string, servers *models.AllServers) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(servers); err != nil {
|
||||
|
||||
err = encoder.Encode(&s.mergedServers)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
120
internal/storage/formatting.go
Normal file
120
internal/storage/formatting.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
func commaJoin(slice []string) string {
|
||||
return strings.Join(slice, ", ")
|
||||
}
|
||||
|
||||
var ErrNoServerFound = errors.New("no server found")
|
||||
|
||||
func noServerFoundError(selection settings.ServerSelection) (err error) {
|
||||
var messageParts []string
|
||||
|
||||
messageParts = append(messageParts, "VPN "+selection.VPN)
|
||||
|
||||
protocol := constants.UDP
|
||||
if *selection.OpenVPN.TCP {
|
||||
protocol = constants.TCP
|
||||
}
|
||||
messageParts = append(messageParts, "protocol "+protocol)
|
||||
|
||||
switch len(selection.Countries) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "country " + selection.Countries[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "countries " + commaJoin(selection.Countries)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Regions) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "region " + selection.Regions[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "regions " + commaJoin(selection.Regions)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Cities) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "city " + selection.Cities[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "cities " + commaJoin(selection.Cities)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
if *selection.OwnedOnly {
|
||||
messageParts = append(messageParts, "owned servers only")
|
||||
}
|
||||
|
||||
switch len(selection.ISPs) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "ISP " + selection.ISPs[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "ISPs " + commaJoin(selection.ISPs)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Hostnames) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "hostname " + selection.Hostnames[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "hostnames " + commaJoin(selection.Hostnames)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Names) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "name " + selection.Names[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "names " + commaJoin(selection.Names)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Numbers) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "server number " + strconv.Itoa(int(selection.Numbers[0]))
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
serverNumbers := make([]string, len(selection.Numbers))
|
||||
for i := range selection.Numbers {
|
||||
serverNumbers[i] = strconv.Itoa(int(selection.Numbers[i]))
|
||||
}
|
||||
part := "server numbers " + commaJoin(serverNumbers)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
if *selection.OpenVPN.PIAEncPreset != "" {
|
||||
part := "encryption preset " + *selection.OpenVPN.PIAEncPreset
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
if *selection.FreeOnly {
|
||||
messageParts = append(messageParts, "free tier only")
|
||||
}
|
||||
|
||||
message := "for " + strings.Join(messageParts, "; ")
|
||||
|
||||
return fmt.Errorf("%w: %s", ErrNoServerFound, message)
|
||||
}
|
||||
@@ -1,7 +1,118 @@
|
||||
package storage
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/models"
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
func (s *Storage) GetServers() models.AllServers {
|
||||
return s.mergedServers.GetCopy()
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
// SetServers sets the given servers for the given provider
|
||||
// in the storage in-memory map and saves all the servers
|
||||
// to file.
|
||||
// Note the servers given are not copied so the caller must
|
||||
// NOT MUTATE them after calling this method.
|
||||
func (s *Storage) SetServers(provider string, servers []models.Server) (err error) {
|
||||
if provider == providers.Custom {
|
||||
return
|
||||
}
|
||||
|
||||
s.mergedMutex.Lock()
|
||||
defer s.mergedMutex.Unlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
serversObject.Timestamp = time.Now().Unix()
|
||||
serversObject.Servers = servers
|
||||
s.mergedServers.ProviderToServers[provider] = serversObject
|
||||
|
||||
err = s.flushToFile(s.filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot save servers to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServerByName returns the server for the given provider
|
||||
// and server name. It returns `ok` as false if the server is
|
||||
// not found. The returned server is also deep copied so it is
|
||||
// safe for mutation and/or thread safe use.
|
||||
func (s *Storage) GetServerByName(provider, name string) (
|
||||
server models.Server, ok bool) {
|
||||
if provider == providers.Custom {
|
||||
return server, false
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
for _, server := range serversObject.Servers {
|
||||
if server.ServerName == name {
|
||||
return copyServer(server), true
|
||||
}
|
||||
}
|
||||
|
||||
return server, false
|
||||
}
|
||||
|
||||
// GetServersCount returns the number of servers for the provider given.
|
||||
func (s *Storage) GetServersCount(provider string) (count int) {
|
||||
if provider == providers.Custom {
|
||||
return 0
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
return len(serversObject.Servers)
|
||||
}
|
||||
|
||||
// FormatToMarkdown Markdown formats the servers for the provider given
|
||||
// and returns the resulting string.
|
||||
func (s *Storage) FormatToMarkdown(provider string) (formatted string) {
|
||||
if provider == providers.Custom {
|
||||
return ""
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
formatted = serversObject.ToMarkdown(provider)
|
||||
return formatted
|
||||
}
|
||||
|
||||
// GetServersCount returns the number of servers for the provider given.
|
||||
func (s *Storage) ServersAreEqual(provider string, servers []models.Server) (equal bool) {
|
||||
if provider == providers.Custom {
|
||||
return true
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
existingServers := serversObject.Servers
|
||||
|
||||
if len(existingServers) != len(servers) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range existingServers {
|
||||
if !existingServers[i].Equal(servers[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Storage) getMergedServersObject(provider string) (serversObject models.Servers) {
|
||||
serversObject, ok := s.mergedServers.ProviderToServers[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s not found in in-memory servers map", provider))
|
||||
}
|
||||
return serversObject
|
||||
}
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
mergedServers models.AllServers
|
||||
mergedMutex sync.RWMutex
|
||||
// this is stored in memory to avoid re-parsing
|
||||
// the embedded JSON file on every call to the
|
||||
// SyncServers method.
|
||||
|
||||
@@ -29,6 +29,9 @@ func (s *Storage) syncServers() (err error) {
|
||||
hardcodedCount := countServers(s.hardcodedServers)
|
||||
countOnFile := countServers(serversOnFile)
|
||||
|
||||
s.mergedMutex.Lock()
|
||||
defer s.mergedMutex.Unlock()
|
||||
|
||||
if countOnFile == 0 {
|
||||
s.logger.Info(fmt.Sprintf(
|
||||
"creating %s with %d hardcoded servers",
|
||||
@@ -47,7 +50,8 @@ func (s *Storage) syncServers() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := flushToFile(s.filepath, &s.mergedServers); err != nil {
|
||||
err = s.flushToFile(s.filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot write servers to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user