From 21f4cf7ab5c55d54bbd73d24f81833187abcfff2 Mon Sep 17 00:00:00 2001 From: "Quentin McGaw (desktop)" Date: Fri, 23 Jul 2021 16:06:19 +0000 Subject: [PATCH] Maint: do not mock os functions - Use filepaths with /tmp for tests instead - Only mock functions where filepath can't be specified such as user.Lookup --- cmd/gluetun/main.go | 46 +++++++++---------- go.mod | 2 +- go.sum | 4 +- internal/alpine/alpine.go | 18 ++++---- internal/alpine/users.go | 6 +-- internal/alpine/version.go | 2 +- internal/cli/cli.go | 17 ++++--- internal/cli/clientkey.go | 6 +-- internal/cli/healthcheck.go | 5 +- internal/cli/openvpnconfig.go | 7 ++- internal/cli/update.go | 16 +++---- internal/configuration/health.go | 5 +- internal/configuration/reader.go | 5 +- internal/configuration/secrets.go | 12 ++--- internal/configuration/settings.go | 5 +- internal/dns/loop.go | 18 ++++---- internal/firewall/enable.go | 2 +- internal/firewall/firewall.go | 9 ++-- internal/firewall/iptables.go | 2 +- internal/openvpn/auth.go | 6 +-- internal/openvpn/custom.go | 8 ++-- internal/openvpn/loop.go | 23 +++++----- internal/openvpn/openvpn.go | 22 +++++---- internal/openvpn/tun.go | 22 ++++----- internal/provider/cyberghost/portforward.go | 5 +- internal/provider/fastestvpn/portforward.go | 5 +- internal/provider/hidemyass/portforward.go | 5 +- internal/provider/ipvanish/portforward.go | 5 +- internal/provider/ivpn/portforward.go | 5 +- internal/provider/mullvad/portforward.go | 5 +- internal/provider/nordvpn/portforward.go | 5 +- internal/provider/privado/portforward.go | 5 +- .../privateinternetaccess/portforward.go | 46 ++++++++++--------- .../privateinternetaccess/provider.go | 15 ++++-- internal/provider/privatevpn/portforward.go | 5 +- internal/provider/protonvpn/portforward.go | 5 +- internal/provider/provider.go | 3 +- internal/provider/purevpn/portforward.go | 5 +- internal/provider/surfshark/portforward.go | 5 +- internal/provider/torguard/portforward.go | 5 +- internal/provider/vpnunlimited/portforward.go | 5 +- internal/provider/vyprvpn/portforward.go | 7 ++- internal/provider/windscribe/portforward.go | 7 ++- internal/publicip/fs.go | 12 ++--- internal/publicip/loop.go | 11 ++--- internal/storage/flush.go | 7 +++ internal/storage/storage.go | 5 +- internal/storage/sync.go | 18 ++++---- 48 files changed, 226 insertions(+), 243 deletions(-) create mode 100644 internal/storage/flush.go diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 343f03ea..ec6cc177 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -6,7 +6,7 @@ import ( "fmt" "net" "net/http" - nativeos "os" + "os" "os/signal" "strconv" "strings" @@ -33,8 +33,6 @@ import ( "github.com/qdm12/gluetun/internal/updater" versionpkg "github.com/qdm12/gluetun/internal/version" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" - "github.com/qdm12/golibs/os/user" "github.com/qdm12/golibs/params" "github.com/qdm12/goshutdown" "github.com/qdm12/gosplash" @@ -61,21 +59,19 @@ func main() { } ctx := context.Background() - ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, nativeos.Interrupt) + ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) ctx, cancel := context.WithCancel(ctx) logger := logging.NewParent(logging.Settings{}) - args := nativeos.Args - os := os.New() - osUser := user.New() + args := os.Args unix := unix.New() cli := cli.New() env := params.NewEnv() errorCh := make(chan error) go func() { - errorCh <- _main(ctx, buildInfo, args, logger, env, os, osUser, unix, cli) + errorCh <- _main(ctx, buildInfo, args, logger, env, unix, cli) }() select { @@ -86,7 +82,7 @@ func main() { stop() close(errorCh) if err == nil { // expected exit such as healthcheck - nativeos.Exit(0) + os.Exit(0) } logger.Error(err) cancel() @@ -104,7 +100,7 @@ func main() { logger.Warn("Shutdown timed out") } - nativeos.Exit(1) + os.Exit(1) } var ( @@ -113,18 +109,18 @@ var ( //nolint:gocognit,gocyclo func _main(ctx context.Context, buildInfo models.BuildInformation, - args []string, logger logging.ParentLogger, env params.Env, os os.OS, - osUser user.OSUser, unix unix.Unix, cli cli.CLI) error { + args []string, logger logging.ParentLogger, env params.Env, + unix unix.Unix, cli cli.CLI) error { if len(args) > 1 { // cli operation switch args[1] { case "healthcheck": - return cli.HealthCheck(ctx, env, os, logger) + return cli.HealthCheck(ctx, env, logger) case "clientkey": - return cli.ClientKey(args[2:], os.OpenFile) + return cli.ClientKey(args[2:]) case "openvpnconfig": - return cli.OpenvpnConfig(os, logger) + return cli.OpenvpnConfig(logger) case "update": - return cli.Update(ctx, args[2:], os, logger) + return cli.Update(ctx, args[2:], logger) default: return fmt.Errorf("%w: %s", errCommandUnknown, args[1]) } @@ -133,19 +129,19 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, const clientTimeout = 15 * time.Second httpClient := &http.Client{Timeout: clientTimeout} // Create configurators - alpineConf := alpine.NewConfigurator(os.OpenFile, osUser) + alpineConf := alpine.NewConfigurator() ovpnConf := openvpn.NewConfigurator( logger.NewChild(logging.Settings{Prefix: "openvpn configurator: "}), - os, unix) + unix) dnsCrypto := dnscrypto.New(httpClient, "", "") const cacertsPath = "/etc/ssl/certs/ca-certificates.crt" - dnsConf := unbound.NewConfigurator(nil, os.OpenFile, dnsCrypto, + dnsConf := unbound.NewConfigurator(nil, dnsCrypto, "/etc/unbound", "/usr/sbin/unbound", cacertsPath) routingConf := routing.NewRouting( logger.NewChild(logging.Settings{Prefix: "routing: "})) firewallConf := firewall.NewConfigurator( logger.NewChild(logging.Settings{Prefix: "firewall: "}), - routingConf, os.OpenFile) + routingConf) announcementExp, err := time.Parse(time.RFC3339, "2021-07-22T00:00:00Z") if err != nil { @@ -179,7 +175,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } var allSettings configuration.Settings - err = allSettings.Read(env, os, + err = allSettings.Read(env, logger.NewChild(logging.Settings{Prefix: "configuration: "})) if err != nil { return err @@ -196,7 +192,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, // TODO run this in a loop or in openvpn to reload from file without restarting storage := storage.New( logger.NewChild(logging.Settings{Prefix: "storage: "}), - os, constants.ServersData) + constants.ServersData) allServers, err := storage.SyncServers(constants.GetAllServers()) if err != nil { return err @@ -314,7 +310,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings) openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, - ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh) + ovpnConf, firewallConf, routingConf, logger, httpClient, tunnelReadyCh) openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler( "openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second}) // wait for restartOpenvpn @@ -331,7 +327,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, unboundLogger := logger.NewChild(logging.Settings{Prefix: "dns over tls: "}) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient, - unboundLogger, os.OpenFile) + unboundLogger) dnsHandler, dnsCtx, dnsDone := goshutdown.NewGoRoutineHandler( "unbound", defaultGoRoutineSettings) // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker @@ -340,7 +336,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, publicIPLooper := publicip.NewLooper(httpClient, logger.NewChild(logging.Settings{Prefix: "ip getter: "}), - allSettings.PublicIP, puid, pgid, os) + allSettings.PublicIP, puid, pgid) pubIPHandler, pubIPCtx, pubIPDone := goshutdown.NewGoRoutineHandler( "public IP", defaultGoRoutineSettings) go publicIPLooper.Run(pubIPCtx, pubIPDone) diff --git a/go.mod b/go.mod index 6d5ba049..a5e64c89 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.16 require ( github.com/fatih/color v1.12.0 github.com/golang/mock v1.6.0 - github.com/qdm12/dns v1.9.0 + github.com/qdm12/dns v1.10.0 github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99 github.com/qdm12/goshutdown v0.1.0 github.com/qdm12/gosplash v0.1.0 diff --git a/go.sum b/go.sum index 53b41ec6..1b93d85e 100644 --- a/go.sum +++ b/go.sum @@ -63,8 +63,8 @@ github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMg github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/qdm12/dns v1.9.0 h1:p4g/BfbpQ+gJRpQdklDAnybkjds+OuenF0wEGoZ8/AI= -github.com/qdm12/dns v1.9.0/go.mod h1:fqZoDf3VzddnKBMNI/OzZUp5H4dO0VBw1fp4qPkolOg= +github.com/qdm12/dns v1.10.0 h1:WX5QQ5+2h34xfhfxJTmvyURbs9XE4qNrEGtyNeq38Bw= +github.com/qdm12/dns v1.10.0/go.mod h1:fqZoDf3VzddnKBMNI/OzZUp5H4dO0VBw1fp4qPkolOg= github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg= github.com/qdm12/golibs v0.0.0-20210716185557-66793f4ddd80/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg= github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99 h1:2OKHAR0SK8BtTtWCRNoSn58eh+iVDA3Cwq4i2CnD3i4= diff --git a/internal/alpine/alpine.go b/internal/alpine/alpine.go index c3d575f3..003ce8f5 100644 --- a/internal/alpine/alpine.go +++ b/internal/alpine/alpine.go @@ -3,9 +3,7 @@ package alpine import ( "context" - - "github.com/qdm12/golibs/os" - "github.com/qdm12/golibs/os/user" + "os/user" ) type Configurator interface { @@ -14,13 +12,17 @@ type Configurator interface { } type configurator struct { - openFile os.OpenFileFunc - osUser user.OSUser + alpineReleasePath string + passwdPath string + lookupID func(uid string) (*user.User, error) + lookup func(username string) (*user.User, error) } -func NewConfigurator(openFile os.OpenFileFunc, osUser user.OSUser) Configurator { +func NewConfigurator() Configurator { return &configurator{ - openFile: openFile, - osUser: osUser, + alpineReleasePath: "/etc/alpine-release", + passwdPath: "/etc/passwd", + lookupID: user.LookupId, + lookup: user.Lookup, } } diff --git a/internal/alpine/users.go b/internal/alpine/users.go index d75db15c..0a9e7e6e 100644 --- a/internal/alpine/users.go +++ b/internal/alpine/users.go @@ -15,7 +15,7 @@ var ( // CreateUser creates a user in Alpine with the given UID. func (c *configurator) CreateUser(username string, uid int) (createdUsername string, err error) { UIDStr := strconv.Itoa(uid) - u, err := c.osUser.LookupID(UIDStr) + u, err := c.lookupID(UIDStr) _, unknownUID := err.(user.UnknownUserIdError) if err != nil && !unknownUID { return "", err @@ -28,7 +28,7 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str return u.Username, nil } - u, err = c.osUser.Lookup(username) + u, err = c.lookup(username) _, unknownUsername := err.(user.UnknownUserError) if err != nil && !unknownUsername { return "", err @@ -39,7 +39,7 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str ErrUserAlreadyExists, username, u.Uid, uid) } - file, err := c.openFile("/etc/passwd", os.O_APPEND|os.O_WRONLY, 0644) + file, err := os.OpenFile(c.passwdPath, os.O_APPEND|os.O_WRONLY, 0644) if err != nil { return "", err } diff --git a/internal/alpine/version.go b/internal/alpine/version.go index ee4a96ce..786d3ba4 100644 --- a/internal/alpine/version.go +++ b/internal/alpine/version.go @@ -8,7 +8,7 @@ import ( ) func (c *configurator) Version(ctx context.Context) (version string, err error) { - file, err := c.openFile("/etc/alpine-release", os.O_RDONLY, 0) + file, err := os.OpenFile(c.alpineReleasePath, os.O_RDONLY, 0) if err != nil { return "", err } diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 1664f459..8507dbd7 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -5,19 +5,22 @@ import ( "context" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" "github.com/qdm12/golibs/params" ) type CLI interface { - ClientKey(args []string, openFile os.OpenFileFunc) error - HealthCheck(ctx context.Context, env params.Env, os os.OS, logger logging.Logger) error - OpenvpnConfig(os os.OS, logger logging.Logger) error - Update(ctx context.Context, args []string, os os.OS, logger logging.Logger) error + ClientKey(args []string) error + HealthCheck(ctx context.Context, env params.Env, logger logging.Logger) error + OpenvpnConfig(logger logging.Logger) error + Update(ctx context.Context, args []string, logger logging.Logger) error } -type cli struct{} +type cli struct { + repoServersPath string +} func New() CLI { - return &cli{} + return &cli{ + repoServersPath: "./internal/constants/servers.json", + } } diff --git a/internal/cli/clientkey.go b/internal/cli/clientkey.go index f4a88714..68ee51f7 100644 --- a/internal/cli/clientkey.go +++ b/internal/cli/clientkey.go @@ -4,19 +4,19 @@ import ( "flag" "fmt" "io" + "os" "strings" "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/golibs/os" ) -func (c *cli) ClientKey(args []string, openFile os.OpenFileFunc) error { +func (c *cli) ClientKey(args []string) error { flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError) filepath := flagSet.String("path", constants.ClientKey, "file path to the client.key file") if err := flagSet.Parse(args); err != nil { return err } - file, err := openFile(*filepath, os.O_RDONLY, 0) + file, err := os.OpenFile(*filepath, os.O_RDONLY, 0) if err != nil { return err } diff --git a/internal/cli/healthcheck.go b/internal/cli/healthcheck.go index 95543c9b..0e38e58d 100644 --- a/internal/cli/healthcheck.go +++ b/internal/cli/healthcheck.go @@ -9,15 +9,14 @@ import ( "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/healthcheck" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" "github.com/qdm12/golibs/params" ) func (c *cli) HealthCheck(ctx context.Context, env params.Env, - os os.OS, logger logging.Logger) error { + logger logging.Logger) error { // Extract the health server port from the configuration. config := configuration.Health{} - err := config.Read(env, os, logger) + err := config.Read(env, logger) if err != nil { return err } diff --git a/internal/cli/openvpnconfig.go b/internal/cli/openvpnconfig.go index 74c38c59..0ff7d41d 100644 --- a/internal/cli/openvpnconfig.go +++ b/internal/cli/openvpnconfig.go @@ -10,17 +10,16 @@ import ( "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" "github.com/qdm12/golibs/params" ) -func (c *cli) OpenvpnConfig(os os.OS, logger logging.Logger) error { +func (c *cli) OpenvpnConfig(logger logging.Logger) error { var allSettings configuration.Settings - err := allSettings.Read(params.NewEnv(), os, logger) + err := allSettings.Read(params.NewEnv(), logger) if err != nil { return err } - allServers, err := storage.New(logger, os, constants.ServersData). + allServers, err := storage.New(logger, constants.ServersData). SyncServers(constants.GetAllServers()) if err != nil { return err diff --git a/internal/cli/update.go b/internal/cli/update.go index fb0a5648..f3aaf51e 100644 --- a/internal/cli/update.go +++ b/internal/cli/update.go @@ -7,7 +7,7 @@ import ( "flag" "fmt" "net/http" - nativeos "os" + "os" "time" "github.com/qdm12/gluetun/internal/configuration" @@ -16,7 +16,6 @@ import ( "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) var ( @@ -26,7 +25,7 @@ var ( ErrWriteToFile = errors.New("cannot write updated information to file") ) -func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger logging.Logger) error { +func (c *cli) Update(ctx context.Context, args []string, logger logging.Logger) error { options := configuration.Updater{CLI: true} var endUserMode, maintainerMode, updateAll bool flagSet := flag.NewFlagSet("update", flag.ExitOnError) @@ -65,7 +64,7 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin const clientTimeout = 10 * time.Second httpClient := &http.Client{Timeout: clientTimeout} - storage := storage.New(logger, os, constants.ServersData) + storage := storage.New(logger, constants.ServersData) currentServers, err := storage.SyncServers(constants.GetAllServers()) if err != nil { return fmt.Errorf("%w: %s", ErrSyncServers, err) @@ -83,7 +82,7 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin } if maintainerMode { - if err := writeToEmbeddedJSON(os, allServers); err != nil { + if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil { return fmt.Errorf("%w: %s", ErrWriteToFile, err) } } @@ -91,10 +90,11 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin return nil } -func writeToEmbeddedJSON(os os.OS, allServers models.AllServers) error { +func writeToEmbeddedJSON(repoServersPath string, + allServers models.AllServers) error { const perms = 0600 - f, err := os.OpenFile("./internal/constants/servers.json", - nativeos.O_TRUNC|nativeos.O_WRONLY|nativeos.O_CREATE, perms) + f, err := os.OpenFile(repoServersPath, + os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms) if err != nil { return err } diff --git a/internal/configuration/health.go b/internal/configuration/health.go index d54dbac0..e881aaa7 100644 --- a/internal/configuration/health.go +++ b/internal/configuration/health.go @@ -5,7 +5,6 @@ import ( "strings" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" "github.com/qdm12/golibs/params" ) @@ -33,8 +32,8 @@ func (settings *Health) lines() (lines []string) { } // Read is to be used for the healthcheck query mode. -func (settings *Health) Read(env params.Env, os os.OS, logger logging.Logger) (err error) { - reader := newReader(env, os, logger) +func (settings *Health) Read(env params.Env, logger logging.Logger) (err error) { + reader := newReader(env, logger) return settings.read(reader) } diff --git a/internal/configuration/reader.go b/internal/configuration/reader.go index 40e04b06..552a67fe 100644 --- a/internal/configuration/reader.go +++ b/internal/configuration/reader.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" "github.com/qdm12/golibs/params" "github.com/qdm12/golibs/verification" ) @@ -17,15 +16,13 @@ type reader struct { env params.Env logger logging.Logger regex verification.Regex - os os.OS } -func newReader(env params.Env, os os.OS, logger logging.Logger) reader { +func newReader(env params.Env, logger logging.Logger) reader { return reader{ env: env, logger: logger, regex: verification.NewRegex(), - os: os, } } diff --git a/internal/configuration/secrets.go b/internal/configuration/secrets.go index 84e63270..978a92ca 100644 --- a/internal/configuration/secrets.go +++ b/internal/configuration/secrets.go @@ -4,9 +4,9 @@ import ( "errors" "fmt" "io" + "os" "strings" - "github.com/qdm12/golibs/os" "github.com/qdm12/golibs/params" ) @@ -48,7 +48,7 @@ func (r *reader) getFromEnvOrSecretFile(envKey string, compulsory bool, retroKey ErrGetSecretFilepath, secretFilepathEnvKey, err) } - file, fileErr := r.os.OpenFile(filepath, os.O_RDONLY, 0) + file, fileErr := os.OpenFile(filepath, os.O_RDONLY, 0) if os.IsNotExist(fileErr) { if compulsory { return "", envErr @@ -85,7 +85,7 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) ( return b, fmt.Errorf("environment variable %s: %w: %s", key, ErrGetSecretFilepath, err) } - b, err = readFromFile(r.os.OpenFile, secretFilepath) + b, err = readFromFile(secretFilepath) if err != nil && !os.IsNotExist(err) { return b, fmt.Errorf("%w: %s", ErrReadSecretFile, err) } else if err == nil { @@ -93,7 +93,7 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) ( } // Secret file does not exist, try the non secret file - b, err = readFromFile(r.os.OpenFile, filepath) + b, err = readFromFile(filepath) if err != nil && !os.IsNotExist(err) { return nil, fmt.Errorf("%w: %s", ErrReadSecretFile, err) } else if err == nil { @@ -102,8 +102,8 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) ( return nil, fmt.Errorf("%w: %s and %s", ErrFilesDoNotExist, secretFilepath, filepath) } -func readFromFile(openFile os.OpenFileFunc, filepath string) (b []byte, err error) { - file, err := openFile(filepath, os.O_RDONLY, 0) +func readFromFile(filepath string) (b []byte, err error) { + file, err := os.Open(filepath) if err != nil { return nil, err } diff --git a/internal/configuration/settings.go b/internal/configuration/settings.go index a8ec95db..15a01d4c 100644 --- a/internal/configuration/settings.go +++ b/internal/configuration/settings.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" "github.com/qdm12/golibs/params" ) @@ -62,8 +61,8 @@ var ( // Read obtains all configuration options for the program and returns an error as soon // as an error is encountered reading them. -func (settings *Settings) Read(env params.Env, os os.OS, logger logging.Logger) (err error) { - r := newReader(env, os, logger) +func (settings *Settings) Read(env params.Env, logger logging.Logger) (err error) { + r := newReader(env, logger) settings.VersionInformation, err = r.env.OnOff("VERSION_INFORMATION", params.Default("on")) if err != nil { diff --git a/internal/dns/loop.go b/internal/dns/loop.go index c880e5b1..e4e50e18 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -16,7 +16,6 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) type Looper interface { @@ -33,6 +32,7 @@ type Looper interface { type looper struct { state *state conf unbound.Configurator + resolvConf string blockBuilder blacklist.Builder client *http.Client logger logging.Logger @@ -45,13 +45,12 @@ type looper struct { backoffTime time.Duration timeNow func() time.Time timeSince func(time.Time) time.Duration - openFile os.OpenFileFunc } const defaultBackoffTime = 10 * time.Second func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *http.Client, - logger logging.Logger, openFile os.OpenFileFunc) Looper { + logger logging.Logger) Looper { start := make(chan struct{}) running := make(chan models.LoopStatus) stop := make(chan struct{}) @@ -63,6 +62,7 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht return &looper{ state: state, conf: conf, + resolvConf: "/etc/resolv.conf", blockBuilder: blacklist.NewBuilder(client), client: client, logger: logger, @@ -75,7 +75,6 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht backoffTime: defaultBackoffTime, timeNow: time.Now, timeSince: time.Since, - openFile: openFile, } } @@ -227,8 +226,8 @@ func (l *looper) setupUnbound(ctx context.Context) ( // use Unbound nameserver.UseDNSInternally(net.IP{127, 0, 0, 1}) - err = nameserver.UseDNSSystemWide(l.openFile, - net.IP{127, 0, 0, 1}, settings.KeepNameserver) + err = nameserver.UseDNSSystemWide(l.resolvConf, net.IP{127, 0, 0, 1}, + settings.KeepNameserver) if err != nil { l.logger.Error(err) } @@ -256,8 +255,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) { l.logger.Info("using plaintext DNS at address %s", targetIP) } nameserver.UseDNSInternally(targetIP) - if err := nameserver.UseDNSSystemWide(l.openFile, - targetIP, settings.KeepNameserver); err != nil { + err := nameserver.UseDNSSystemWide(l.resolvConf, targetIP, settings.KeepNameserver) + if err != nil { l.logger.Error(err) } return @@ -271,7 +270,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) { l.logger.Info("using plaintext DNS at address " + targetIP.String()) } nameserver.UseDNSInternally(targetIP) - if err := nameserver.UseDNSSystemWide(l.openFile, targetIP, settings.KeepNameserver); err != nil { + err := nameserver.UseDNSSystemWide(l.resolvConf, targetIP, settings.KeepNameserver) + if err != nil { l.logger.Error(err) } } diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 257cc24d..004bc295 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -136,7 +136,7 @@ func (c *configurator) enable(ctx context.Context) (err error) { } } - if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil { + if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil { return fmt.Errorf("%w: %s", ErrUserPostRules, err) } diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 83b960b0..61cf6432 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -11,7 +11,6 @@ import ( "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) // Configurator allows to change firewall rules and modify network routes. @@ -32,7 +31,6 @@ type configurator struct { //nolint:maligned commander command.Commander logger logging.Logger routing routing.Routing - openFile os.OpenFileFunc // for custom iptables rules iptablesMutex sync.Mutex ip6tablesMutex sync.Mutex debug bool @@ -43,7 +41,8 @@ type configurator struct { //nolint:maligned networkInfoMutex sync.Mutex // Fixed state - ip6Tables bool + ip6Tables bool + customRulesPath string // State enabled bool @@ -54,15 +53,15 @@ type configurator struct { //nolint:maligned } // NewConfigurator creates a new Configurator instance. -func NewConfigurator(logger logging.Logger, routing routing.Routing, openFile os.OpenFileFunc) Configurator { +func NewConfigurator(logger logging.Logger, routing routing.Routing) Configurator { commander := command.NewCommander() return &configurator{ commander: commander, logger: logger, routing: routing, - openFile: openFile, allowedInputPorts: make(map[uint16]string), ip6Tables: ip6tablesSupported(context.Background(), commander), + customRulesPath: "/iptables/post-rules.txt", } } diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index 2eafa4bf..9dccbcfd 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -196,7 +196,7 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port } func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error { - file, err := c.openFile(filepath, os.O_RDONLY, 0) + file, err := os.OpenFile(filepath, os.O_RDONLY, 0) if os.IsNotExist(err) { return nil } else if err != nil { diff --git a/internal/openvpn/auth.go b/internal/openvpn/auth.go index 24ee0e11..307b1ac6 100644 --- a/internal/openvpn/auth.go +++ b/internal/openvpn/auth.go @@ -10,14 +10,14 @@ import ( // WriteAuthFile writes the OpenVPN auth file to disk with the right permissions. func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) error { - file, err := c.os.OpenFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0) + file, err := os.Open(c.authFilePath) if err != nil && !os.IsNotExist(err) { return err } if os.IsNotExist(err) { - file, err = c.os.OpenFile(constants.OpenVPNAuthConf, os.O_WRONLY|os.O_CREATE, 0400) + file, err = os.OpenFile(c.authFilePath, os.O_WRONLY|os.O_CREATE, 0400) if err != nil { return err } @@ -49,7 +49,7 @@ func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) erro } c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf) - file, err = c.os.OpenFile(constants.OpenVPNAuthConf, os.O_TRUNC|os.O_WRONLY, 0400) + file, err = os.OpenFile(c.authFilePath, os.O_TRUNC|os.O_WRONLY, 0400) if err != nil { return err } diff --git a/internal/openvpn/custom.go b/internal/openvpn/custom.go index 0cfe5906..c7f02c2b 100644 --- a/internal/openvpn/custom.go +++ b/internal/openvpn/custom.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "os" "strconv" "strings" @@ -12,14 +13,13 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/utils" - "github.com/qdm12/golibs/os" ) var errProcessCustomConfig = errors.New("cannot process custom config") func (l *looper) processCustomConfig(settings configuration.OpenVPN) ( lines []string, connection models.OpenVPNConnection, err error) { - lines, err = readCustomConfigLines(settings.Config, l.openFile) + lines, err = readCustomConfigLines(settings.Config) if err != nil { return nil, connection, fmt.Errorf("%w: %s", errProcessCustomConfig, err) } @@ -35,9 +35,9 @@ func (l *looper) processCustomConfig(settings configuration.OpenVPN) ( return lines, connection, nil } -func readCustomConfigLines(filepath string, openFile os.OpenFileFunc) ( +func readCustomConfigLines(filepath string) ( lines []string, err error) { - file, err := openFile(filepath, os.O_RDONLY, 0) + file, err := os.Open(filepath) if err != nil { return nil, err } diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 1e39d1bc..c9f6f5ff 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/http" + "os" "strings" "time" @@ -14,7 +15,6 @@ import ( "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) type Looper interface { @@ -34,9 +34,10 @@ type Looper interface { type looper struct { state *state // Fixed parameters - username string - puid int - pgid int + username string + puid int + pgid int + targetConfPath string // Configurators conf Configurator fw firewall.Configurator @@ -44,7 +45,6 @@ type looper struct { // Other objects logger, pfLogger logging.Logger client *http.Client - openFile os.OpenFileFunc tunnelReady chan<- struct{} // Internal channels and values stop <-chan struct{} @@ -64,7 +64,7 @@ const ( func NewLooper(settings configuration.OpenVPN, username string, puid, pgid int, allServers models.AllServers, conf Configurator, fw firewall.Configurator, routing routing.Routing, - logger logging.ParentLogger, client *http.Client, openFile os.OpenFileFunc, + logger logging.ParentLogger, client *http.Client, tunnelReady chan<- struct{}) Looper { start := make(chan struct{}) running := make(chan models.LoopStatus) @@ -79,13 +79,13 @@ func NewLooper(settings configuration.OpenVPN, username: username, puid: puid, pgid: pgid, + targetConfPath: constants.OpenVPNConf, conf: conf, fw: fw, routing: routing, logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}), pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}), client: client, - openFile: openFile, tunnelReady: tunnelReady, start: start, running: running, @@ -145,7 +145,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { } } - if err := writeOpenvpnConf(lines, l.openFile); err != nil { + if err := l.writeOpenvpnConf(lines); err != nil { l.signalOrSetStatus(constants.Crashed) l.logAndWait(ctx, err) continue @@ -279,13 +279,12 @@ func (l *looper) portForward(ctx context.Context, defer l.state.settingsMu.RUnlock() return settings.Provider.PortForwarding.Filepath } - providerConf.PortForward(ctx, - client, l.openFile, l.pfLogger, + providerConf.PortForward(ctx, client, l.pfLogger, gateway, l.fw, syncState) } -func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error { - file, err := openFile(constants.OpenVPNConf, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) +func (l *looper) writeOpenvpnConf(lines []string) error { + file, err := os.OpenFile(l.targetConfPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { return err } diff --git a/internal/openvpn/openvpn.go b/internal/openvpn/openvpn.go index 9de40014..7389a72a 100644 --- a/internal/openvpn/openvpn.go +++ b/internal/openvpn/openvpn.go @@ -5,10 +5,10 @@ package openvpn import ( "context" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/unix" "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) type Configurator interface { @@ -22,17 +22,19 @@ type Configurator interface { } type configurator struct { - logger logging.Logger - commander command.Commander - os os.OS - unix unix.Unix + logger logging.Logger + commander command.Commander + unix unix.Unix + authFilePath string + tunDevPath string } -func NewConfigurator(logger logging.Logger, os os.OS, unix unix.Unix) Configurator { +func NewConfigurator(logger logging.Logger, unix unix.Unix) Configurator { return &configurator{ - logger: logger, - commander: command.NewCommander(), - os: os, - unix: unix, + logger: logger, + commander: command.NewCommander(), + unix: unix, + authFilePath: constants.OpenVPNAuthConf, + tunDevPath: constants.TunnelDevice, } } diff --git a/internal/openvpn/tun.go b/internal/openvpn/tun.go index 01eaf80c..95f2aa4f 100644 --- a/internal/openvpn/tun.go +++ b/internal/openvpn/tun.go @@ -3,15 +3,15 @@ package openvpn import ( "fmt" "os" + "path/filepath" - "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/unix" ) // CheckTUN checks the tunnel device is present and accessible. func (c *configurator) CheckTUN() error { - c.logger.Info("checking for device %s", constants.TunnelDevice) - f, err := c.os.OpenFile(constants.TunnelDevice, os.O_RDWR, 0) + c.logger.Info("checking for device " + c.tunDevPath) + f, err := os.OpenFile(c.tunDevPath, os.O_RDWR, 0) if err != nil { return fmt.Errorf("TUN device is not available: %w", err) } @@ -22,8 +22,10 @@ func (c *configurator) CheckTUN() error { } func (c *configurator) CreateTUN() error { - c.logger.Info("creating %s", constants.TunnelDevice) - if err := c.os.MkdirAll("/dev/net", 0751); err != nil { //nolint:gomnd + c.logger.Info("creating " + c.tunDevPath) + + parentDir := filepath.Dir(c.tunDevPath) + if err := os.MkdirAll(parentDir, 0751); err != nil { //nolint:gomnd return err } @@ -32,17 +34,13 @@ func (c *configurator) CreateTUN() error { minor = 200 ) dev := c.unix.Mkdev(major, minor) - if err := c.unix.Mknod(constants.TunnelDevice, unix.S_IFCHR, int(dev)); err != nil { + if err := c.unix.Mknod(c.tunDevPath, unix.S_IFCHR, int(dev)); err != nil { return err } - file, err := c.os.OpenFile(constants.TunnelDevice, os.O_WRONLY, 0666) //nolint:gomnd - if err != nil { - return err - } const readWriteAllPerms os.FileMode = 0666 - if err := file.Chmod(readWriteAllPerms); err != nil { - _ = file.Close() + file, err := os.OpenFile(c.tunDevPath, os.O_WRONLY, readWriteAllPerms) + if err != nil { return err } diff --git a/internal/provider/cyberghost/portforward.go b/internal/provider/cyberghost/portforward.go index 5ec4c4ba..299094a9 100644 --- a/internal/provider/cyberghost/portforward.go +++ b/internal/provider/cyberghost/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (c *Cyberghost) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for Cyberghost") } diff --git a/internal/provider/fastestvpn/portforward.go b/internal/provider/fastestvpn/portforward.go index c4d6dddf..f3801478 100644 --- a/internal/provider/fastestvpn/portforward.go +++ b/internal/provider/fastestvpn/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (f *Fastestvpn) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for FastestVPN") } diff --git a/internal/provider/hidemyass/portforward.go b/internal/provider/hidemyass/portforward.go index d4b8c3f1..3fa3a6d0 100644 --- a/internal/provider/hidemyass/portforward.go +++ b/internal/provider/hidemyass/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (f *HideMyAss) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for HideMyAss") } diff --git a/internal/provider/ipvanish/portforward.go b/internal/provider/ipvanish/portforward.go index ddd3a822..47f486b1 100644 --- a/internal/provider/ipvanish/portforward.go +++ b/internal/provider/ipvanish/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (i *Ipvanish) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for Ipvanish") } diff --git a/internal/provider/ivpn/portforward.go b/internal/provider/ivpn/portforward.go index 4f8389d6..01b1eae7 100644 --- a/internal/provider/ivpn/portforward.go +++ b/internal/provider/ivpn/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (i *Ivpn) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for Ivpn") } diff --git a/internal/provider/mullvad/portforward.go b/internal/provider/mullvad/portforward.go index ebe3ce4f..9bb8d6a2 100644 --- a/internal/provider/mullvad/portforward.go +++ b/internal/provider/mullvad/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (m *Mullvad) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding logic is not needed for Mullvad") } diff --git a/internal/provider/nordvpn/portforward.go b/internal/provider/nordvpn/portforward.go index df46f828..9bafbea3 100644 --- a/internal/provider/nordvpn/portforward.go +++ b/internal/provider/nordvpn/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (n *Nordvpn) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for NordVPN") } diff --git a/internal/provider/privado/portforward.go b/internal/provider/privado/portforward.go index 3ad0065c..48c0dea0 100644 --- a/internal/provider/privado/portforward.go +++ b/internal/provider/privado/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (p *Privado) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for Privado") } diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index ce4c22d1..d6da3cef 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/url" + "os" "strconv" "strings" "time" @@ -18,7 +19,6 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/format" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) var ( @@ -28,7 +28,7 @@ var ( // PortForward obtains a VPN server side port forwarded from PIA. //nolint:gocognit func (p *PIA) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator, + logger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { commonName := p.activeServer.ServerName if !p.activeServer.PortForward { @@ -47,7 +47,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client, return } - data, err := readPIAPortForwardData(openFile) + data, err := readPIAPortForwardData(p.portForwardPath) if err != nil { logger.Error(err) } @@ -67,7 +67,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client, if !dataFound || expired { tryUntilSuccessful(ctx, logger, func() error { - data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile) + data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, + p.portForwardPath, p.authFilePath) return err }) if ctx.Err() != nil { @@ -91,7 +92,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client, filepath := syncState(data.Port) logger.Info("Writing port to " + filepath) - if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil { + if err := writePortForwardedToFile(filepath, data.Port); err != nil { logger.Error(err) } @@ -128,7 +129,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client, data.Expiration.Format(time.RFC1123) + ", getting another one") oldPort := data.Port for { - data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile) + data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, + p.portForwardPath, p.authFilePath) if err != nil { logger.Error(err) continue @@ -146,7 +148,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client, } filepath := syncState(data.Port) logger.Info("Writing port to " + filepath) - if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil { + if err := writePortForwardedToFile(filepath, data.Port); err != nil { logger.Error("Cannot write port forward data to file: " + err.Error()) } if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { @@ -168,8 +170,8 @@ var ( ) func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client, - gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) { - data.Token, err = fetchToken(ctx, openFile, client) + gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) { + data.Token, err = fetchToken(ctx, client, authFilePath) if err != nil { return data, fmt.Errorf("%w: %s", ErrFetchToken, err) } @@ -179,7 +181,7 @@ func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *htt return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err) } - if err := writePIAPortForwardData(openFile, data); err != nil { + if err := writePIAPortForwardData(portForwardPath, data); err != nil { return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err) } @@ -199,8 +201,8 @@ type piaPortForwardData struct { Expiration time.Time `json:"expires_at"` } -func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) { - file, err := openFile(constants.PIAPortForward, os.O_RDONLY, 0) +func readPIAPortForwardData(portForwardPath string) (data piaPortForwardData, err error) { + file, err := os.Open(portForwardPath) if os.IsNotExist(err) { return data, nil } else if err != nil { @@ -216,8 +218,8 @@ func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, return data, file.Close() } -func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) { - file, err := openFile(constants.PIAPortForward, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) +func writePIAPortForwardData(portForwardPath string, data piaPortForwardData) (err error) { + file, err := os.OpenFile(portForwardPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { return err } @@ -269,9 +271,9 @@ var ( errEmptyToken = errors.New("token received is empty") ) -func fetchToken(ctx context.Context, openFile os.OpenFileFunc, - client *http.Client) (token string, err error) { - username, password, err := getOpenvpnCredentials(openFile) +func fetchToken(ctx context.Context, client *http.Client, + authFilePath string) (token string, err error) { + username, password, err := getOpenvpnCredentials(authFilePath) if err != nil { return "", fmt.Errorf("%w: %s", errGetCredentials, err) } @@ -321,8 +323,9 @@ var ( errAuthFileMalformed = errors.New("authentication file is malformed") ) -func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) { - file, err := openFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0) +func getOpenvpnCredentials(authFilePath string) ( + username, password string, err error) { + file, err := os.Open(authFilePath) if err != nil { return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err) } @@ -460,9 +463,8 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia return nil } -func writePortForwardedToFile(openFile os.OpenFileFunc, - filepath string, port uint16) (err error) { - file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) +func writePortForwardedToFile(filepath string, port uint16) (err error) { + file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { return err } diff --git a/internal/provider/privateinternetaccess/provider.go b/internal/provider/privateinternetaccess/provider.go index 856dffe9..096ae9d6 100644 --- a/internal/provider/privateinternetaccess/provider.go +++ b/internal/provider/privateinternetaccess/provider.go @@ -4,6 +4,7 @@ import ( "math/rand" "time" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" ) @@ -12,12 +13,18 @@ type PIA struct { randSource rand.Source timeNow func() time.Time activeServer models.PIAServer + // Port forwarding + portForwardPath string + authFilePath string } -func New(servers []models.PIAServer, randSource rand.Source, timeNow func() time.Time) *PIA { +func New(servers []models.PIAServer, randSource rand.Source, + timeNow func() time.Time) *PIA { return &PIA{ - servers: servers, - timeNow: timeNow, - randSource: randSource, + servers: servers, + timeNow: timeNow, + randSource: randSource, + portForwardPath: constants.PIAPortForward, + authFilePath: constants.OpenVPNAuthConf, } } diff --git a/internal/provider/privatevpn/portforward.go b/internal/provider/privatevpn/portforward.go index 691ad37b..ec777aaa 100644 --- a/internal/provider/privatevpn/portforward.go +++ b/internal/provider/privatevpn/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (p *Privatevpn) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for PrivateVPN") } diff --git a/internal/provider/protonvpn/portforward.go b/internal/provider/protonvpn/portforward.go index fd86af10..c7990fb5 100644 --- a/internal/provider/protonvpn/portforward.go +++ b/internal/provider/protonvpn/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (p *Protonvpn) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for ProtonVPN") } diff --git a/internal/provider/provider.go b/internal/provider/provider.go index e3c99c5d..2e7ce5f1 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -30,7 +30,6 @@ import ( "github.com/qdm12/gluetun/internal/provider/vyprvpn" "github.com/qdm12/gluetun/internal/provider/windscribe" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) // Provider contains methods to read and modify the openvpn configuration to connect as a client. @@ -38,7 +37,7 @@ type Provider interface { GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error) BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) } diff --git a/internal/provider/purevpn/portforward.go b/internal/provider/purevpn/portforward.go index 5312681c..d6f19664 100644 --- a/internal/provider/purevpn/portforward.go +++ b/internal/provider/purevpn/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (p *Purevpn) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for PureVPN") } diff --git a/internal/provider/surfshark/portforward.go b/internal/provider/surfshark/portforward.go index 5822a50b..e9220276 100644 --- a/internal/provider/surfshark/portforward.go +++ b/internal/provider/surfshark/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (s *Surfshark) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for Surfshark") } diff --git a/internal/provider/torguard/portforward.go b/internal/provider/torguard/portforward.go index f7835a15..1688777c 100644 --- a/internal/provider/torguard/portforward.go +++ b/internal/provider/torguard/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (t *Torguard) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for Torguard") } diff --git a/internal/provider/vpnunlimited/portforward.go b/internal/provider/vpnunlimited/portforward.go index db5ee010..c5191a98 100644 --- a/internal/provider/vpnunlimited/portforward.go +++ b/internal/provider/vpnunlimited/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) func (p *Provider) PortForward(ctx context.Context, client *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for VPN Unlimited") } diff --git a/internal/provider/vyprvpn/portforward.go b/internal/provider/vyprvpn/portforward.go index 9ede1359..b50d1858 100644 --- a/internal/provider/vyprvpn/portforward.go +++ b/internal/provider/vyprvpn/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) -func (v *Vyprvpn) PortForward(ctx context.Context, clienv *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { +func (v *Vyprvpn) PortForward(ctx context.Context, client *http.Client, + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for Vyprvpn") } diff --git a/internal/provider/windscribe/portforward.go b/internal/provider/windscribe/portforward.go index eeccccb8..253f1e43 100644 --- a/internal/provider/windscribe/portforward.go +++ b/internal/provider/windscribe/portforward.go @@ -7,11 +7,10 @@ import ( "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) -func (w *Windscribe) PortForward(ctx context.Context, clienw *http.Client, - openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, - fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { +func (w *Windscribe) PortForward(ctx context.Context, client *http.Client, + pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + syncState func(port uint16) (pfFilepath string)) { panic("port forwarding is not supported for Windscribe") } diff --git a/internal/publicip/fs.go b/internal/publicip/fs.go index 92ad5d71..750d97d0 100644 --- a/internal/publicip/fs.go +++ b/internal/publicip/fs.go @@ -1,13 +1,11 @@ package publicip -import "github.com/qdm12/golibs/os" +import ( + "os" +) -func persistPublicIP(openFile os.OpenFileFunc, - filepath string, content string, puid, pgid int) error { - file, err := openFile( - filepath, - os.O_TRUNC|os.O_WRONLY|os.O_CREATE, - 0644) +func persistPublicIP(path string, content string, puid, pgid int) error { + file, err := os.OpenFile(path, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return err } diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index 6b686b7f..e5684176 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/http" + "os" "sync" "time" @@ -11,7 +12,6 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) type Looper interface { @@ -31,7 +31,6 @@ type looper struct { getter IPGetter client *http.Client logger logging.Logger - os os.OS // Fixed settings puid int pgid int @@ -51,8 +50,7 @@ type looper struct { const defaultBackoffTime = 5 * time.Second func NewLooper(client *http.Client, logger logging.Logger, - settings configuration.PublicIP, puid, pgid int, - os os.OS) Looper { + settings configuration.PublicIP, puid, pgid int) Looper { return &looper{ state: state{ status: constants.Stopped, @@ -62,7 +60,6 @@ func NewLooper(client *http.Client, logger logging.Logger, client: client, getter: NewIPGetter(client), logger: logger, - os: os, puid: puid, pgid: pgid, start: make(chan struct{}), @@ -136,7 +133,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { close(errorCh) filepath := l.GetSettings().IPFilepath l.logger.Info("Removing ip file " + filepath) - if err := l.os.Remove(filepath); err != nil { + if err := os.Remove(filepath); err != nil { l.logger.Error(err) } return @@ -161,7 +158,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { } l.logger.Info(message) - err = persistPublicIP(l.os.OpenFile, l.state.settings.IPFilepath, + err = persistPublicIP(l.state.settings.IPFilepath, ip.String(), l.puid, l.pgid) if err != nil { l.logger.Error(err) diff --git a/internal/storage/flush.go b/internal/storage/flush.go new file mode 100644 index 00000000..4846905f --- /dev/null +++ b/internal/storage/flush.go @@ -0,0 +1,7 @@ +package storage + +import "github.com/qdm12/gluetun/internal/models" + +func (s *storage) FlushToFile(allServers models.AllServers) error { + return flushToFile(s.filepath, allServers) +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index d6c8d6f4..574aa236 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -4,7 +4,6 @@ package storage import ( "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/golibs/logging" - "github.com/qdm12/golibs/os" ) type Storage interface { @@ -14,14 +13,12 @@ type Storage interface { } type storage struct { - os os.OS logger logging.Logger filepath string } -func New(logger logging.Logger, os os.OS, filepath string) Storage { +func New(logger logging.Logger, filepath string) Storage { return &storage{ - os: os, logger: logger, filepath: filepath, } diff --git a/internal/storage/sync.go b/internal/storage/sync.go index 0411d569..f72c2d6d 100644 --- a/internal/storage/sync.go +++ b/internal/storage/sync.go @@ -5,11 +5,11 @@ import ( "errors" "fmt" "io" + "os" "path/filepath" "reflect" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/os" ) var ( @@ -39,7 +39,7 @@ func countServers(allServers models.AllServers) int { func (s *storage) SyncServers(hardcodedServers models.AllServers) ( allServers models.AllServers, err error) { - serversOnFile, err := s.readFromFile(s.filepath) + serversOnFile, err := readFromFile(s.filepath) if err != nil { return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err) } @@ -62,14 +62,14 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers) ( return allServers, nil } - if err := s.FlushToFile(allServers); err != nil { + if err := flushToFile(s.filepath, allServers); err != nil { return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err) } return allServers, nil } -func (s *storage) readFromFile(filepath string) (servers models.AllServers, err error) { - file, err := s.os.OpenFile(filepath, os.O_RDONLY, 0) +func readFromFile(filepath string) (servers models.AllServers, err error) { + file, err := os.Open(filepath) if os.IsNotExist(err) { return servers, nil } else if err != nil { @@ -86,13 +86,13 @@ func (s *storage) readFromFile(filepath string) (servers models.AllServers, err return servers, file.Close() } -func (s *storage) FlushToFile(servers models.AllServers) error { - dirPath := filepath.Dir(s.filepath) - if err := s.os.MkdirAll(dirPath, 0644); err != nil { +func flushToFile(path string, servers models.AllServers) error { + dirPath := filepath.Dir(path) + if err := os.MkdirAll(dirPath, 0644); err != nil { return err } - file, err := s.os.OpenFile(s.filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return err }