From 73479bab260afdf27ad2e485ad5e64030633b6a0 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 29 Dec 2020 00:55:31 +0000 Subject: [PATCH] Code maintenance: OS package for file system - OS custom internal package for file system interaction - Remove fileManager external dependency - Closer API to Go's native API on the OS - Create directories at startup - Better testability - Move Unsetenv to os interface --- .golangci.yml | 4 + cmd/gluetun/main.go | 53 ++++++---- internal/alpine/alpine.go | 16 +-- internal/alpine/users.go | 14 +-- internal/cli/cli.go | 29 +++-- internal/constants/permissions.go | 8 -- internal/dns/conf.go | 30 ++++-- internal/dns/dns.go | 25 ++--- internal/dns/nameserver.go | 19 +++- internal/dns/nameserver_test.go | 61 ++++++++--- internal/dns/roots.go | 51 +++++---- internal/dns/roots_test.go | 170 +++++++++++++++++------------- internal/firewall/firewall.go | 8 +- internal/firewall/iptables.go | 16 ++- internal/openvpn/auth.go | 65 +++++++++--- internal/openvpn/loop.go | 31 ++++-- internal/openvpn/openvpn.go | 27 +++-- internal/openvpn/tun.go | 18 +++- internal/os/alias.go | 9 ++ internal/os/constants.go | 16 +++ internal/os/file.go | 15 +++ internal/os/funcs.go | 10 ++ internal/os/mock_os/file.go | 121 +++++++++++++++++++++ internal/os/mock_os/os.go | 121 +++++++++++++++++++++ internal/os/os.go | 39 +++++++ internal/params/cyberghost.go | 24 ++++- internal/params/openvpn.go | 4 +- internal/params/params.go | 23 ++-- internal/params/shadowsocks.go | 2 +- internal/provider/cyberghost.go | 4 +- internal/provider/mullvad.go | 4 +- internal/provider/nordvpn.go | 4 +- internal/provider/piav4.go | 101 +++++++++++------- internal/provider/privado.go | 4 +- internal/provider/provider.go | 4 +- internal/provider/purevpn.go | 4 +- internal/provider/surfshark.go | 4 +- internal/provider/vyprvpn.go | 4 +- internal/provider/windscribe.go | 4 +- internal/publicip/fs.go | 27 +++++ internal/publicip/loop.go | 25 ++--- internal/storage/storage.go | 18 ++-- internal/storage/sync.go | 40 ++++--- 43 files changed, 923 insertions(+), 353 deletions(-) delete mode 100644 internal/constants/permissions.go create mode 100644 internal/os/alias.go create mode 100644 internal/os/constants.go create mode 100644 internal/os/file.go create mode 100644 internal/os/funcs.go create mode 100644 internal/os/mock_os/file.go create mode 100644 internal/os/mock_os/os.go create mode 100644 internal/os/os.go create mode 100644 internal/publicip/fs.go diff --git a/.golangci.yml b/.golangci.yml index 5490793e..1c3bbe65 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -10,6 +10,10 @@ issues: linters: - dupl - maligned + - path: internal/os/alias\.go + linters: + - gochecknoglobals + text: IsNotExist is a global variable linters: disable-all: true diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index dc5a4f5a..8232186d 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -5,7 +5,7 @@ import ( "fmt" "net" "net/http" - "os" + nativeos "os" "os/signal" "strings" "sync" @@ -22,6 +22,7 @@ import ( gluetunLogging "github.com/qdm12/gluetun/internal/logging" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/openvpn" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/publicip" "github.com/qdm12/gluetun/internal/routing" @@ -32,7 +33,6 @@ import ( "github.com/qdm12/gluetun/internal/updater" versionpkg "github.com/qdm12/gluetun/internal/version" "github.com/qdm12/golibs/command" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/network" ) @@ -50,21 +50,24 @@ func main() { buildInfo.Commit = commit buildInfo.BuildDate = buildDate ctx := context.Background() - os.Exit(_main(ctx, os.Args)) + args := nativeos.Args + os := os.New() + nativeos.Exit(_main(ctx, args, os)) } -func _main(background context.Context, args []string) int { //nolint:gocognit,gocyclo +//nolint:gocognit,gocyclo +func _main(background context.Context, args []string, os os.OS) int { if len(args) > 1 { // cli operation var err error switch args[1] { case "healthcheck": err = cli.HealthCheck(background) case "clientkey": - err = cli.ClientKey(args[2:]) + err = cli.ClientKey(args[2:], os.OpenFile) case "openvpnconfig": - err = cli.OpenvpnConfig() + err = cli.OpenvpnConfig(os) case "update": - err = cli.Update(args[2:]) + err = cli.Update(args[2:], os) default: err = fmt.Errorf("command %q is unknown", args[1]) } @@ -82,15 +85,14 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go httpClient := &http.Client{Timeout: clientTimeout} client := network.NewClient(clientTimeout) // Create configurators - fileManager := files.NewFileManager() - alpineConf := alpine.NewConfigurator(fileManager) - ovpnConf := openvpn.NewConfigurator(logger, fileManager) - dnsConf := dns.NewConfigurator(logger, client, fileManager) + alpineConf := alpine.NewConfigurator(os.OpenFile) + ovpnConf := openvpn.NewConfigurator(logger, os) + dnsConf := dns.NewConfigurator(logger, client, os.OpenFile) routingConf := routing.NewRouting(logger) - firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager) + firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile) streamMerger := command.NewStreamMerger() - paramsReader := params.NewReader(logger, fileManager) + paramsReader := params.NewReader(logger, os) fmt.Println(gluetunLogging.Splash(buildInfo)) printVersions(ctx, logger, map[string]func(ctx context.Context) (string, error){ @@ -106,8 +108,17 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go } logger.Info(allSettings.String()) + if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil { + logger.Error(err) + return 1 + } + if err := os.MkdirAll("/gluetun", 0644); err != nil { + logger.Error(err) + return 1 + } + // TODO run this in a loop or in openvpn to reload from file without restarting - storage := storage.New(logger) + storage := storage.New(logger, os) const updateServerFile = true allServers, err := storage.SyncServers(constants.GetAllServers(), updateServerFile) if err != nil { @@ -124,8 +135,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go logger.Error(err) return 1 } - err = fileManager.SetOwnership("/etc/unbound", uid, gid) - if err != nil { + + if err := os.Chown("/etc/unbound", uid, gid); err != nil { logger.Error(err) return 1 } @@ -219,7 +230,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, uid, gid, allServers, - ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel) + ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel) wg.Add(1) // wait for restartOpenvpn go openvpnLooper.Run(ctx, wg) @@ -236,7 +247,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go go unboundLooper.Run(ctx, wg, signalDNSReady) publicIPLooper := publicip.NewLooper( - client, logger, fileManager, allSettings.PublicIP, uid, gid) + client, logger, allSettings.PublicIP, uid, gid, os) wg.Add(1) go publicIPLooper.Run(ctx, wg) wg.Add(1) @@ -279,11 +290,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go // until openvpn is launched _, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable - signalsCh := make(chan os.Signal, 1) + signalsCh := make(chan nativeos.Signal, 1) signal.Notify(signalsCh, syscall.SIGINT, syscall.SIGTERM, - os.Interrupt, + nativeos.Interrupt, ) shutdownErrorsCount := 0 select { @@ -295,7 +306,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go } if allSettings.OpenVPN.Provider.PortForwarding.Enabled { logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath) - if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { + if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { logger.Error(err) shutdownErrorsCount++ } diff --git a/internal/alpine/alpine.go b/internal/alpine/alpine.go index 3c3e3fbe..d3b2a7ae 100644 --- a/internal/alpine/alpine.go +++ b/internal/alpine/alpine.go @@ -3,7 +3,7 @@ package alpine import ( "os/user" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" ) type Configurator interface { @@ -11,15 +11,15 @@ type Configurator interface { } type configurator struct { - fileManager files.FileManager - lookupUID func(uid string) (*user.User, error) - lookupUser func(username string) (*user.User, error) + openFile os.OpenFileFunc + lookupUID func(uid string) (*user.User, error) + lookupUser func(username string) (*user.User, error) } -func NewConfigurator(fileManager files.FileManager) Configurator { +func NewConfigurator(openFile os.OpenFileFunc) Configurator { return &configurator{ - fileManager: fileManager, - lookupUID: user.LookupId, - lookupUser: user.Lookup, + openFile: openFile, + lookupUID: user.LookupId, + lookupUser: user.Lookup, } } diff --git a/internal/alpine/users.go b/internal/alpine/users.go index 9ec3ed2d..d997bc7f 100644 --- a/internal/alpine/users.go +++ b/internal/alpine/users.go @@ -2,6 +2,7 @@ package alpine import ( "fmt" + "os" "os/user" ) @@ -26,14 +27,15 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str return "", fmt.Errorf("cannot create user: user with name %s already exists for ID %s instead of %d", username, u.Uid, uid) } - passwd, err := c.fileManager.ReadFile("/etc/passwd") + file, err := c.openFile("/etc/passwd", os.O_APPEND|os.O_WRONLY, 0644) if err != nil { return "", fmt.Errorf("cannot create user: %w", err) } - passwd = append(passwd, []byte(fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid))...) - - if err := c.fileManager.WriteToFile("/etc/passwd", passwd); err != nil { - return "", fmt.Errorf("cannot create user: %w", err) + s := fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid) + _, err = file.WriteString(s) + if err != nil { + _ = file.Close() + return "", err } - return username, nil + return username, file.Close() } diff --git a/internal/cli/cli.go b/internal/cli/cli.go index c50f3b02..f31d3ce2 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -4,29 +4,40 @@ import ( "context" "flag" "fmt" + "io/ioutil" "net/http" "strings" "time" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/healthcheck" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/updater" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" ) -func ClientKey(args []string) error { +func ClientKey(args []string, openFile os.OpenFileFunc) error { flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError) filepath := flagSet.String("path", string(constants.ClientKey), "file path to the client.key file") if err := flagSet.Parse(args); err != nil { return err } - fileManager := files.NewFileManager() - data, err := fileManager.ReadFile(*filepath) + file, err := openFile(*filepath, os.O_RDONLY, 0) + if err != nil { + return err + } + data, err := ioutil.ReadAll(file) + if err != nil { + _ = file.Close() + return err + } + if err := file.Close(); err != nil { + return err + } if err != nil { return err } @@ -49,17 +60,17 @@ func HealthCheck(ctx context.Context) error { return healthchecker.Check(ctx, url) } -func OpenvpnConfig() error { +func OpenvpnConfig(os os.OS) error { logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel) if err != nil { return err } - paramsReader := params.NewReader(logger, files.NewFileManager()) + paramsReader := params.NewReader(logger, os) allSettings, err := settings.GetAllSettings(paramsReader) if err != nil { return err } - allServers, err := storage.New(logger).SyncServers(constants.GetAllServers(), false) + allServers, err := storage.New(logger, os).SyncServers(constants.GetAllServers(), false) if err != nil { return err } @@ -81,7 +92,7 @@ func OpenvpnConfig() error { return nil } -func Update(args []string) error { +func Update(args []string, os os.OS) error { options := settings.Updater{CLI: true} var flushToFile bool flagSet := flag.NewFlagSet("update", flag.ExitOnError) @@ -110,7 +121,7 @@ func Update(args []string) error { ctx := context.Background() const clientTimeout = 10 * time.Second httpClient := &http.Client{Timeout: clientTimeout} - storage := storage.New(logger) + storage := storage.New(logger, os) const writeSync = false currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync) if err != nil { diff --git a/internal/constants/permissions.go b/internal/constants/permissions.go deleted file mode 100644 index 1cb738fd..00000000 --- a/internal/constants/permissions.go +++ /dev/null @@ -1,8 +0,0 @@ -package constants - -import "os" - -const ( - UserReadPermission os.FileMode = 0400 - AllReadWritePermissions os.FileMode = 0666 -) diff --git a/internal/dns/conf.go b/internal/dns/conf.go index 8da871f3..3ab1b79a 100644 --- a/internal/dns/conf.go +++ b/internal/dns/conf.go @@ -8,8 +8,8 @@ import ( "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/settings" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/network" ) @@ -21,11 +21,29 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN for _, warning := range warnings { c.logger.Warn(warning) } - return c.fileManager.WriteLinesToFile( - string(constants.UnboundConf), - lines, - files.Ownership(uid, gid), - files.Permissions(constants.UserReadPermission)) + + const filepath = string(constants.UnboundConf) + file, err := c.openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0400) + if err != nil { + return err + } + + _, err = file.WriteString(strings.Join(lines, "\n")) + if err != nil { + _ = file.Close() + return err + } + + if err := file.Chown(uid, gid); err != nil { + _ = file.Close() + return err + } + + if err := file.Close(); err != nil { + return err + } + + return nil } // MakeUnboundConf generates an Unbound configuration from the user provided settings. diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 4e3da2bc..6a5d522c 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -5,9 +5,9 @@ import ( "io" "net" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/command" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/network" ) @@ -24,19 +24,20 @@ type Configurator interface { } type configurator struct { - logger logging.Logger - client network.Client - fileManager files.FileManager - commander command.Commander - lookupIP func(host string) ([]net.IP, error) + logger logging.Logger + client network.Client + openFile os.OpenFileFunc + commander command.Commander + lookupIP func(host string) ([]net.IP, error) } -func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator { +func NewConfigurator(logger logging.Logger, client network.Client, + openFile os.OpenFileFunc) Configurator { return &configurator{ - logger: logger.WithPrefix("dns configurator: "), - client: client, - fileManager: fileManager, - commander: command.NewCommander(), - lookupIP: net.LookupIP, + logger: logger.WithPrefix("dns configurator: "), + client: client, + openFile: openFile, + commander: command.NewCommander(), + lookupIP: net.LookupIP, } } diff --git a/internal/dns/nameserver.go b/internal/dns/nameserver.go index ab0d4f10..61ce16e9 100644 --- a/internal/dns/nameserver.go +++ b/internal/dns/nameserver.go @@ -2,10 +2,12 @@ package dns import ( "context" + "io/ioutil" "net" "strings" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/os" ) // UseDNSInternally is to change the Go program DNS only. @@ -23,10 +25,16 @@ func (c *configurator) UseDNSInternally(ip net.IP) { // UseDNSSystemWide changes the nameserver to use for DNS system wide. func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error { c.logger.Info("using DNS address %s system wide", ip.String()) - data, err := c.fileManager.ReadFile(string(constants.ResolvConf)) + const filepath = string(constants.ResolvConf) + file, err := c.openFile(filepath, os.O_RDWR, 0644) if err != nil { return err } + data, err := ioutil.ReadAll(file) + if err != nil { + _ = file.Close() + return err + } s := strings.TrimSuffix(string(data), "\n") lines := strings.Split(s, "\n") if len(lines) == 1 && lines[0] == "" { @@ -44,6 +52,11 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error { if !found { lines = append(lines, "nameserver "+ip.String()) } - data = []byte(strings.Join(lines, "\n")) - return c.fileManager.WriteToFile(string(constants.ResolvConf), data) + s = strings.Join(lines, "\n") + _, err = file.WriteString(s) + if err != nil { + _ = file.Close() + return err + } + return file.Close() } diff --git a/internal/dns/nameserver_test.go b/internal/dns/nameserver_test.go index ace1b333..1c88aa80 100644 --- a/internal/dns/nameserver_test.go +++ b/internal/dns/nameserver_test.go @@ -2,12 +2,14 @@ package dns import ( "fmt" + "io" "net" "testing" "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/golibs/files/mock_files" + "github.com/qdm12/gluetun/internal/os" + "github.com/qdm12/gluetun/internal/os/mock_os" "github.com/qdm12/golibs/logging/mock_logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,30 +19,36 @@ func Test_UseDNSSystemWide(t *testing.T) { t.Parallel() tests := map[string]struct { data []byte - writtenData []byte + writtenData string + openErr error readErr error writeErr error + closeErr error err error }{ "no data": { - writtenData: []byte("nameserver 127.0.0.1"), + writtenData: "nameserver 127.0.0.1", + }, + "open error": { + openErr: fmt.Errorf("error"), + err: fmt.Errorf("error"), }, "read error": { readErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, "write error": { - writtenData: []byte("nameserver 127.0.0.1"), + writtenData: "nameserver 127.0.0.1", writeErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, "lines without nameserver": { data: []byte("abc\ndef\n"), - writtenData: []byte("abc\ndef\nnameserver 127.0.0.1"), + writtenData: "abc\ndef\nnameserver 127.0.0.1", }, "lines with nameserver": { data: []byte("abc\nnameserver abc def\ndef\n"), - writtenData: []byte("abc\nnameserver 127.0.0.1\ndef"), + writtenData: "abc\nnameserver 127.0.0.1\ndef", }, } for name, tc := range tests { @@ -49,18 +57,43 @@ func Test_UseDNSSystemWide(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - fileManager := mock_files.NewMockFileManager(mockCtrl) - fileManager.EXPECT().ReadFile(string(constants.ResolvConf)). - Return(tc.data, tc.readErr) - if tc.readErr == nil { - fileManager.EXPECT().WriteToFile(string(constants.ResolvConf), tc.writtenData). - Return(tc.writeErr) + + file := mock_os.NewMockFile(mockCtrl) + if tc.openErr == nil { + firstReadCall := file.EXPECT(). + Read(gomock.AssignableToTypeOf([]byte{})). + DoAndReturn(func(b []byte) (int, error) { + copy(b, tc.data) + return len(tc.data), nil + }) + readErr := tc.readErr + if readErr == nil { + readErr = io.EOF + } + finalReadCall := file.EXPECT(). + Read(gomock.AssignableToTypeOf([]byte{})). + Return(0, readErr).After(firstReadCall) + if tc.readErr == nil { + writeCall := file.EXPECT().WriteString(tc.writtenData). + Return(0, tc.writeErr).After(finalReadCall) + file.EXPECT().Close().Return(tc.closeErr).After(writeCall) + } else { + file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall) + } } + + openFile := func(name string, flag int, perm os.FileMode) (os.File, error) { + assert.Equal(t, string(constants.ResolvConf), name) + assert.Equal(t, os.O_RDWR, flag) + assert.Equal(t, os.FileMode(0644), perm) + return file, tc.openErr + } + logger := mock_logging.NewMockLogger(mockCtrl) logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1") c := &configurator{ - fileManager: fileManager, - logger: logger, + openFile: openFile, + logger: logger, } err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false) if tc.err != nil { diff --git a/internal/dns/roots.go b/internal/dns/roots.go index 3fc190f1..103c81da 100644 --- a/internal/dns/roots.go +++ b/internal/dns/roots.go @@ -4,37 +4,46 @@ import ( "context" "fmt" "net/http" + "os" "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/golibs/files" ) func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error { - c.logger.Info("downloading root hints from %s", constants.NamedRootURL) - content, status, err := c.client.Get(ctx, string(constants.NamedRootURL)) - if err != nil { - return err - } else if status != http.StatusOK { - return fmt.Errorf("HTTP status code is %d for %s", status, constants.NamedRootURL) - } - return c.fileManager.WriteToFile( - string(constants.RootHints), - content, - files.Ownership(uid, gid), - files.Permissions(constants.UserReadPermission)) + return c.downloadAndSave(ctx, "root hints", + string(constants.NamedRootURL), string(constants.RootHints), uid, gid) } func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error { - c.logger.Info("downloading root key from %s", constants.RootKeyURL) - content, status, err := c.client.Get(ctx, string(constants.RootKeyURL)) + return c.downloadAndSave(ctx, "root key", + string(constants.RootKeyURL), string(constants.RootKey), uid, gid) +} + +func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error { + c.logger.Info("downloading %s from %s", logName, url) + content, status, err := c.client.Get(ctx, url) if err != nil { return err } else if status != http.StatusOK { - return fmt.Errorf("HTTP status code is %d for %s", status, constants.RootKeyURL) + return fmt.Errorf("HTTP status code is %d for %s", status, url) } - return c.fileManager.WriteToFile( - string(constants.RootKey), - content, - files.Ownership(uid, gid), - files.Permissions(constants.UserReadPermission)) + + file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400) + if err != nil { + return err + } + + _, err = file.Write(content) + if err != nil { + _ = file.Close() + return err + } + + err = file.Chown(uid, gid) + if err != nil { + _ = file.Close() + return err + } + + return file.Close() } diff --git a/internal/dns/roots_test.go b/internal/dns/roots_test.go index e8a4237b..d2f5be01 100644 --- a/internal/dns/roots_test.go +++ b/internal/dns/roots_test.go @@ -2,27 +2,31 @@ package dns import ( "context" + "errors" "fmt" "net/http" "testing" "github.com/golang/mock/gomock" "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/golibs/files" - "github.com/qdm12/golibs/files/mock_files" + "github.com/qdm12/gluetun/internal/os" + "github.com/qdm12/gluetun/internal/os/mock_os" "github.com/qdm12/golibs/logging/mock_logging" "github.com/qdm12/golibs/network/mock_network" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func Test_DownloadRootHints(t *testing.T) { //nolint:dupl +func Test_downloadAndSave(t *testing.T) { t.Parallel() tests := map[string]struct { content []byte status int clientErr error + openErr error writeErr error + chownErr error + closeErr error err error }{ "no data": { @@ -36,11 +40,26 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl clientErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, + "open error": { + status: http.StatusOK, + openErr: fmt.Errorf("error"), + err: fmt.Errorf("error"), + }, "write error": { status: http.StatusOK, writeErr: fmt.Errorf("error"), err: fmt.Errorf("error"), }, + "chown error": { + status: http.StatusOK, + chownErr: fmt.Errorf("error"), + err: fmt.Errorf("error"), + }, + "close error": { + status: http.StatusOK, + closeErr: fmt.Errorf("error"), + err: fmt.Errorf("error"), + }, "data": { content: []byte("content"), status: http.StatusOK, @@ -52,23 +71,49 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl t.Parallel() mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() + ctx := context.Background() logger := mock_logging.NewMockLogger(mockCtrl) - logger.EXPECT().Info("downloading root hints from %s", constants.NamedRootURL) + logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL)) client := mock_network.NewMockClient(mockCtrl) client.EXPECT().Get(ctx, string(constants.NamedRootURL)). Return(tc.content, tc.status, tc.clientErr) - fileManager := mock_files.NewMockFileManager(mockCtrl) - if tc.clientErr == nil && tc.status == http.StatusOK { - fileManager.EXPECT().WriteToFile( - string(constants.RootHints), - tc.content, - gomock.AssignableToTypeOf(files.Ownership(0, 0)), - gomock.AssignableToTypeOf(files.Ownership(0, 0))). - Return(tc.writeErr) + + openFile := func(name string, flag int, perm os.FileMode) (os.File, error) { + return nil, nil } - c := &configurator{logger: logger, client: client, fileManager: fileManager} - err := c.DownloadRootHints(ctx, 1000, 1000) + + if tc.clientErr == nil && tc.status == http.StatusOK { + file := mock_os.NewMockFile(mockCtrl) + if tc.openErr == nil { + writeCall := file.EXPECT().Write(tc.content). + Return(0, tc.writeErr) + if tc.writeErr != nil { + file.EXPECT().Close().Return(tc.closeErr).After(writeCall) + } else { + chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall) + file.EXPECT().Close().Return(tc.closeErr).After(chownCall) + } + } + + openFile = func(name string, flag int, perm os.FileMode) (os.File, error) { + assert.Equal(t, string(constants.RootHints), name) + assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag) + assert.Equal(t, os.FileMode(0400), perm) + return file, tc.openErr + } + } + + c := &configurator{ + logger: logger, + client: client, + openFile: openFile, + } + + err := c.downloadAndSave(ctx, "root hints", + string(constants.NamedRootURL), string(constants.RootHints), + 1000, 1000) + if tc.err != nil { require.Error(t, err) assert.Equal(t, tc.err.Error(), err.Error()) @@ -79,65 +124,44 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl } } -func Test_DownloadRootKey(t *testing.T) { //nolint:dupl +func Test_DownloadRootHints(t *testing.T) { t.Parallel() - tests := map[string]struct { - content []byte - status int - clientErr error - writeErr error - err error - }{ - "no data": { - status: http.StatusOK, - }, - "bad status": { - status: http.StatusBadRequest, - err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"), //nolint:lll - }, - "client error": { - clientErr: fmt.Errorf("error"), - err: fmt.Errorf("error"), - }, - "write error": { - status: http.StatusOK, - writeErr: fmt.Errorf("error"), - err: fmt.Errorf("error"), - }, - "data": { - content: []byte("content"), - status: http.StatusOK, - }, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - ctx := context.Background() - logger := mock_logging.NewMockLogger(mockCtrl) - logger.EXPECT().Info("downloading root key from %s", constants.RootKeyURL) - client := mock_network.NewMockClient(mockCtrl) - client.EXPECT().Get(ctx, string(constants.RootKeyURL)). - Return(tc.content, tc.status, tc.clientErr) - fileManager := mock_files.NewMockFileManager(mockCtrl) - if tc.clientErr == nil && tc.status == http.StatusOK { - fileManager.EXPECT().WriteToFile( - string(constants.RootKey), - tc.content, - gomock.AssignableToTypeOf(files.Ownership(0, 0)), - gomock.AssignableToTypeOf(files.Ownership(0, 0)), - ).Return(tc.writeErr) - } - c := &configurator{logger: logger, client: client, fileManager: fileManager} - err := c.DownloadRootKey(ctx, 1000, 1001) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - }) + mockCtrl := gomock.NewController(t) + + ctx := context.Background() + logger := mock_logging.NewMockLogger(mockCtrl) + logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL)) + client := mock_network.NewMockClient(mockCtrl) + client.EXPECT().Get(ctx, string(constants.NamedRootURL)). + Return(nil, http.StatusOK, errors.New("test")) + + c := &configurator{ + logger: logger, + client: client, } + + err := c.DownloadRootHints(ctx, 1000, 1000) + require.Error(t, err) + assert.Equal(t, "test", err.Error()) +} + +func Test_DownloadRootKey(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + + ctx := context.Background() + logger := mock_logging.NewMockLogger(mockCtrl) + logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL)) + client := mock_network.NewMockClient(mockCtrl) + client.EXPECT().Get(ctx, string(constants.RootKeyURL)). + Return(nil, http.StatusOK, errors.New("test")) + + c := &configurator{ + logger: logger, + client: client, + } + + err := c.DownloadRootKey(ctx, 1000, 1000) + require.Error(t, err) + assert.Equal(t, "test", err.Error()) } diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index d126c43a..9e46b38c 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -6,9 +6,9 @@ import ( "sync" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/golibs/command" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" ) @@ -29,7 +29,7 @@ type configurator struct { //nolint:maligned commander command.Commander logger logging.Logger routing routing.Routing - fileManager files.FileManager // for custom iptables rules + openFile os.OpenFileFunc // for custom iptables rules iptablesMutex sync.Mutex debug bool defaultInterface string @@ -47,12 +47,12 @@ type configurator struct { //nolint:maligned } // NewConfigurator creates a new Configurator instance. -func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator { +func NewConfigurator(logger logging.Logger, routing routing.Routing, openFile os.OpenFileFunc) Configurator { return &configurator{ commander: command.NewCommander(), logger: logger.WithPrefix("firewall: "), routing: routing, - fileManager: fileManager, + openFile: openFile, allowedInputPorts: make(map[uint16]string), } } diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index 5e380b28..1d224b45 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -3,7 +3,9 @@ package firewall import ( "context" "fmt" + "io/ioutil" "net" + "os" "strings" "github.com/qdm12/gluetun/internal/models" @@ -150,14 +152,18 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port } func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error { - exists, err := c.fileManager.FileExists(filepath) - if err != nil { - return err - } else if !exists { + file, err := c.openFile(filepath, os.O_RDONLY, 0) + if os.IsNotExist(err) { return nil + } else if err != nil { + return err } - b, err := c.fileManager.ReadFile(filepath) + b, err := ioutil.ReadAll(file) if err != nil { + _ = file.Close() + return err + } + if err := file.Close(); err != nil { return err } lines := strings.Split(string(b), "\n") diff --git a/internal/openvpn/auth.go b/internal/openvpn/auth.go index 231b8967..54ca4461 100644 --- a/internal/openvpn/auth.go +++ b/internal/openvpn/auth.go @@ -1,31 +1,68 @@ package openvpn import ( + "io/ioutil" + "os" "strings" "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/golibs/files" ) // WriteAuthFile writes the OpenVPN auth file to disk with the right permissions. func (c *configurator) WriteAuthFile(user, password string, uid, gid int) error { - exists, err := c.fileManager.FileExists(string(constants.OpenVPNAuthConf)) - if err != nil { + const filepath = string(constants.OpenVPNAuthConf) + file, err := c.os.OpenFile(filepath, os.O_RDONLY, 0) + + if err != nil && !os.IsNotExist(err) { return err - } else if exists { - data, err := c.fileManager.ReadFile(string(constants.OpenVPNAuthConf)) + } + + if os.IsNotExist(err) { + file, err = c.os.OpenFile(filepath, os.O_WRONLY|os.O_CREATE, 0400) if err != nil { return err } - lines := strings.Split(string(data), "\n") - if len(lines) > 1 && lines[0] == user && lines[1] == password { - return nil + _, err = file.WriteString(user + "\n" + password) + if err != nil { + _ = file.Close() + return err } - c.logger.Info("username and password changed", constants.OpenVPNAuthConf) + err = file.Chown(uid, gid) + if err != nil { + _ = file.Close() + return err + } + return file.Close() } - return c.fileManager.WriteLinesToFile( - string(constants.OpenVPNAuthConf), - []string{user, password}, - files.Ownership(uid, gid), - files.Permissions(constants.UserReadPermission)) + + data, err := ioutil.ReadAll(file) + if err != nil { + _ = file.Close() + return err + } + if err := file.Close(); err != nil { + return err + } + + lines := strings.Split(string(data), "\n") + if len(lines) > 1 && lines[0] == user && lines[1] == password { + return nil + } + + c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf) + file, err = c.os.OpenFile(filepath, os.O_TRUNC|os.O_WRONLY, 0400) + if err != nil { + return err + } + _, err = file.WriteString(user + "\n" + password) + if err != nil { + _ = file.Close() + return err + } + err = file.Chown(uid, gid) + if err != nil { + _ = file.Close() + return err + } + return file.Close() } diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index ee415ffd..86ae927f 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -4,17 +4,18 @@ import ( "context" "net" "net/http" + "strings" "sync" "time" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/golibs/command" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" ) @@ -43,7 +44,7 @@ type looper struct { // Other objects logger, pfLogger logging.Logger client *http.Client - fileManager files.FileManager + openFile os.OpenFileFunc streamMerger command.StreamMerger cancel context.CancelFunc // Internal channels and locks @@ -57,7 +58,7 @@ type looper struct { func NewLooper(settings settings.OpenVPN, username string, uid, gid int, allServers models.AllServers, conf Configurator, fw firewall.Configurator, routing routing.Routing, - logger logging.Logger, client *http.Client, fileManager files.FileManager, + logger logging.Logger, client *http.Client, openFile os.OpenFileFunc, streamMerger command.StreamMerger, cancel context.CancelFunc) Looper { return &looper{ state: state{ @@ -74,7 +75,7 @@ func NewLooper(settings settings.OpenVPN, logger: logger.WithPrefix("openvpn: "), pfLogger: logger.WithPrefix("port forwarding: "), client: client, - fileManager: fileManager, + openFile: openFile, streamMerger: streamMerger, cancel: cancel, start: make(chan struct{}), @@ -115,8 +116,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { settings.Auth, settings.Provider.ExtraConfigOptions, ) - if err := l.fileManager.WriteLinesToFile(string(constants.OpenVPNConf), lines, - files.Ownership(l.uid, l.gid), files.Permissions(constants.UserReadPermission)); err != nil { + + if err := writeOpenvpnConf(lines, l.openFile); err != nil { l.logger.Error(err) l.cancel() return @@ -239,6 +240,22 @@ func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup, return settings.Provider.PortForwarding.Filepath } providerConf.PortForward(ctx, - client, l.fileManager, l.pfLogger, + client, l.openFile, l.pfLogger, gateway, l.fw, syncState) } + +func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error { + const filepath = string(constants.OpenVPNConf) + file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return err + } + _, err = file.WriteString(strings.Join(lines, "\n")) + if err != nil { + return err + } + if err := file.Close(); err != nil { + return err + } + return nil +} diff --git a/internal/openvpn/openvpn.go b/internal/openvpn/openvpn.go index a25fc042..d7e1e68e 100644 --- a/internal/openvpn/openvpn.go +++ b/internal/openvpn/openvpn.go @@ -3,10 +3,9 @@ package openvpn import ( "context" "io" - "os" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/command" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" "golang.org/x/sys/unix" ) @@ -20,21 +19,19 @@ type Configurator interface { } type configurator struct { - fileManager files.FileManager - logger logging.Logger - commander command.Commander - openFile func(name string, flag int, perm os.FileMode) (*os.File, error) - mkDev func(major uint32, minor uint32) uint64 - mkNod func(path string, mode uint32, dev int) error + logger logging.Logger + commander command.Commander + os os.OS + mkDev func(major uint32, minor uint32) uint64 + mkNod func(path string, mode uint32, dev int) error } -func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator { +func NewConfigurator(logger logging.Logger, os os.OS) Configurator { return &configurator{ - fileManager: fileManager, - logger: logger.WithPrefix("openvpn configurator: "), - commander: command.NewCommander(), - openFile: os.OpenFile, - mkDev: unix.Mkdev, - mkNod: unix.Mknod, + logger: logger.WithPrefix("openvpn configurator: "), + commander: command.NewCommander(), + os: os, + mkDev: unix.Mkdev, + mkNod: unix.Mknod, } } diff --git a/internal/openvpn/tun.go b/internal/openvpn/tun.go index bc760bbf..fbb3c8eb 100644 --- a/internal/openvpn/tun.go +++ b/internal/openvpn/tun.go @@ -11,7 +11,7 @@ import ( // CheckTUN checks the tunnel device is present and accessible. func (c *configurator) CheckTUN() error { c.logger.Info("checking for device %s", constants.TunnelDevice) - f, err := c.openFile(string(constants.TunnelDevice), os.O_RDWR, 0) + f, err := c.os.OpenFile(string(constants.TunnelDevice), os.O_RDWR, 0) if err != nil { return fmt.Errorf("TUN device is not available: %w", err) } @@ -23,9 +23,10 @@ func (c *configurator) CheckTUN() error { func (c *configurator) CreateTUN() error { c.logger.Info("creating %s", constants.TunnelDevice) - if err := c.fileManager.CreateDir("/dev/net"); err != nil { + if err := c.os.MkdirAll("/dev/net", 0751); err != nil { return err } + const ( major = 10 minor = 200 @@ -34,8 +35,17 @@ func (c *configurator) CreateTUN() error { if err := c.mkNod(string(constants.TunnelDevice), unix.S_IFCHR, int(dev)); err != nil { return err } - if err := c.fileManager.SetUserPermissions(string(constants.TunnelDevice), 0666); err != nil { + + const filepath = string(constants.TunnelDevice) + file, err := c.os.OpenFile(filepath, os.O_WRONLY, 0666) + if err != nil { return err } - return nil + const readWriteAllPerms os.FileMode = 0666 + if err := file.Chmod(readWriteAllPerms); err != nil { + _ = file.Close() + return err + } + + return file.Close() } diff --git a/internal/os/alias.go b/internal/os/alias.go new file mode 100644 index 00000000..24062954 --- /dev/null +++ b/internal/os/alias.go @@ -0,0 +1,9 @@ +package os + +import nativeos "os" + +// Aliases used for convenience so "os" does not have to be imported + +type FileMode nativeos.FileMode + +var IsNotExist = nativeos.IsNotExist diff --git a/internal/os/constants.go b/internal/os/constants.go new file mode 100644 index 00000000..4ec31b81 --- /dev/null +++ b/internal/os/constants.go @@ -0,0 +1,16 @@ +package os + +import ( + nativeos "os" +) + +// Constants used for convenience so "os" does not have to be imported + +//nolint:golint +const ( + O_CREATE = nativeos.O_CREATE + O_TRUNC = nativeos.O_TRUNC + O_WRONLY = nativeos.O_WRONLY + O_RDONLY = nativeos.O_RDONLY + O_RDWR = nativeos.O_RDWR +) diff --git a/internal/os/file.go b/internal/os/file.go new file mode 100644 index 00000000..41c9e2d6 --- /dev/null +++ b/internal/os/file.go @@ -0,0 +1,15 @@ +package os + +import ( + "io" + nativeos "os" +) + +//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . File + +type File interface { + io.ReadWriteCloser + WriteString(s string) (int, error) + Chown(uid, gid int) error + Chmod(mode nativeos.FileMode) error +} diff --git a/internal/os/funcs.go b/internal/os/funcs.go new file mode 100644 index 00000000..01516ec9 --- /dev/null +++ b/internal/os/funcs.go @@ -0,0 +1,10 @@ +package os + +import ( + nativeos "os" +) + +type OpenFileFunc func(name string, flag int, perm FileMode) (File, error) +type MkdirAllFunc func(name string, perm nativeos.FileMode) error +type RemoveFunc func(name string) error +type ChownFunc func(name string, uid int, gid int) error diff --git a/internal/os/mock_os/file.go b/internal/os/mock_os/file.go new file mode 100644 index 00000000..382d9276 --- /dev/null +++ b/internal/os/mock_os/file.go @@ -0,0 +1,121 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/os (interfaces: File) + +// Package mock_os is a generated GoMock package. +package mock_os + +import ( + gomock "github.com/golang/mock/gomock" + os "os" + reflect "reflect" +) + +// MockFile is a mock of File interface +type MockFile struct { + ctrl *gomock.Controller + recorder *MockFileMockRecorder +} + +// MockFileMockRecorder is the mock recorder for MockFile +type MockFileMockRecorder struct { + mock *MockFile +} + +// NewMockFile creates a new mock instance +func NewMockFile(ctrl *gomock.Controller) *MockFile { + mock := &MockFile{ctrl: ctrl} + mock.recorder = &MockFileMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockFile) EXPECT() *MockFileMockRecorder { + return m.recorder +} + +// Chmod mocks base method +func (m *MockFile) Chmod(arg0 os.FileMode) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Chmod", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Chmod indicates an expected call of Chmod +func (mr *MockFileMockRecorder) Chmod(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chmod", reflect.TypeOf((*MockFile)(nil).Chmod), arg0) +} + +// Chown mocks base method +func (m *MockFile) Chown(arg0, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Chown", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Chown indicates an expected call of Chown +func (mr *MockFileMockRecorder) Chown(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chown", reflect.TypeOf((*MockFile)(nil).Chown), arg0, arg1) +} + +// Close mocks base method +func (m *MockFile) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockFileMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockFile)(nil).Close)) +} + +// Read mocks base method +func (m *MockFile) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockFileMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockFile)(nil).Read), arg0) +} + +// Write mocks base method +func (m *MockFile) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (mr *MockFileMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockFile)(nil).Write), arg0) +} + +// WriteString mocks base method +func (m *MockFile) WriteString(arg0 string) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteString", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WriteString indicates an expected call of WriteString +func (mr *MockFileMockRecorder) WriteString(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteString", reflect.TypeOf((*MockFile)(nil).WriteString), arg0) +} diff --git a/internal/os/mock_os/os.go b/internal/os/mock_os/os.go new file mode 100644 index 00000000..1e149bf9 --- /dev/null +++ b/internal/os/mock_os/os.go @@ -0,0 +1,121 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/os (interfaces: OS) + +// Package mock_os is a generated GoMock package. +package mock_os + +import ( + gomock "github.com/golang/mock/gomock" + os "github.com/qdm12/gluetun/internal/os" + os0 "os" + reflect "reflect" +) + +// MockOS is a mock of OS interface +type MockOS struct { + ctrl *gomock.Controller + recorder *MockOSMockRecorder +} + +// MockOSMockRecorder is the mock recorder for MockOS +type MockOSMockRecorder struct { + mock *MockOS +} + +// NewMockOS creates a new mock instance +func NewMockOS(ctrl *gomock.Controller) *MockOS { + mock := &MockOS{ctrl: ctrl} + mock.recorder = &MockOSMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockOS) EXPECT() *MockOSMockRecorder { + return m.recorder +} + +// Chown mocks base method +func (m *MockOS) Chown(arg0 string, arg1, arg2 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Chown", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// Chown indicates an expected call of Chown +func (mr *MockOSMockRecorder) Chown(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chown", reflect.TypeOf((*MockOS)(nil).Chown), arg0, arg1, arg2) +} + +// MkdirAll mocks base method +func (m *MockOS) MkdirAll(arg0 string, arg1 os.FileMode) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MkdirAll", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// MkdirAll indicates an expected call of MkdirAll +func (mr *MockOSMockRecorder) MkdirAll(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MkdirAll", reflect.TypeOf((*MockOS)(nil).MkdirAll), arg0, arg1) +} + +// OpenFile mocks base method +func (m *MockOS) OpenFile(arg0 string, arg1 int, arg2 os.FileMode) (os.File, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenFile", arg0, arg1, arg2) + ret0, _ := ret[0].(os.File) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenFile indicates an expected call of OpenFile +func (mr *MockOSMockRecorder) OpenFile(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenFile", reflect.TypeOf((*MockOS)(nil).OpenFile), arg0, arg1, arg2) +} + +// Remove mocks base method +func (m *MockOS) Remove(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Remove", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Remove indicates an expected call of Remove +func (mr *MockOSMockRecorder) Remove(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockOS)(nil).Remove), arg0) +} + +// Stat mocks base method +func (m *MockOS) Stat(arg0 string) (os0.FileInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stat", arg0) + ret0, _ := ret[0].(os0.FileInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Stat indicates an expected call of Stat +func (mr *MockOSMockRecorder) Stat(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stat", reflect.TypeOf((*MockOS)(nil).Stat), arg0) +} + +// Unsetenv mocks base method +func (m *MockOS) Unsetenv(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unsetenv", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unsetenv indicates an expected call of Unsetenv +func (mr *MockOSMockRecorder) Unsetenv(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsetenv", reflect.TypeOf((*MockOS)(nil).Unsetenv), arg0) +} diff --git a/internal/os/os.go b/internal/os/os.go new file mode 100644 index 00000000..8c686704 --- /dev/null +++ b/internal/os/os.go @@ -0,0 +1,39 @@ +package os + +import nativeos "os" + +//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . OS + +type OS interface { + OpenFile(name string, flag int, perm FileMode) (File, error) + MkdirAll(name string, perm FileMode) error + Remove(name string) error + Chown(name string, uid int, gid int) error + Unsetenv(key string) error + Stat(name string) (nativeos.FileInfo, error) +} + +func New() OS { + return &os{} +} + +type os struct{} + +func (o *os) OpenFile(name string, flag int, perm FileMode) (File, error) { + return nativeos.OpenFile(name, flag, nativeos.FileMode(perm)) +} +func (o *os) MkdirAll(name string, perm FileMode) error { + return nativeos.MkdirAll(name, nativeos.FileMode(perm)) +} +func (o *os) Remove(name string) error { + return nativeos.Remove(name) +} +func (o *os) Chown(name string, uid, gid int) error { + return nativeos.Chown(name, uid, gid) +} +func (o *os) Unsetenv(key string) error { + return nativeos.Unsetenv(key) +} +func (o *os) Stat(name string) (nativeos.FileInfo, error) { + return nativeos.Stat(name) +} diff --git a/internal/params/cyberghost.go b/internal/params/cyberghost.go index 719acd23..b9c6b914 100644 --- a/internal/params/cyberghost.go +++ b/internal/params/cyberghost.go @@ -3,6 +3,8 @@ package params import ( "encoding/pem" "fmt" + "io/ioutil" + "os" "strings" "github.com/qdm12/gluetun/internal/constants" @@ -32,10 +34,19 @@ func (p *reader) GetCyberghostClientKey() (clientKey string, err error) { } else if len(clientKey) > 0 { return clientKey, nil } - content, err := p.fileManager.ReadFile(string(constants.ClientKey)) + const filepath = string(constants.ClientKey) + file, err := p.os.OpenFile(filepath, os.O_RDONLY, 0) if err != nil { return "", err } + content, err := ioutil.ReadAll(file) + if err != nil { + _ = file.Close() + return "", err + } + if err := file.Close(); err != nil { + return "", err + } return extractClientKey(content) } @@ -55,10 +66,19 @@ func extractClientKey(b []byte) (key string, err error) { // GetCyberghostClientCertificate obtains the client certificate to use for openvpn from the // file at /gluetun/client.crt. func (p *reader) GetCyberghostClientCertificate() (clientCertificate string, err error) { - content, err := p.fileManager.ReadFile(string(constants.ClientCertificate)) + const filepath = string(constants.ClientCertificate) + file, err := p.os.OpenFile(filepath, os.O_RDONLY, 0) if err != nil { return "", err } + content, err := ioutil.ReadAll(file) + if err != nil { + _ = file.Close() + return "", err + } + if err := file.Close(); err != nil { + return "", err + } return extractClientCertificate(content) } diff --git a/internal/params/openvpn.go b/internal/params/openvpn.go index 5ba3cda6..caa28c31 100644 --- a/internal/params/openvpn.go +++ b/internal/params/openvpn.go @@ -11,7 +11,7 @@ import ( // GetUser obtains the user to use to connect to the VPN servers. func (r *reader) GetUser() (s string, err error) { defer func() { - unsetenvErr := r.unsetEnv("USER") + unsetenvErr := r.os.Unsetenv("USER") if err == nil { err = unsetenvErr } @@ -22,7 +22,7 @@ func (r *reader) GetUser() (s string, err error) { // GetPassword obtains the password to use to connect to the VPN servers. func (r *reader) GetPassword(required bool) (s string, err error) { defer func() { - unsetenvErr := r.unsetEnv("PASSWORD") + unsetenvErr := r.os.Unsetenv("PASSWORD") if err == nil { err = unsetenvErr } diff --git a/internal/params/params.go b/internal/params/params.go index c1b40709..46ad9119 100644 --- a/internal/params/params.go +++ b/internal/params/params.go @@ -2,11 +2,10 @@ package params import ( "net" - "os" "time" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" libparams "github.com/qdm12/golibs/params" "github.com/qdm12/golibs/verification" @@ -128,22 +127,20 @@ type Reader interface { } type reader struct { - envParams libparams.EnvParams - logger logging.Logger - verifier verification.Verifier - unsetEnv func(key string) error - fileManager files.FileManager + envParams libparams.EnvParams + logger logging.Logger + verifier verification.Verifier + os os.OS } // Newreader returns a paramsReadeer object to read parameters from // environment variables. -func NewReader(logger logging.Logger, fileManager files.FileManager) Reader { +func NewReader(logger logging.Logger, os os.OS) Reader { return &reader{ - envParams: libparams.NewEnvParams(), - logger: logger, - verifier: verification.NewVerifier(), - unsetEnv: os.Unsetenv, - fileManager: fileManager, + envParams: libparams.NewEnvParams(), + logger: logger, + verifier: verification.NewVerifier(), + os: os, } } diff --git a/internal/params/shadowsocks.go b/internal/params/shadowsocks.go index 94a27fb0..b2703572 100644 --- a/internal/params/shadowsocks.go +++ b/internal/params/shadowsocks.go @@ -36,7 +36,7 @@ func (r *reader) GetShadowSocksPort() (port uint16, err error) { // SHADOWSOCKS_PASSWORD. func (r *reader) GetShadowSocksPassword() (password string, err error) { defer func() { - unsetErr := r.unsetEnv("SHADOWSOCKS_PASSWORD") + unsetErr := r.os.Unsetenv("SHADOWSOCKS_PASSWORD") if err == nil { err = unsetErr } diff --git a/internal/provider/cyberghost.go b/internal/provider/cyberghost.go index af0e0f27..ba8ddf4b 100644 --- a/internal/provider/cyberghost.go +++ b/internal/provider/cyberghost.go @@ -11,7 +11,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -133,7 +133,7 @@ func (c *cyberghost) BuildConf(connection models.OpenVPNConnection, verbosity in } func (c *cyberghost) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for cyberghost") } diff --git a/internal/provider/mullvad.go b/internal/provider/mullvad.go index 7d4227ec..2977df7e 100644 --- a/internal/provider/mullvad.go +++ b/internal/provider/mullvad.go @@ -10,7 +10,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -128,7 +128,7 @@ func (m *mullvad) BuildConf(connection models.OpenVPNConnection, } func (m *mullvad) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for mullvad") } diff --git a/internal/provider/nordvpn.go b/internal/provider/nordvpn.go index 1b94127d..666c2eb5 100644 --- a/internal/provider/nordvpn.go +++ b/internal/provider/nordvpn.go @@ -10,7 +10,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -142,7 +142,7 @@ func (n *nordvpn) BuildConf(connection models.OpenVPNConnection, verbosity int, } func (n *nordvpn) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for nordvpn") } diff --git a/internal/provider/piav4.go b/internal/provider/piav4.go index dd1d4d5e..b207e97e 100644 --- a/internal/provider/piav4.go +++ b/internal/provider/piav4.go @@ -19,7 +19,7 @@ import ( "github.com/qdm12/gluetun/internal/firewall" gluetunLog "github.com/qdm12/gluetun/internal/logging" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -183,7 +183,7 @@ func (p *pia) BuildConf(connection models.OpenVPNConnection, verbosity int, user //nolint:gocognit func (p *pia) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { if !p.activeServer.PortForward { pfLogger.Error("The server %s does not support port forwarding", p.activeServer.Region) @@ -203,7 +203,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, return } defer pfLogger.Warn("loop exited") - data, err := readPIAPortForwardData(fileManager) + data, err := readPIAPortForwardData(openFile) if err != nil { pfLogger.Error(err) } @@ -222,7 +222,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, if !dataFound || expired { tryUntilSuccessful(ctx, pfLogger, func() error { - data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager) + data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile) return err }) if ctx.Err() != nil { @@ -240,12 +240,9 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, return } - filepath := syncState(data.Port) + filepath := string(syncState(data.Port)) pfLogger.Info("Writing port to %s", filepath) - if err := fileManager.WriteToFile( - string(filepath), []byte(fmt.Sprintf("%d", data.Port)), - files.Permissions(constants.AllReadWritePermissions), - ); err != nil { + if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil { pfLogger.Error(err) } @@ -281,7 +278,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123)) oldPort := data.Port for { - data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager) + data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile) if err != nil { pfLogger.Error(err) continue @@ -298,10 +295,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, } filepath := syncState(data.Port) pfLogger.Info("Writing port to %s", filepath) - if err := fileManager.WriteToFile( - string(filepath), []byte(fmt.Sprintf("%d", data.Port)), - files.Permissions(constants.AllReadWritePermissions), - ); err != nil { + if err := writePortForwardedToFile(openFile, string(filepath), data.Port); err != nil { pfLogger.Error(err) } if err := bindPIAPort(ctx, client, gateway, data); err != nil { @@ -365,8 +359,8 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) { } func refreshPIAPortForwardData(ctx context.Context, client *http.Client, - gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) { - data.Token, err = fetchPIAToken(ctx, fileManager, client) + gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) { + data.Token, err = fetchPIAToken(ctx, openFile, client) if err != nil { return data, fmt.Errorf("cannot obtain token: %w", err) } @@ -374,7 +368,7 @@ func refreshPIAPortForwardData(ctx context.Context, client *http.Client, if err != nil { return data, fmt.Errorf("cannot obtain port forwarding data: %w", err) } - if err := writePIAPortForwardData(fileManager, data); err != nil { + if err := writePIAPortForwardData(openFile, data); err != nil { return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err) } return data, nil @@ -393,34 +387,39 @@ type piaPortForwardData struct { Expiration time.Time `json:"expires_at"` } -func readPIAPortForwardData(fileManager files.FileManager) (data piaPortForwardData, err error) { +func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) { const filepath = string(constants.PIAPortForward) - exists, err := fileManager.FileExists(filepath) - if err != nil { - return data, err - } else if !exists { + file, err := openFile(filepath, os.O_RDONLY, 0) + if os.IsNotExist(err) { return data, nil + } else if err != nil { + return data, err } - b, err := fileManager.ReadFile(filepath) + + decoder := json.NewDecoder(file) + err = decoder.Decode(&data) if err != nil { + _ = file.Close() return data, err } - if err := json.Unmarshal(b, &data); err != nil { - return data, err - } - return data, nil + return data, file.Close() } -func writePIAPortForwardData(fileManager files.FileManager, data piaPortForwardData) (err error) { - b, err := json.Marshal(&data) - if err != nil { - return fmt.Errorf("cannot encode data: %w", err) - } - err = fileManager.WriteToFile(string(constants.PIAPortForward), b) +func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) { + const filepath = string(constants.PIAPortForward) + file, err := openFile(filepath, + os.O_CREATE|os.O_TRUNC|os.O_WRONLY, + 0644) if err != nil { return err } - return nil + encoder := json.NewEncoder(file) + err = encoder.Encode(data) + if err != nil { + _ = file.Close() + return err + } + return file.Close() } func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) { @@ -449,8 +448,9 @@ func packPIAPayload(port uint16, token string, expiration time.Time) (payload st return payload, nil } -func fetchPIAToken(ctx context.Context, fileManager files.FileManager, client *http.Client) (token string, err error) { - username, password, err := getOpenvpnCredentials(fileManager) +func fetchPIAToken(ctx context.Context, openFile os.OpenFileFunc, + client *http.Client) (token string, err error) { + username, password, err := getOpenvpnCredentials(openFile) if err != nil { return "", fmt.Errorf("cannot get Openvpn credentials: %w", err) } @@ -489,10 +489,19 @@ func fetchPIAToken(ctx context.Context, fileManager files.FileManager, client *h return result.Token, nil } -func getOpenvpnCredentials(fileManager files.FileManager) (username, password string, err error) { - authData, err := fileManager.ReadFile(string(constants.OpenVPNAuthConf)) +func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) { + const filepath = string(constants.OpenVPNAuthConf) + file, err := openFile(filepath, os.O_RDONLY, 0) if err != nil { - return "", "", fmt.Errorf("cannot read openvpn auth file: %w", err) + return "", "", fmt.Errorf("cannot read openvpn auth file: %s", err) + } + authData, err := ioutil.ReadAll(file) + if err != nil { + _ = file.Close() + return "", "", fmt.Errorf("cannot read openvpn auth file: %s", err) + } + if err := file.Close(); err != nil { + return "", "", err } lines := strings.Split(string(authData), "\n") const minLines = 2 @@ -586,3 +595,17 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data } return nil } + +func writePortForwardedToFile(openFile os.OpenFileFunc, + filepath string, port uint16) (err error) { + file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return err + } + _, err = file.Write([]byte(fmt.Sprintf("%d", port))) + if err != nil { + _ = file.Close() + return err + } + return file.Close() +} diff --git a/internal/provider/privado.go b/internal/provider/privado.go index 7461848d..6e061d5b 100644 --- a/internal/provider/privado.go +++ b/internal/provider/privado.go @@ -10,7 +10,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -117,7 +117,7 @@ func (s *privado) BuildConf(connection models.OpenVPNConnection, verbosity int, } func (s *privado) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for privado") } diff --git a/internal/provider/provider.go b/internal/provider/provider.go index ef76c2b1..2bb247c3 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -8,7 +8,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -18,7 +18,7 @@ type Provider interface { BuildConf(connection models.OpenVPNConnection, verbosity int, username string, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) } diff --git a/internal/provider/purevpn.go b/internal/provider/purevpn.go index addead66..155a3758 100644 --- a/internal/provider/purevpn.go +++ b/internal/provider/purevpn.go @@ -10,7 +10,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -150,7 +150,7 @@ func (p *purevpn) BuildConf(connection models.OpenVPNConnection, verbosity int, } func (p *purevpn) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for purevpn") } diff --git a/internal/provider/surfshark.go b/internal/provider/surfshark.go index 951fac2d..9a1ce731 100644 --- a/internal/provider/surfshark.go +++ b/internal/provider/surfshark.go @@ -10,7 +10,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -138,7 +138,7 @@ func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity int } func (s *surfshark) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for surfshark") } diff --git a/internal/provider/vyprvpn.go b/internal/provider/vyprvpn.go index 22cb0392..9ce7d589 100644 --- a/internal/provider/vyprvpn.go +++ b/internal/provider/vyprvpn.go @@ -10,7 +10,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -119,7 +119,7 @@ func (v *vyprvpn) BuildConf(connection models.OpenVPNConnection, verbosity int, } func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for vyprvpn") } diff --git a/internal/provider/windscribe.go b/internal/provider/windscribe.go index 11ae0eb3..b7a7ac32 100644 --- a/internal/provider/windscribe.go +++ b/internal/provider/windscribe.go @@ -11,7 +11,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" - "github.com/qdm12/golibs/files" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -132,7 +132,7 @@ func (w *windscribe) BuildConf(connection models.OpenVPNConnection, verbosity in } func (w *windscribe) PortForward(ctx context.Context, client *http.Client, - fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, + openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, syncState func(port uint16) (pfFilepath models.Filepath)) { panic("port forwarding is not supported for windscribe") } diff --git a/internal/publicip/fs.go b/internal/publicip/fs.go new file mode 100644 index 00000000..dbaa5b14 --- /dev/null +++ b/internal/publicip/fs.go @@ -0,0 +1,27 @@ +package publicip + +import "github.com/qdm12/gluetun/internal/os" + +func persistPublicIP(openFile os.OpenFileFunc, + filepath string, content string, uid, gid int) error { + file, err := openFile( + filepath, + os.O_TRUNC|os.O_WRONLY|os.O_CREATE, + 0644) + if err != nil { + return err + } + + _, err = file.WriteString(content) + if err != nil { + _ = file.Close() + return err + } + + if err := file.Chown(uid, gid); err != nil { + _ = file.Close() + return err + } + + return file.Close() +} diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index c9eb92b3..d3a6dc11 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -8,8 +8,8 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/settings" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/network" ) @@ -27,9 +27,9 @@ type Looper interface { type looper struct { state state // Objects - getter IPGetter - logger logging.Logger - fileManager files.FileManager + getter IPGetter + logger logging.Logger + os os.OS // Fixed settings uid int gid int @@ -45,8 +45,9 @@ type looper struct { timeSince func(time.Time) time.Duration } -func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager, - settings settings.PublicIP, uid, gid int) Looper { +func NewLooper(client network.Client, logger logging.Logger, + settings settings.PublicIP, uid, gid int, + os os.OS) Looper { return &looper{ state: state{ status: constants.Stopped, @@ -55,7 +56,7 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F // Objects getter: NewIPGetter(client), logger: logger.WithPrefix("ip getter: "), - fileManager: fileManager, + os: os, uid: uid, gid: gid, start: make(chan struct{}), @@ -125,7 +126,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { close(errorCh) filepath := l.GetSettings().IPFilepath l.logger.Info("Removing ip file %s", filepath) - if err := l.fileManager.Remove(string(filepath)); err != nil { + if err := l.os.Remove(string(filepath)); err != nil { l.logger.Error(err) } return @@ -142,12 +143,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { getCancel() l.state.setPublicIP(ip) l.logger.Info("Public IP address is %s", ip) - const userReadWritePermissions = 0600 - err := l.fileManager.WriteLinesToFile( - string(l.state.settings.IPFilepath), - []string{ip.String()}, - files.Ownership(l.uid, l.gid), - files.Permissions(userReadWritePermissions)) + filepath := string(l.state.settings.IPFilepath) + err := persistPublicIP(l.os.OpenFile, filepath, ip.String(), l.uid, l.gid) if err != nil { l.logger.Error(err) } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 719085cb..35ac4854 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -1,10 +1,8 @@ package storage import ( - "io/ioutil" - "os" - "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/os" "github.com/qdm12/golibs/logging" ) @@ -14,17 +12,13 @@ type Storage interface { } type storage struct { - osStat func(name string) (os.FileInfo, error) - readFile func(filename string) (data []byte, err error) - writeFile func(filename string, data []byte, perm os.FileMode) error - logger logging.Logger + os os.OS + logger logging.Logger } -func New(logger logging.Logger) Storage { +func New(logger logging.Logger, os os.OS) Storage { return &storage{ - osStat: os.Stat, - readFile: ioutil.ReadFile, - writeFile: ioutil.WriteFile, - logger: logger.WithPrefix("storage: "), + os: os, + logger: logger.WithPrefix("storage: "), } } diff --git a/internal/storage/sync.go b/internal/storage/sync.go index 87e8e806..1327a439 100644 --- a/internal/storage/sync.go +++ b/internal/storage/sync.go @@ -3,10 +3,10 @@ package storage import ( "encoding/json" "fmt" - "os" "reflect" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/os" ) const ( @@ -29,14 +29,18 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) ( allServers models.AllServers, err error) { // Eventually read file var serversOnFile models.AllServers - _, err = s.osStat(jsonFilepath) + file, err := s.os.OpenFile(jsonFilepath, os.O_RDONLY, 0) + if err != nil && !os.IsNotExist(err) { + return allServers, err + } if err == nil { - serversOnFile, err = s.readFromFile() - if err != nil { + var serversOnFile models.AllServers + decoder := json.NewDecoder(file) + if err := decoder.Decode(&serversOnFile); err != nil { + _ = file.Close() return allServers, err } - } else if !os.IsNotExist(err) { - return allServers, err + return allServers, file.Close() } // Merge data from file and hardcoded @@ -51,24 +55,16 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) ( return allServers, s.FlushToFile(allServers) } -func (s *storage) readFromFile() (servers models.AllServers, err error) { - bytes, err := s.readFile(jsonFilepath) - if err != nil { - return servers, err - } - if err := json.Unmarshal(bytes, &servers); err != nil { - return servers, err - } - return servers, nil -} - func (s *storage) FlushToFile(servers models.AllServers) error { - bytes, err := json.MarshalIndent(servers, "", " ") + file, err := s.os.OpenFile(jsonFilepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { - return fmt.Errorf("cannot write to file: %w", err) - } - if err := s.writeFile(jsonFilepath, bytes, 0644); err != nil { return err } - return nil + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + if err := encoder.Encode(servers); err != nil { + _ = file.Close() + return fmt.Errorf("cannot write to file: %w", err) + } + return file.Close() }