diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index e0988c8e..8f43a154 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -122,9 +122,8 @@ func _main(background context.Context, buildInfo models.BuildInformation, } // TODO run this in a loop or in openvpn to reload from file without restarting - storage := storage.New(logger, os) - const updateServerFile = true - allServers, err := storage.SyncServers(constants.GetAllServers(), updateServerFile) + storage := storage.New(logger, os, constants.ServersData) + allServers, err := storage.SyncServers(constants.GetAllServers()) if err != nil { logger.Error(err) return 1 diff --git a/internal/cli/cli.go b/internal/cli/cli.go index f31d3ce2..6fb90b3a 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -70,7 +70,8 @@ func OpenvpnConfig(os os.OS) error { if err != nil { return err } - allServers, err := storage.New(logger, os).SyncServers(constants.GetAllServers(), false) + allServers, err := storage.New(logger, os, constants.ServersData). + SyncServers(constants.GetAllServers()) if err != nil { return err } @@ -121,9 +122,8 @@ func Update(args []string, os os.OS) error { ctx := context.Background() const clientTimeout = 10 * time.Second httpClient := &http.Client{Timeout: clientTimeout} - storage := storage.New(logger, os) - const writeSync = false - currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync) + storage := storage.New(logger, os, constants.ServersData) + currentServers, err := storage.SyncServers(constants.GetAllServers()) if err != nil { return fmt.Errorf("cannot update servers: %w", err) } diff --git a/internal/constants/paths.go b/internal/constants/paths.go index 135a5b97..9f985906 100644 --- a/internal/constants/paths.go +++ b/internal/constants/paths.go @@ -29,4 +29,6 @@ const ( ClientKey models.Filepath = "/gluetun/client.key" // Client certificate filepath, used by Cyberghost. ClientCertificate models.Filepath = "/gluetun/client.crt" + // Servers information filepath. + ServersData = "/gluetun/servers.json" ) diff --git a/internal/storage/merge.go b/internal/storage/merge.go index 928e5770..8d84d2fc 100644 --- a/internal/storage/merge.go +++ b/internal/storage/merge.go @@ -14,80 +14,119 @@ func getUnixTimeDifference(unix1, unix2 int64) (difference time.Duration) { return difference.Truncate(time.Second) } -func (s *storage) mergeServers(hardcoded, persistent models.AllServers) (merged models.AllServers) { - merged.Version = hardcoded.Version - merged.Cyberghost = hardcoded.Cyberghost - if persistent.Cyberghost.Timestamp > hardcoded.Cyberghost.Timestamp { - s.logger.Info("Using Cyberghost servers from file (%s more recent)", - getUnixTimeDifference(persistent.Cyberghost.Timestamp, hardcoded.Cyberghost.Timestamp)) - merged.Cyberghost = persistent.Cyberghost +func (s *storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers { + return models.AllServers{ + Version: hardcoded.Version, + Cyberghost: s.mergeCyberghost(hardcoded.Cyberghost, persisted.Cyberghost), + Mullvad: s.mergeMullvad(hardcoded.Mullvad, persisted.Mullvad), + Nordvpn: s.mergeNordVPN(hardcoded.Nordvpn, persisted.Nordvpn), + Pia: s.mergePIA(hardcoded.Pia, persisted.Pia), + Privado: s.mergePrivado(hardcoded.Privado, persisted.Privado), + Purevpn: s.mergePureVPN(hardcoded.Purevpn, persisted.Purevpn), + Surfshark: s.mergeSurfshark(hardcoded.Surfshark, persisted.Surfshark), + Vyprvpn: s.mergeVyprvpn(hardcoded.Vyprvpn, persisted.Vyprvpn), + Windscribe: s.mergeWindscribe(hardcoded.Windscribe, persisted.Windscribe), } - merged.Mullvad = hardcoded.Mullvad - if persistent.Mullvad.Timestamp > hardcoded.Mullvad.Timestamp { - s.logger.Info("Using Mullvad servers from file (%s more recent)", - getUnixTimeDifference(persistent.Mullvad.Timestamp, hardcoded.Mullvad.Timestamp)) - merged.Mullvad = persistent.Mullvad - } - merged.Nordvpn = hardcoded.Nordvpn - if persistent.Nordvpn.Timestamp > hardcoded.Nordvpn.Timestamp { - s.logger.Info("Using Nordvpn servers from file (%s more recent)", - getUnixTimeDifference(persistent.Nordvpn.Timestamp, hardcoded.Nordvpn.Timestamp)) - merged.Nordvpn = persistent.Nordvpn - } - merged.Pia = hardcoded.Pia - if persistent.Pia.Timestamp > hardcoded.Pia.Timestamp { - versionDiff := hardcoded.Pia.Version - persistent.Pia.Version - if versionDiff > 0 { - s.logger.Info("Private Internet Access servers from file discarded because they are %d versions behind", - versionDiff) - merged.Pia = hardcoded.Pia - } else { - s.logger.Info("Using Private Internet Access servers from file (%s more recent)", - getUnixTimeDifference(persistent.Pia.Timestamp, hardcoded.Pia.Timestamp)) - merged.Pia = persistent.Pia - } - } - - merged.Privado = hardcoded.Privado - versionDiff := int(persistent.Privado.Version) - int(hardcoded.Privado.Version) - switch { - case versionDiff > 0: - s.logger.Info("Using Privado servers from file (%d version(s) more recent)", versionDiff) - merged.Privado = persistent.Privado - case persistent.Privado.Timestamp > hardcoded.Privado.Timestamp: - s.logger.Info("Using Privado servers from file (%s more recent)", - getUnixTimeDifference(persistent.Privado.Timestamp, hardcoded.Privado.Timestamp)) - merged.Privado = persistent.Privado - } - - merged.Purevpn = hardcoded.Purevpn - if persistent.Purevpn.Timestamp > hardcoded.Purevpn.Timestamp { - s.logger.Info("Using Purevpn servers from file (%s more recent)", - getUnixTimeDifference(persistent.Purevpn.Timestamp, hardcoded.Purevpn.Timestamp)) - merged.Purevpn = persistent.Purevpn - } - merged.Surfshark = hardcoded.Surfshark - if persistent.Surfshark.Timestamp > hardcoded.Surfshark.Timestamp { - s.logger.Info("Using Surfshark servers from file (%s more recent)", - getUnixTimeDifference(persistent.Surfshark.Timestamp, hardcoded.Surfshark.Timestamp)) - merged.Surfshark = persistent.Surfshark - } - merged.Vyprvpn = hardcoded.Vyprvpn - if persistent.Vyprvpn.Timestamp > hardcoded.Vyprvpn.Timestamp { - s.logger.Info("Using Vyprvpn servers from file (%s more recent)", - getUnixTimeDifference(persistent.Vyprvpn.Timestamp, hardcoded.Vyprvpn.Timestamp)) - merged.Vyprvpn = persistent.Vyprvpn - } - merged.Windscribe = hardcoded.Windscribe - if persistent.Windscribe.Timestamp > hardcoded.Windscribe.Timestamp { - if hardcoded.Windscribe.Version == 2 && persistent.Windscribe.Version == 1 { - s.logger.Info("Windscribe servers from file discarded because they are one version behind") - merged.Windscribe = hardcoded.Windscribe - } else { - s.logger.Info("Using Windscribe servers from file (%s more recent)", - getUnixTimeDifference(persistent.Windscribe.Timestamp, hardcoded.Windscribe.Timestamp)) - merged.Windscribe = persistent.Windscribe - } - } - return merged +} + +func (s *storage) mergeCyberghost(hardcoded, persisted models.CyberghostServers) models.CyberghostServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + s.logger.Info("Using Cyberghost servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted +} + +func (s *storage) mergeMullvad(hardcoded, persisted models.MullvadServers) models.MullvadServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + s.logger.Info("Using Mullvad servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted +} + +func (s *storage) mergeNordVPN(hardcoded, persisted models.NordvpnServers) models.NordvpnServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + s.logger.Info("Using NordVPN servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted +} + +func (s *storage) mergePIA(hardcoded, persisted models.PiaServers) models.PiaServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + versionDiff := hardcoded.Version - persisted.Version + if versionDiff > 0 { + s.logger.Info( + "PIA servers from file discarded because they are %d versions behind", + versionDiff) + return hardcoded + } + s.logger.Info("Using PIA servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted +} + +func (s *storage) mergePrivado(hardcoded, persisted models.PrivadoServers) models.PrivadoServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + versionDiff := hardcoded.Version - persisted.Version + if versionDiff > 0 { + s.logger.Info( + "Privado servers from file discarded because they are %d versions behind", + versionDiff) + return hardcoded + } + s.logger.Info("Using Privado servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted +} + +func (s *storage) mergePureVPN(hardcoded, persisted models.PurevpnServers) models.PurevpnServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + s.logger.Info("Using PureVPN servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted +} + +func (s *storage) mergeSurfshark(hardcoded, persisted models.SurfsharkServers) models.SurfsharkServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + s.logger.Info("Using Surfshark servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted +} + +func (s *storage) mergeVyprvpn(hardcoded, persisted models.VyprvpnServers) models.VyprvpnServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + s.logger.Info("Using VyprVPN servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted +} + +func (s *storage) mergeWindscribe(hardcoded, persisted models.WindscribeServers) models.WindscribeServers { + if persisted.Timestamp <= hardcoded.Timestamp { + return hardcoded + } + versionDiff := hardcoded.Version - persisted.Version + if versionDiff > 0 { + s.logger.Info( + "Windscribe servers from file discarded because they are %d versions behind", + versionDiff) + return hardcoded + } + s.logger.Info("Using Windscribe servers from file (%s more recent)", + getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp)) + return persisted } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 35ac4854..b5e02081 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -7,18 +7,21 @@ import ( ) type Storage interface { - SyncServers(hardcodedServers models.AllServers, write bool) (allServers models.AllServers, err error) + // Passing an empty filepath disables writing to a file + SyncServers(hardcodedServers models.AllServers) (allServers models.AllServers, err error) FlushToFile(servers models.AllServers) error } type storage struct { - os os.OS - logger logging.Logger + os os.OS + logger logging.Logger + filepath string } -func New(logger logging.Logger, os os.OS) Storage { +func New(logger logging.Logger, os os.OS, filepath string) Storage { return &storage{ - os: os, - logger: logger.WithPrefix("storage: "), + os: os, + logger: logger.WithPrefix("storage: "), + filepath: filepath, } } diff --git a/internal/storage/sync.go b/internal/storage/sync.go index 1327a439..c463fd2f 100644 --- a/internal/storage/sync.go +++ b/internal/storage/sync.go @@ -2,6 +2,7 @@ package storage import ( "encoding/json" + "errors" "fmt" "reflect" @@ -9,8 +10,9 @@ import ( "github.com/qdm12/gluetun/internal/os" ) -const ( - jsonFilepath = "/gluetun/servers.json" +var ( + ErrCannotReadFile = errors.New("cannot read servers from file") + ErrCannotWriteFile = errors.New("cannot write servers to file") ) func countServers(allServers models.AllServers) int { @@ -25,38 +27,54 @@ func countServers(allServers models.AllServers) int { len(allServers.Windscribe.Servers) } -func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) ( +func (s *storage) SyncServers(hardcodedServers models.AllServers) ( allServers models.AllServers, err error) { - // Eventually read file - var serversOnFile models.AllServers - file, err := s.os.OpenFile(jsonFilepath, os.O_RDONLY, 0) - if err != nil && !os.IsNotExist(err) { - return allServers, err - } - if err == nil { - var serversOnFile models.AllServers - decoder := json.NewDecoder(file) - if err := decoder.Decode(&serversOnFile); err != nil { - _ = file.Close() - return allServers, err - } - return allServers, file.Close() + serversOnFile, err := s.readFromFile(s.filepath) + if err != nil { + return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err) } - // Merge data from file and hardcoded - s.logger.Info("Merging by most recent %d hardcoded servers and %d servers read from %s", - countServers(hardcodedServers), countServers(serversOnFile), jsonFilepath) - allServers = s.mergeServers(hardcodedServers, serversOnFile) + hardcodedCount := countServers(hardcodedServers) + countOnFile := countServers(serversOnFile) + + if countOnFile == 0 { + s.logger.Info("creating %s with %d hardcoded servers", s.filepath, hardcodedCount) + allServers = hardcodedServers + } else { + s.logger.Info( + "merging by most recent %d hardcoded servers and %d servers read from %s", + hardcodedCount, countOnFile, s.filepath) + allServers = s.mergeServers(hardcodedServers, serversOnFile) + } // Eventually write file - if !write || reflect.DeepEqual(serversOnFile, allServers) { + if s.filepath == "" || reflect.DeepEqual(serversOnFile, allServers) { return allServers, nil } - return allServers, s.FlushToFile(allServers) + + if err := s.FlushToFile(allServers); err != nil { + return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err) + } + return allServers, nil +} + +func (s *storage) readFromFile(filepath string) (servers models.AllServers, err error) { + file, err := s.os.OpenFile(filepath, os.O_RDONLY, 0) + if os.IsNotExist(err) { + return servers, nil + } else if err != nil { + return servers, err + } + decoder := json.NewDecoder(file) + if err := decoder.Decode(&servers); err != nil { + _ = file.Close() + return servers, err + } + return servers, file.Close() } func (s *storage) FlushToFile(servers models.AllServers) error { - file, err := s.os.OpenFile(jsonFilepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + file, err := s.os.OpenFile(s.filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return err } @@ -64,7 +82,7 @@ func (s *storage) FlushToFile(servers models.AllServers) error { encoder.SetIndent("", " ") if err := encoder.Encode(servers); err != nil { _ = file.Close() - return fmt.Errorf("cannot write to file: %w", err) + return err } return file.Close() }