package cli import ( "context" "encoding/json" "errors" "flag" "fmt" "net" "net/http" "os" "strings" "time" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/updater" ) var ( ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified") ErrDNSAddress = errors.New("DNS address is not valid") ErrNoProviderSpecified = errors.New("no provider was specified") ErrNewStorage = errors.New("cannot create storage") ErrUpdateServerInformation = errors.New("cannot update server information") ErrWriteToFile = errors.New("cannot write updated information to file") ) type Updater interface { Update(ctx context.Context, args []string, logger UpdaterLogger) error } type UpdaterLogger interface { Info(s string) Warn(s string) Error(s string) } func boolPtr(b bool) *bool { return &b } func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error { options := settings.Updater{CLI: boolPtr(true)} var endUserMode, maintainerMode, updateAll bool var dnsAddress, csvProviders string flagSet := flag.NewFlagSet("update", flag.ExitOnError) flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)") flagSet.BoolVar(&maintainerMode, "maintainer", false, "Write results to ./internal/storage/servers.json to modify the program (for maintainers)") flagSet.StringVar(&dnsAddress, "dns", "8.8.8.8", "DNS resolver address to use") flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers") flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for") if err := flagSet.Parse(args); err != nil { return err } if !endUserMode && !maintainerMode { return ErrModeUnspecified } options.DNSAddress = net.ParseIP(dnsAddress) if options.DNSAddress == nil { return fmt.Errorf("%w: %s", ErrDNSAddress, dnsAddress) } if updateAll { options.Providers = constants.AllProviders() } else { if csvProviders == "" { return ErrNoProviderSpecified } options.Providers = strings.Split(csvProviders, ",") } options.SetDefaults() err := options.Validate() if err != nil { return fmt.Errorf("options validation failed: %w", err) } const clientTimeout = 10 * time.Second httpClient := &http.Client{Timeout: clientTimeout} storage, err := storage.New(logger, constants.ServersData) if err != nil { return fmt.Errorf("%w: %s", ErrNewStorage, err) } currentServers := storage.GetServers() updater := updater.New(options, httpClient, currentServers, logger) allServers, err := updater.UpdateServers(ctx) if err != nil { return fmt.Errorf("%w: %s", ErrUpdateServerInformation, err) } if endUserMode { if err := storage.FlushToFile(allServers); err != nil { return fmt.Errorf("%w: %s", ErrWriteToFile, err) } } if maintainerMode { if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil { return fmt.Errorf("%w: %s", ErrWriteToFile, err) } } return nil } func writeToEmbeddedJSON(repoServersPath string, allServers models.AllServers) error { const perms = 0600 f, err := os.OpenFile(repoServersPath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms) if err != nil { return err } defer f.Close() encoder := json.NewEncoder(f) encoder.SetIndent("", " ") return encoder.Encode(allServers) }