Code maintenance: OS package for file system

- OS custom internal package for file system interaction
- Remove fileManager external dependency
- Closer API to Go's native API on the OS
- Create directories at startup
- Better testability
- Move Unsetenv to os interface
This commit is contained in:
Quentin McGaw
2020-12-29 00:55:31 +00:00
parent f5366c33bc
commit 73479bab26
43 changed files with 923 additions and 353 deletions

View File

@@ -10,6 +10,10 @@ issues:
linters: linters:
- dupl - dupl
- maligned - maligned
- path: internal/os/alias\.go
linters:
- gochecknoglobals
text: IsNotExist is a global variable
linters: linters:
disable-all: true disable-all: true

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"os" nativeos "os"
"os/signal" "os/signal"
"strings" "strings"
"sync" "sync"
@@ -22,6 +22,7 @@ import (
gluetunLogging "github.com/qdm12/gluetun/internal/logging" gluetunLogging "github.com/qdm12/gluetun/internal/logging"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/params"
"github.com/qdm12/gluetun/internal/publicip" "github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
@@ -32,7 +33,6 @@ import (
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
versionpkg "github.com/qdm12/gluetun/internal/version" versionpkg "github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
) )
@@ -50,21 +50,24 @@ func main() {
buildInfo.Commit = commit buildInfo.Commit = commit
buildInfo.BuildDate = buildDate buildInfo.BuildDate = buildDate
ctx := context.Background() ctx := context.Background()
os.Exit(_main(ctx, os.Args)) args := nativeos.Args
os := os.New()
nativeos.Exit(_main(ctx, args, os))
} }
func _main(background context.Context, args []string) int { //nolint:gocognit,gocyclo //nolint:gocognit,gocyclo
func _main(background context.Context, args []string, os os.OS) int {
if len(args) > 1 { // cli operation if len(args) > 1 { // cli operation
var err error var err error
switch args[1] { switch args[1] {
case "healthcheck": case "healthcheck":
err = cli.HealthCheck(background) err = cli.HealthCheck(background)
case "clientkey": case "clientkey":
err = cli.ClientKey(args[2:]) err = cli.ClientKey(args[2:], os.OpenFile)
case "openvpnconfig": case "openvpnconfig":
err = cli.OpenvpnConfig() err = cli.OpenvpnConfig(os)
case "update": case "update":
err = cli.Update(args[2:]) err = cli.Update(args[2:], os)
default: default:
err = fmt.Errorf("command %q is unknown", args[1]) err = fmt.Errorf("command %q is unknown", args[1])
} }
@@ -82,15 +85,14 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
httpClient := &http.Client{Timeout: clientTimeout} httpClient := &http.Client{Timeout: clientTimeout}
client := network.NewClient(clientTimeout) client := network.NewClient(clientTimeout)
// Create configurators // Create configurators
fileManager := files.NewFileManager() alpineConf := alpine.NewConfigurator(os.OpenFile)
alpineConf := alpine.NewConfigurator(fileManager) ovpnConf := openvpn.NewConfigurator(logger, os)
ovpnConf := openvpn.NewConfigurator(logger, fileManager) dnsConf := dns.NewConfigurator(logger, client, os.OpenFile)
dnsConf := dns.NewConfigurator(logger, client, fileManager)
routingConf := routing.NewRouting(logger) routingConf := routing.NewRouting(logger)
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager) firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile)
streamMerger := command.NewStreamMerger() streamMerger := command.NewStreamMerger()
paramsReader := params.NewReader(logger, fileManager) paramsReader := params.NewReader(logger, os)
fmt.Println(gluetunLogging.Splash(buildInfo)) fmt.Println(gluetunLogging.Splash(buildInfo))
printVersions(ctx, logger, map[string]func(ctx context.Context) (string, error){ printVersions(ctx, logger, map[string]func(ctx context.Context) (string, error){
@@ -106,8 +108,17 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
} }
logger.Info(allSettings.String()) logger.Info(allSettings.String())
if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil {
logger.Error(err)
return 1
}
if err := os.MkdirAll("/gluetun", 0644); err != nil {
logger.Error(err)
return 1
}
// TODO run this in a loop or in openvpn to reload from file without restarting // TODO run this in a loop or in openvpn to reload from file without restarting
storage := storage.New(logger) storage := storage.New(logger, os)
const updateServerFile = true const updateServerFile = true
allServers, err := storage.SyncServers(constants.GetAllServers(), updateServerFile) allServers, err := storage.SyncServers(constants.GetAllServers(), updateServerFile)
if err != nil { if err != nil {
@@ -124,8 +135,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
logger.Error(err) logger.Error(err)
return 1 return 1
} }
err = fileManager.SetOwnership("/etc/unbound", uid, gid)
if err != nil { if err := os.Chown("/etc/unbound", uid, gid); err != nil {
logger.Error(err) logger.Error(err)
return 1 return 1
} }
@@ -219,7 +230,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, uid, gid, allServers, openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, uid, gid, allServers,
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel) ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel)
wg.Add(1) wg.Add(1)
// wait for restartOpenvpn // wait for restartOpenvpn
go openvpnLooper.Run(ctx, wg) go openvpnLooper.Run(ctx, wg)
@@ -236,7 +247,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
go unboundLooper.Run(ctx, wg, signalDNSReady) go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper( publicIPLooper := publicip.NewLooper(
client, logger, fileManager, allSettings.PublicIP, uid, gid) client, logger, allSettings.PublicIP, uid, gid, os)
wg.Add(1) wg.Add(1)
go publicIPLooper.Run(ctx, wg) go publicIPLooper.Run(ctx, wg)
wg.Add(1) wg.Add(1)
@@ -279,11 +290,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
// until openvpn is launched // until openvpn is launched
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable _, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
signalsCh := make(chan os.Signal, 1) signalsCh := make(chan nativeos.Signal, 1)
signal.Notify(signalsCh, signal.Notify(signalsCh,
syscall.SIGINT, syscall.SIGINT,
syscall.SIGTERM, syscall.SIGTERM,
os.Interrupt, nativeos.Interrupt,
) )
shutdownErrorsCount := 0 shutdownErrorsCount := 0
select { select {
@@ -295,7 +306,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
} }
if allSettings.OpenVPN.Provider.PortForwarding.Enabled { if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath) logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
logger.Error(err) logger.Error(err)
shutdownErrorsCount++ shutdownErrorsCount++
} }

View File

@@ -3,7 +3,7 @@ package alpine
import ( import (
"os/user" "os/user"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
) )
type Configurator interface { type Configurator interface {
@@ -11,15 +11,15 @@ type Configurator interface {
} }
type configurator struct { type configurator struct {
fileManager files.FileManager openFile os.OpenFileFunc
lookupUID func(uid string) (*user.User, error) lookupUID func(uid string) (*user.User, error)
lookupUser func(username string) (*user.User, error) lookupUser func(username string) (*user.User, error)
} }
func NewConfigurator(fileManager files.FileManager) Configurator { func NewConfigurator(openFile os.OpenFileFunc) Configurator {
return &configurator{ return &configurator{
fileManager: fileManager, openFile: openFile,
lookupUID: user.LookupId, lookupUID: user.LookupId,
lookupUser: user.Lookup, lookupUser: user.Lookup,
} }
} }

View File

@@ -2,6 +2,7 @@ package alpine
import ( import (
"fmt" "fmt"
"os"
"os/user" "os/user"
) )
@@ -26,14 +27,15 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
return "", fmt.Errorf("cannot create user: user with name %s already exists for ID %s instead of %d", return "", fmt.Errorf("cannot create user: user with name %s already exists for ID %s instead of %d",
username, u.Uid, uid) username, u.Uid, uid)
} }
passwd, err := c.fileManager.ReadFile("/etc/passwd") file, err := c.openFile("/etc/passwd", os.O_APPEND|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return "", fmt.Errorf("cannot create user: %w", err) return "", fmt.Errorf("cannot create user: %w", err)
} }
passwd = append(passwd, []byte(fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid))...) s := fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid)
_, err = file.WriteString(s)
if err := c.fileManager.WriteToFile("/etc/passwd", passwd); err != nil { if err != nil {
return "", fmt.Errorf("cannot create user: %w", err) _ = file.Close()
return "", err
} }
return username, nil return username, file.Close()
} }

View File

@@ -4,29 +4,40 @@ import (
"context" "context"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/healthcheck" "github.com/qdm12/gluetun/internal/healthcheck"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/params"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
func ClientKey(args []string) error { func ClientKey(args []string, openFile os.OpenFileFunc) error {
flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError) flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError)
filepath := flagSet.String("path", string(constants.ClientKey), "file path to the client.key file") filepath := flagSet.String("path", string(constants.ClientKey), "file path to the client.key file")
if err := flagSet.Parse(args); err != nil { if err := flagSet.Parse(args); err != nil {
return err return err
} }
fileManager := files.NewFileManager() file, err := openFile(*filepath, os.O_RDONLY, 0)
data, err := fileManager.ReadFile(*filepath) if err != nil {
return err
}
data, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
if err != nil { if err != nil {
return err return err
} }
@@ -49,17 +60,17 @@ func HealthCheck(ctx context.Context) error {
return healthchecker.Check(ctx, url) return healthchecker.Check(ctx, url)
} }
func OpenvpnConfig() error { func OpenvpnConfig(os os.OS) error {
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel) logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
if err != nil { if err != nil {
return err return err
} }
paramsReader := params.NewReader(logger, files.NewFileManager()) paramsReader := params.NewReader(logger, os)
allSettings, err := settings.GetAllSettings(paramsReader) allSettings, err := settings.GetAllSettings(paramsReader)
if err != nil { if err != nil {
return err return err
} }
allServers, err := storage.New(logger).SyncServers(constants.GetAllServers(), false) allServers, err := storage.New(logger, os).SyncServers(constants.GetAllServers(), false)
if err != nil { if err != nil {
return err return err
} }
@@ -81,7 +92,7 @@ func OpenvpnConfig() error {
return nil return nil
} }
func Update(args []string) error { func Update(args []string, os os.OS) error {
options := settings.Updater{CLI: true} options := settings.Updater{CLI: true}
var flushToFile bool var flushToFile bool
flagSet := flag.NewFlagSet("update", flag.ExitOnError) flagSet := flag.NewFlagSet("update", flag.ExitOnError)
@@ -110,7 +121,7 @@ func Update(args []string) error {
ctx := context.Background() ctx := context.Background()
const clientTimeout = 10 * time.Second const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout} httpClient := &http.Client{Timeout: clientTimeout}
storage := storage.New(logger) storage := storage.New(logger, os)
const writeSync = false const writeSync = false
currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync) currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync)
if err != nil { if err != nil {

View File

@@ -1,8 +0,0 @@
package constants
import "os"
const (
UserReadPermission os.FileMode = 0400
AllReadWritePermissions os.FileMode = 0666
)

View File

@@ -8,8 +8,8 @@ import (
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
) )
@@ -21,11 +21,29 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN
for _, warning := range warnings { for _, warning := range warnings {
c.logger.Warn(warning) c.logger.Warn(warning)
} }
return c.fileManager.WriteLinesToFile(
string(constants.UnboundConf), const filepath = string(constants.UnboundConf)
lines, file, err := c.openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0400)
files.Ownership(uid, gid), if err != nil {
files.Permissions(constants.UserReadPermission)) return err
}
_, err = file.WriteString(strings.Join(lines, "\n"))
if err != nil {
_ = file.Close()
return err
}
if err := file.Chown(uid, gid); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
return nil
} }
// MakeUnboundConf generates an Unbound configuration from the user provided settings. // MakeUnboundConf generates an Unbound configuration from the user provided settings.

View File

@@ -5,9 +5,9 @@ import (
"io" "io"
"net" "net"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
) )
@@ -24,19 +24,20 @@ type Configurator interface {
} }
type configurator struct { type configurator struct {
logger logging.Logger logger logging.Logger
client network.Client client network.Client
fileManager files.FileManager openFile os.OpenFileFunc
commander command.Commander commander command.Commander
lookupIP func(host string) ([]net.IP, error) lookupIP func(host string) ([]net.IP, error)
} }
func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator { func NewConfigurator(logger logging.Logger, client network.Client,
openFile os.OpenFileFunc) Configurator {
return &configurator{ return &configurator{
logger: logger.WithPrefix("dns configurator: "), logger: logger.WithPrefix("dns configurator: "),
client: client, client: client,
fileManager: fileManager, openFile: openFile,
commander: command.NewCommander(), commander: command.NewCommander(),
lookupIP: net.LookupIP, lookupIP: net.LookupIP,
} }
} }

View File

@@ -2,10 +2,12 @@ package dns
import ( import (
"context" "context"
"io/ioutil"
"net" "net"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/os"
) )
// UseDNSInternally is to change the Go program DNS only. // UseDNSInternally is to change the Go program DNS only.
@@ -23,10 +25,16 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
// UseDNSSystemWide changes the nameserver to use for DNS system wide. // UseDNSSystemWide changes the nameserver to use for DNS system wide.
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error { func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
c.logger.Info("using DNS address %s system wide", ip.String()) c.logger.Info("using DNS address %s system wide", ip.String())
data, err := c.fileManager.ReadFile(string(constants.ResolvConf)) const filepath = string(constants.ResolvConf)
file, err := c.openFile(filepath, os.O_RDWR, 0644)
if err != nil { if err != nil {
return err return err
} }
data, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
s := strings.TrimSuffix(string(data), "\n") s := strings.TrimSuffix(string(data), "\n")
lines := strings.Split(s, "\n") lines := strings.Split(s, "\n")
if len(lines) == 1 && lines[0] == "" { if len(lines) == 1 && lines[0] == "" {
@@ -44,6 +52,11 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
if !found { if !found {
lines = append(lines, "nameserver "+ip.String()) lines = append(lines, "nameserver "+ip.String())
} }
data = []byte(strings.Join(lines, "\n")) s = strings.Join(lines, "\n")
return c.fileManager.WriteToFile(string(constants.ResolvConf), data) _, err = file.WriteString(s)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
} }

View File

@@ -2,12 +2,14 @@ package dns
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files/mock_files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/os/mock_os"
"github.com/qdm12/golibs/logging/mock_logging" "github.com/qdm12/golibs/logging/mock_logging"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -17,30 +19,36 @@ func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel() t.Parallel()
tests := map[string]struct { tests := map[string]struct {
data []byte data []byte
writtenData []byte writtenData string
openErr error
readErr error readErr error
writeErr error writeErr error
closeErr error
err error err error
}{ }{
"no data": { "no data": {
writtenData: []byte("nameserver 127.0.0.1"), writtenData: "nameserver 127.0.0.1",
},
"open error": {
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
}, },
"read error": { "read error": {
readErr: fmt.Errorf("error"), readErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"write error": { "write error": {
writtenData: []byte("nameserver 127.0.0.1"), writtenData: "nameserver 127.0.0.1",
writeErr: fmt.Errorf("error"), writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"lines without nameserver": { "lines without nameserver": {
data: []byte("abc\ndef\n"), data: []byte("abc\ndef\n"),
writtenData: []byte("abc\ndef\nnameserver 127.0.0.1"), writtenData: "abc\ndef\nnameserver 127.0.0.1",
}, },
"lines with nameserver": { "lines with nameserver": {
data: []byte("abc\nnameserver abc def\ndef\n"), data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: []byte("abc\nnameserver 127.0.0.1\ndef"), writtenData: "abc\nnameserver 127.0.0.1\ndef",
}, },
} }
for name, tc := range tests { for name, tc := range tests {
@@ -49,18 +57,43 @@ func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel() t.Parallel()
mockCtrl := gomock.NewController(t) mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish() defer mockCtrl.Finish()
fileManager := mock_files.NewMockFileManager(mockCtrl)
fileManager.EXPECT().ReadFile(string(constants.ResolvConf)). file := mock_os.NewMockFile(mockCtrl)
Return(tc.data, tc.readErr) if tc.openErr == nil {
if tc.readErr == nil { firstReadCall := file.EXPECT().
fileManager.EXPECT().WriteToFile(string(constants.ResolvConf), tc.writtenData). Read(gomock.AssignableToTypeOf([]byte{})).
Return(tc.writeErr) DoAndReturn(func(b []byte) (int, error) {
copy(b, tc.data)
return len(tc.data), nil
})
readErr := tc.readErr
if readErr == nil {
readErr = io.EOF
}
finalReadCall := file.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
Return(0, readErr).After(firstReadCall)
if tc.readErr == nil {
writeCall := file.EXPECT().WriteString(tc.writtenData).
Return(0, tc.writeErr).After(finalReadCall)
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
}
} }
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.ResolvConf), name)
assert.Equal(t, os.O_RDWR, flag)
assert.Equal(t, os.FileMode(0644), perm)
return file, tc.openErr
}
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1") logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1")
c := &configurator{ c := &configurator{
fileManager: fileManager, openFile: openFile,
logger: logger, logger: logger,
} }
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false) err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
if tc.err != nil { if tc.err != nil {

View File

@@ -4,37 +4,46 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"os"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
) )
func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error { func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error {
c.logger.Info("downloading root hints from %s", constants.NamedRootURL) return c.downloadAndSave(ctx, "root hints",
content, status, err := c.client.Get(ctx, string(constants.NamedRootURL)) string(constants.NamedRootURL), string(constants.RootHints), uid, gid)
if err != nil {
return err
} else if status != http.StatusOK {
return fmt.Errorf("HTTP status code is %d for %s", status, constants.NamedRootURL)
}
return c.fileManager.WriteToFile(
string(constants.RootHints),
content,
files.Ownership(uid, gid),
files.Permissions(constants.UserReadPermission))
} }
func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error { func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error {
c.logger.Info("downloading root key from %s", constants.RootKeyURL) return c.downloadAndSave(ctx, "root key",
content, status, err := c.client.Get(ctx, string(constants.RootKeyURL)) string(constants.RootKeyURL), string(constants.RootKey), uid, gid)
}
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error {
c.logger.Info("downloading %s from %s", logName, url)
content, status, err := c.client.Get(ctx, url)
if err != nil { if err != nil {
return err return err
} else if status != http.StatusOK { } else if status != http.StatusOK {
return fmt.Errorf("HTTP status code is %d for %s", status, constants.RootKeyURL) return fmt.Errorf("HTTP status code is %d for %s", status, url)
} }
return c.fileManager.WriteToFile(
string(constants.RootKey), file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
content, if err != nil {
files.Ownership(uid, gid), return err
files.Permissions(constants.UserReadPermission)) }
_, err = file.Write(content)
if err != nil {
_ = file.Close()
return err
}
err = file.Chown(uid, gid)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
} }

View File

@@ -2,27 +2,31 @@ package dns
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/files/mock_files" "github.com/qdm12/gluetun/internal/os/mock_os"
"github.com/qdm12/golibs/logging/mock_logging" "github.com/qdm12/golibs/logging/mock_logging"
"github.com/qdm12/golibs/network/mock_network" "github.com/qdm12/golibs/network/mock_network"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func Test_DownloadRootHints(t *testing.T) { //nolint:dupl func Test_downloadAndSave(t *testing.T) {
t.Parallel() t.Parallel()
tests := map[string]struct { tests := map[string]struct {
content []byte content []byte
status int status int
clientErr error clientErr error
openErr error
writeErr error writeErr error
chownErr error
closeErr error
err error err error
}{ }{
"no data": { "no data": {
@@ -36,11 +40,26 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
clientErr: fmt.Errorf("error"), clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"open error": {
status: http.StatusOK,
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": { "write error": {
status: http.StatusOK, status: http.StatusOK,
writeErr: fmt.Errorf("error"), writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"chown error": {
status: http.StatusOK,
chownErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"close error": {
status: http.StatusOK,
closeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": { "data": {
content: []byte("content"), content: []byte("content"),
status: http.StatusOK, status: http.StatusOK,
@@ -52,23 +71,49 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
t.Parallel() t.Parallel()
mockCtrl := gomock.NewController(t) mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish() defer mockCtrl.Finish()
ctx := context.Background() ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading root hints from %s", constants.NamedRootURL) logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
client := mock_network.NewMockClient(mockCtrl) client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.NamedRootURL)). client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
Return(tc.content, tc.status, tc.clientErr) Return(tc.content, tc.status, tc.clientErr)
fileManager := mock_files.NewMockFileManager(mockCtrl)
if tc.clientErr == nil && tc.status == http.StatusOK { openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
fileManager.EXPECT().WriteToFile( return nil, nil
string(constants.RootHints),
tc.content,
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
gomock.AssignableToTypeOf(files.Ownership(0, 0))).
Return(tc.writeErr)
} }
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootHints(ctx, 1000, 1000) if tc.clientErr == nil && tc.status == http.StatusOK {
file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil {
writeCall := file.EXPECT().Write(tc.content).
Return(0, tc.writeErr)
if tc.writeErr != nil {
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall)
file.EXPECT().Close().Return(tc.closeErr).After(chownCall)
}
}
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.RootHints), name)
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
assert.Equal(t, os.FileMode(0400), perm)
return file, tc.openErr
}
}
c := &configurator{
logger: logger,
client: client,
openFile: openFile,
}
err := c.downloadAndSave(ctx, "root hints",
string(constants.NamedRootURL), string(constants.RootHints),
1000, 1000)
if tc.err != nil { if tc.err != nil {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error()) assert.Equal(t, tc.err.Error(), err.Error())
@@ -79,65 +124,44 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
} }
} }
func Test_DownloadRootKey(t *testing.T) { //nolint:dupl func Test_DownloadRootHints(t *testing.T) {
t.Parallel() t.Parallel()
tests := map[string]struct { mockCtrl := gomock.NewController(t)
content []byte
status int ctx := context.Background()
clientErr error logger := mock_logging.NewMockLogger(mockCtrl)
writeErr error logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
err error client := mock_network.NewMockClient(mockCtrl)
}{ client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
"no data": { Return(nil, http.StatusOK, errors.New("test"))
status: http.StatusOK,
}, c := &configurator{
"bad status": { logger: logger,
status: http.StatusBadRequest, client: client,
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"), //nolint:lll
},
"client error": {
clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
status: http.StatusOK,
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": {
content: []byte("content"),
status: http.StatusOK,
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading root key from %s", constants.RootKeyURL)
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
Return(tc.content, tc.status, tc.clientErr)
fileManager := mock_files.NewMockFileManager(mockCtrl)
if tc.clientErr == nil && tc.status == http.StatusOK {
fileManager.EXPECT().WriteToFile(
string(constants.RootKey),
tc.content,
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
).Return(tc.writeErr)
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootKey(ctx, 1000, 1001)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
} }
err := c.DownloadRootHints(ctx, 1000, 1000)
require.Error(t, err)
assert.Equal(t, "test", err.Error())
}
func Test_DownloadRootKey(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
Return(nil, http.StatusOK, errors.New("test"))
c := &configurator{
logger: logger,
client: client,
}
err := c.DownloadRootKey(ctx, 1000, 1000)
require.Error(t, err)
assert.Equal(t, "test", err.Error())
} }

View File

@@ -6,9 +6,9 @@ import (
"sync" "sync"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -29,7 +29,7 @@ type configurator struct { //nolint:maligned
commander command.Commander commander command.Commander
logger logging.Logger logger logging.Logger
routing routing.Routing routing routing.Routing
fileManager files.FileManager // for custom iptables rules openFile os.OpenFileFunc // for custom iptables rules
iptablesMutex sync.Mutex iptablesMutex sync.Mutex
debug bool debug bool
defaultInterface string defaultInterface string
@@ -47,12 +47,12 @@ type configurator struct { //nolint:maligned
} }
// NewConfigurator creates a new Configurator instance. // NewConfigurator creates a new Configurator instance.
func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator { func NewConfigurator(logger logging.Logger, routing routing.Routing, openFile os.OpenFileFunc) Configurator {
return &configurator{ return &configurator{
commander: command.NewCommander(), commander: command.NewCommander(),
logger: logger.WithPrefix("firewall: "), logger: logger.WithPrefix("firewall: "),
routing: routing, routing: routing,
fileManager: fileManager, openFile: openFile,
allowedInputPorts: make(map[uint16]string), allowedInputPorts: make(map[uint16]string),
} }
} }

View File

@@ -3,7 +3,9 @@ package firewall
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os"
"strings" "strings"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
@@ -150,14 +152,18 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port
} }
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error { func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
exists, err := c.fileManager.FileExists(filepath) file, err := c.openFile(filepath, os.O_RDONLY, 0)
if err != nil { if os.IsNotExist(err) {
return err
} else if !exists {
return nil return nil
} else if err != nil {
return err
} }
b, err := c.fileManager.ReadFile(filepath) b, err := ioutil.ReadAll(file)
if err != nil { if err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err return err
} }
lines := strings.Split(string(b), "\n") lines := strings.Split(string(b), "\n")

View File

@@ -1,31 +1,68 @@
package openvpn package openvpn
import ( import (
"io/ioutil"
"os"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
) )
// WriteAuthFile writes the OpenVPN auth file to disk with the right permissions. // WriteAuthFile writes the OpenVPN auth file to disk with the right permissions.
func (c *configurator) WriteAuthFile(user, password string, uid, gid int) error { func (c *configurator) WriteAuthFile(user, password string, uid, gid int) error {
exists, err := c.fileManager.FileExists(string(constants.OpenVPNAuthConf)) const filepath = string(constants.OpenVPNAuthConf)
if err != nil { file, err := c.os.OpenFile(filepath, os.O_RDONLY, 0)
if err != nil && !os.IsNotExist(err) {
return err return err
} else if exists { }
data, err := c.fileManager.ReadFile(string(constants.OpenVPNAuthConf))
if os.IsNotExist(err) {
file, err = c.os.OpenFile(filepath, os.O_WRONLY|os.O_CREATE, 0400)
if err != nil { if err != nil {
return err return err
} }
lines := strings.Split(string(data), "\n") _, err = file.WriteString(user + "\n" + password)
if len(lines) > 1 && lines[0] == user && lines[1] == password { if err != nil {
return nil _ = file.Close()
return err
} }
c.logger.Info("username and password changed", constants.OpenVPNAuthConf) err = file.Chown(uid, gid)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
} }
return c.fileManager.WriteLinesToFile(
string(constants.OpenVPNAuthConf), data, err := ioutil.ReadAll(file)
[]string{user, password}, if err != nil {
files.Ownership(uid, gid), _ = file.Close()
files.Permissions(constants.UserReadPermission)) return err
}
if err := file.Close(); err != nil {
return err
}
lines := strings.Split(string(data), "\n")
if len(lines) > 1 && lines[0] == user && lines[1] == password {
return nil
}
c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf)
file, err = c.os.OpenFile(filepath, os.O_TRUNC|os.O_WRONLY, 0400)
if err != nil {
return err
}
_, err = file.WriteString(user + "\n" + password)
if err != nil {
_ = file.Close()
return err
}
err = file.Chown(uid, gid)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
} }

View File

@@ -4,17 +4,18 @@ import (
"context" "context"
"net" "net"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -43,7 +44,7 @@ type looper struct {
// Other objects // Other objects
logger, pfLogger logging.Logger logger, pfLogger logging.Logger
client *http.Client client *http.Client
fileManager files.FileManager openFile os.OpenFileFunc
streamMerger command.StreamMerger streamMerger command.StreamMerger
cancel context.CancelFunc cancel context.CancelFunc
// Internal channels and locks // Internal channels and locks
@@ -57,7 +58,7 @@ type looper struct {
func NewLooper(settings settings.OpenVPN, func NewLooper(settings settings.OpenVPN,
username string, uid, gid int, allServers models.AllServers, username string, uid, gid int, allServers models.AllServers,
conf Configurator, fw firewall.Configurator, routing routing.Routing, conf Configurator, fw firewall.Configurator, routing routing.Routing,
logger logging.Logger, client *http.Client, fileManager files.FileManager, logger logging.Logger, client *http.Client, openFile os.OpenFileFunc,
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper { streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
return &looper{ return &looper{
state: state{ state: state{
@@ -74,7 +75,7 @@ func NewLooper(settings settings.OpenVPN,
logger: logger.WithPrefix("openvpn: "), logger: logger.WithPrefix("openvpn: "),
pfLogger: logger.WithPrefix("port forwarding: "), pfLogger: logger.WithPrefix("port forwarding: "),
client: client, client: client,
fileManager: fileManager, openFile: openFile,
streamMerger: streamMerger, streamMerger: streamMerger,
cancel: cancel, cancel: cancel,
start: make(chan struct{}), start: make(chan struct{}),
@@ -115,8 +116,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
settings.Auth, settings.Auth,
settings.Provider.ExtraConfigOptions, settings.Provider.ExtraConfigOptions,
) )
if err := l.fileManager.WriteLinesToFile(string(constants.OpenVPNConf), lines,
files.Ownership(l.uid, l.gid), files.Permissions(constants.UserReadPermission)); err != nil { if err := writeOpenvpnConf(lines, l.openFile); err != nil {
l.logger.Error(err) l.logger.Error(err)
l.cancel() l.cancel()
return return
@@ -239,6 +240,22 @@ func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
return settings.Provider.PortForwarding.Filepath return settings.Provider.PortForwarding.Filepath
} }
providerConf.PortForward(ctx, providerConf.PortForward(ctx,
client, l.fileManager, l.pfLogger, client, l.openFile, l.pfLogger,
gateway, l.fw, syncState) gateway, l.fw, syncState)
} }
func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error {
const filepath = string(constants.OpenVPNConf)
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.WriteString(strings.Join(lines, "\n"))
if err != nil {
return err
}
if err := file.Close(); err != nil {
return err
}
return nil
}

View File

@@ -3,10 +3,9 @@ package openvpn
import ( import (
"context" "context"
"io" "io"
"os"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -20,21 +19,19 @@ type Configurator interface {
} }
type configurator struct { type configurator struct {
fileManager files.FileManager logger logging.Logger
logger logging.Logger commander command.Commander
commander command.Commander os os.OS
openFile func(name string, flag int, perm os.FileMode) (*os.File, error) mkDev func(major uint32, minor uint32) uint64
mkDev func(major uint32, minor uint32) uint64 mkNod func(path string, mode uint32, dev int) error
mkNod func(path string, mode uint32, dev int) error
} }
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator { func NewConfigurator(logger logging.Logger, os os.OS) Configurator {
return &configurator{ return &configurator{
fileManager: fileManager, logger: logger.WithPrefix("openvpn configurator: "),
logger: logger.WithPrefix("openvpn configurator: "), commander: command.NewCommander(),
commander: command.NewCommander(), os: os,
openFile: os.OpenFile, mkDev: unix.Mkdev,
mkDev: unix.Mkdev, mkNod: unix.Mknod,
mkNod: unix.Mknod,
} }
} }

View File

@@ -11,7 +11,7 @@ import (
// CheckTUN checks the tunnel device is present and accessible. // CheckTUN checks the tunnel device is present and accessible.
func (c *configurator) CheckTUN() error { func (c *configurator) CheckTUN() error {
c.logger.Info("checking for device %s", constants.TunnelDevice) c.logger.Info("checking for device %s", constants.TunnelDevice)
f, err := c.openFile(string(constants.TunnelDevice), os.O_RDWR, 0) f, err := c.os.OpenFile(string(constants.TunnelDevice), os.O_RDWR, 0)
if err != nil { if err != nil {
return fmt.Errorf("TUN device is not available: %w", err) return fmt.Errorf("TUN device is not available: %w", err)
} }
@@ -23,9 +23,10 @@ func (c *configurator) CheckTUN() error {
func (c *configurator) CreateTUN() error { func (c *configurator) CreateTUN() error {
c.logger.Info("creating %s", constants.TunnelDevice) c.logger.Info("creating %s", constants.TunnelDevice)
if err := c.fileManager.CreateDir("/dev/net"); err != nil { if err := c.os.MkdirAll("/dev/net", 0751); err != nil {
return err return err
} }
const ( const (
major = 10 major = 10
minor = 200 minor = 200
@@ -34,8 +35,17 @@ func (c *configurator) CreateTUN() error {
if err := c.mkNod(string(constants.TunnelDevice), unix.S_IFCHR, int(dev)); err != nil { if err := c.mkNod(string(constants.TunnelDevice), unix.S_IFCHR, int(dev)); err != nil {
return err return err
} }
if err := c.fileManager.SetUserPermissions(string(constants.TunnelDevice), 0666); err != nil {
const filepath = string(constants.TunnelDevice)
file, err := c.os.OpenFile(filepath, os.O_WRONLY, 0666)
if err != nil {
return err return err
} }
return nil const readWriteAllPerms os.FileMode = 0666
if err := file.Chmod(readWriteAllPerms); err != nil {
_ = file.Close()
return err
}
return file.Close()
} }

9
internal/os/alias.go Normal file
View File

@@ -0,0 +1,9 @@
package os
import nativeos "os"
// Aliases used for convenience so "os" does not have to be imported
type FileMode nativeos.FileMode
var IsNotExist = nativeos.IsNotExist

16
internal/os/constants.go Normal file
View File

@@ -0,0 +1,16 @@
package os
import (
nativeos "os"
)
// Constants used for convenience so "os" does not have to be imported
//nolint:golint
const (
O_CREATE = nativeos.O_CREATE
O_TRUNC = nativeos.O_TRUNC
O_WRONLY = nativeos.O_WRONLY
O_RDONLY = nativeos.O_RDONLY
O_RDWR = nativeos.O_RDWR
)

15
internal/os/file.go Normal file
View File

@@ -0,0 +1,15 @@
package os
import (
"io"
nativeos "os"
)
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . File
type File interface {
io.ReadWriteCloser
WriteString(s string) (int, error)
Chown(uid, gid int) error
Chmod(mode nativeos.FileMode) error
}

10
internal/os/funcs.go Normal file
View File

@@ -0,0 +1,10 @@
package os
import (
nativeos "os"
)
type OpenFileFunc func(name string, flag int, perm FileMode) (File, error)
type MkdirAllFunc func(name string, perm nativeos.FileMode) error
type RemoveFunc func(name string) error
type ChownFunc func(name string, uid int, gid int) error

121
internal/os/mock_os/file.go Normal file
View File

@@ -0,0 +1,121 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/os (interfaces: File)
// Package mock_os is a generated GoMock package.
package mock_os
import (
gomock "github.com/golang/mock/gomock"
os "os"
reflect "reflect"
)
// MockFile is a mock of File interface
type MockFile struct {
ctrl *gomock.Controller
recorder *MockFileMockRecorder
}
// MockFileMockRecorder is the mock recorder for MockFile
type MockFileMockRecorder struct {
mock *MockFile
}
// NewMockFile creates a new mock instance
func NewMockFile(ctrl *gomock.Controller) *MockFile {
mock := &MockFile{ctrl: ctrl}
mock.recorder = &MockFileMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockFile) EXPECT() *MockFileMockRecorder {
return m.recorder
}
// Chmod mocks base method
func (m *MockFile) Chmod(arg0 os.FileMode) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Chmod", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Chmod indicates an expected call of Chmod
func (mr *MockFileMockRecorder) Chmod(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chmod", reflect.TypeOf((*MockFile)(nil).Chmod), arg0)
}
// Chown mocks base method
func (m *MockFile) Chown(arg0, arg1 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Chown", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Chown indicates an expected call of Chown
func (mr *MockFileMockRecorder) Chown(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chown", reflect.TypeOf((*MockFile)(nil).Chown), arg0, arg1)
}
// Close mocks base method
func (m *MockFile) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockFileMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockFile)(nil).Close))
}
// Read mocks base method
func (m *MockFile) Read(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockFileMockRecorder) Read(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockFile)(nil).Read), arg0)
}
// Write mocks base method
func (m *MockFile) Write(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockFileMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockFile)(nil).Write), arg0)
}
// WriteString mocks base method
func (m *MockFile) WriteString(arg0 string) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WriteString", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// WriteString indicates an expected call of WriteString
func (mr *MockFileMockRecorder) WriteString(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteString", reflect.TypeOf((*MockFile)(nil).WriteString), arg0)
}

121
internal/os/mock_os/os.go Normal file
View File

@@ -0,0 +1,121 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/os (interfaces: OS)
// Package mock_os is a generated GoMock package.
package mock_os
import (
gomock "github.com/golang/mock/gomock"
os "github.com/qdm12/gluetun/internal/os"
os0 "os"
reflect "reflect"
)
// MockOS is a mock of OS interface
type MockOS struct {
ctrl *gomock.Controller
recorder *MockOSMockRecorder
}
// MockOSMockRecorder is the mock recorder for MockOS
type MockOSMockRecorder struct {
mock *MockOS
}
// NewMockOS creates a new mock instance
func NewMockOS(ctrl *gomock.Controller) *MockOS {
mock := &MockOS{ctrl: ctrl}
mock.recorder = &MockOSMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockOS) EXPECT() *MockOSMockRecorder {
return m.recorder
}
// Chown mocks base method
func (m *MockOS) Chown(arg0 string, arg1, arg2 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Chown", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// Chown indicates an expected call of Chown
func (mr *MockOSMockRecorder) Chown(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chown", reflect.TypeOf((*MockOS)(nil).Chown), arg0, arg1, arg2)
}
// MkdirAll mocks base method
func (m *MockOS) MkdirAll(arg0 string, arg1 os.FileMode) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MkdirAll", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// MkdirAll indicates an expected call of MkdirAll
func (mr *MockOSMockRecorder) MkdirAll(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MkdirAll", reflect.TypeOf((*MockOS)(nil).MkdirAll), arg0, arg1)
}
// OpenFile mocks base method
func (m *MockOS) OpenFile(arg0 string, arg1 int, arg2 os.FileMode) (os.File, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenFile", arg0, arg1, arg2)
ret0, _ := ret[0].(os.File)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenFile indicates an expected call of OpenFile
func (mr *MockOSMockRecorder) OpenFile(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenFile", reflect.TypeOf((*MockOS)(nil).OpenFile), arg0, arg1, arg2)
}
// Remove mocks base method
func (m *MockOS) Remove(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Remove", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Remove indicates an expected call of Remove
func (mr *MockOSMockRecorder) Remove(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockOS)(nil).Remove), arg0)
}
// Stat mocks base method
func (m *MockOS) Stat(arg0 string) (os0.FileInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stat", arg0)
ret0, _ := ret[0].(os0.FileInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Stat indicates an expected call of Stat
func (mr *MockOSMockRecorder) Stat(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stat", reflect.TypeOf((*MockOS)(nil).Stat), arg0)
}
// Unsetenv mocks base method
func (m *MockOS) Unsetenv(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Unsetenv", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Unsetenv indicates an expected call of Unsetenv
func (mr *MockOSMockRecorder) Unsetenv(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsetenv", reflect.TypeOf((*MockOS)(nil).Unsetenv), arg0)
}

39
internal/os/os.go Normal file
View File

@@ -0,0 +1,39 @@
package os
import nativeos "os"
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . OS
type OS interface {
OpenFile(name string, flag int, perm FileMode) (File, error)
MkdirAll(name string, perm FileMode) error
Remove(name string) error
Chown(name string, uid int, gid int) error
Unsetenv(key string) error
Stat(name string) (nativeos.FileInfo, error)
}
func New() OS {
return &os{}
}
type os struct{}
func (o *os) OpenFile(name string, flag int, perm FileMode) (File, error) {
return nativeos.OpenFile(name, flag, nativeos.FileMode(perm))
}
func (o *os) MkdirAll(name string, perm FileMode) error {
return nativeos.MkdirAll(name, nativeos.FileMode(perm))
}
func (o *os) Remove(name string) error {
return nativeos.Remove(name)
}
func (o *os) Chown(name string, uid, gid int) error {
return nativeos.Chown(name, uid, gid)
}
func (o *os) Unsetenv(key string) error {
return nativeos.Unsetenv(key)
}
func (o *os) Stat(name string) (nativeos.FileInfo, error) {
return nativeos.Stat(name)
}

View File

@@ -3,6 +3,8 @@ package params
import ( import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil"
"os"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
@@ -32,10 +34,19 @@ func (p *reader) GetCyberghostClientKey() (clientKey string, err error) {
} else if len(clientKey) > 0 { } else if len(clientKey) > 0 {
return clientKey, nil return clientKey, nil
} }
content, err := p.fileManager.ReadFile(string(constants.ClientKey)) const filepath = string(constants.ClientKey)
file, err := p.os.OpenFile(filepath, os.O_RDONLY, 0)
if err != nil { if err != nil {
return "", err return "", err
} }
content, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return "", err
}
if err := file.Close(); err != nil {
return "", err
}
return extractClientKey(content) return extractClientKey(content)
} }
@@ -55,10 +66,19 @@ func extractClientKey(b []byte) (key string, err error) {
// GetCyberghostClientCertificate obtains the client certificate to use for openvpn from the // GetCyberghostClientCertificate obtains the client certificate to use for openvpn from the
// file at /gluetun/client.crt. // file at /gluetun/client.crt.
func (p *reader) GetCyberghostClientCertificate() (clientCertificate string, err error) { func (p *reader) GetCyberghostClientCertificate() (clientCertificate string, err error) {
content, err := p.fileManager.ReadFile(string(constants.ClientCertificate)) const filepath = string(constants.ClientCertificate)
file, err := p.os.OpenFile(filepath, os.O_RDONLY, 0)
if err != nil { if err != nil {
return "", err return "", err
} }
content, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return "", err
}
if err := file.Close(); err != nil {
return "", err
}
return extractClientCertificate(content) return extractClientCertificate(content)
} }

View File

@@ -11,7 +11,7 @@ import (
// GetUser obtains the user to use to connect to the VPN servers. // GetUser obtains the user to use to connect to the VPN servers.
func (r *reader) GetUser() (s string, err error) { func (r *reader) GetUser() (s string, err error) {
defer func() { defer func() {
unsetenvErr := r.unsetEnv("USER") unsetenvErr := r.os.Unsetenv("USER")
if err == nil { if err == nil {
err = unsetenvErr err = unsetenvErr
} }
@@ -22,7 +22,7 @@ func (r *reader) GetUser() (s string, err error) {
// GetPassword obtains the password to use to connect to the VPN servers. // GetPassword obtains the password to use to connect to the VPN servers.
func (r *reader) GetPassword(required bool) (s string, err error) { func (r *reader) GetPassword(required bool) (s string, err error) {
defer func() { defer func() {
unsetenvErr := r.unsetEnv("PASSWORD") unsetenvErr := r.os.Unsetenv("PASSWORD")
if err == nil { if err == nil {
err = unsetenvErr err = unsetenvErr
} }

View File

@@ -2,11 +2,10 @@ package params
import ( import (
"net" "net"
"os"
"time" "time"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
libparams "github.com/qdm12/golibs/params" libparams "github.com/qdm12/golibs/params"
"github.com/qdm12/golibs/verification" "github.com/qdm12/golibs/verification"
@@ -128,22 +127,20 @@ type Reader interface {
} }
type reader struct { type reader struct {
envParams libparams.EnvParams envParams libparams.EnvParams
logger logging.Logger logger logging.Logger
verifier verification.Verifier verifier verification.Verifier
unsetEnv func(key string) error os os.OS
fileManager files.FileManager
} }
// Newreader returns a paramsReadeer object to read parameters from // Newreader returns a paramsReadeer object to read parameters from
// environment variables. // environment variables.
func NewReader(logger logging.Logger, fileManager files.FileManager) Reader { func NewReader(logger logging.Logger, os os.OS) Reader {
return &reader{ return &reader{
envParams: libparams.NewEnvParams(), envParams: libparams.NewEnvParams(),
logger: logger, logger: logger,
verifier: verification.NewVerifier(), verifier: verification.NewVerifier(),
unsetEnv: os.Unsetenv, os: os,
fileManager: fileManager,
} }
} }

View File

@@ -36,7 +36,7 @@ func (r *reader) GetShadowSocksPort() (port uint16, err error) {
// SHADOWSOCKS_PASSWORD. // SHADOWSOCKS_PASSWORD.
func (r *reader) GetShadowSocksPassword() (password string, err error) { func (r *reader) GetShadowSocksPassword() (password string, err error) {
defer func() { defer func() {
unsetErr := r.unsetEnv("SHADOWSOCKS_PASSWORD") unsetErr := r.os.Unsetenv("SHADOWSOCKS_PASSWORD")
if err == nil { if err == nil {
err = unsetErr err = unsetErr
} }

View File

@@ -11,7 +11,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -133,7 +133,7 @@ func (c *cyberghost) BuildConf(connection models.OpenVPNConnection, verbosity in
} }
func (c *cyberghost) PortForward(ctx context.Context, client *http.Client, func (c *cyberghost) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for cyberghost") panic("port forwarding is not supported for cyberghost")
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -128,7 +128,7 @@ func (m *mullvad) BuildConf(connection models.OpenVPNConnection,
} }
func (m *mullvad) PortForward(ctx context.Context, client *http.Client, func (m *mullvad) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for mullvad") panic("port forwarding is not supported for mullvad")
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -142,7 +142,7 @@ func (n *nordvpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
} }
func (n *nordvpn) PortForward(ctx context.Context, client *http.Client, func (n *nordvpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for nordvpn") panic("port forwarding is not supported for nordvpn")
} }

View File

@@ -19,7 +19,7 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
gluetunLog "github.com/qdm12/gluetun/internal/logging" gluetunLog "github.com/qdm12/gluetun/internal/logging"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -183,7 +183,7 @@ func (p *pia) BuildConf(connection models.OpenVPNConnection, verbosity int, user
//nolint:gocognit //nolint:gocognit
func (p *pia) PortForward(ctx context.Context, client *http.Client, func (p *pia) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
if !p.activeServer.PortForward { if !p.activeServer.PortForward {
pfLogger.Error("The server %s does not support port forwarding", p.activeServer.Region) pfLogger.Error("The server %s does not support port forwarding", p.activeServer.Region)
@@ -203,7 +203,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
return return
} }
defer pfLogger.Warn("loop exited") defer pfLogger.Warn("loop exited")
data, err := readPIAPortForwardData(fileManager) data, err := readPIAPortForwardData(openFile)
if err != nil { if err != nil {
pfLogger.Error(err) pfLogger.Error(err)
} }
@@ -222,7 +222,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
if !dataFound || expired { if !dataFound || expired {
tryUntilSuccessful(ctx, pfLogger, func() error { tryUntilSuccessful(ctx, pfLogger, func() error {
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager) data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
return err return err
}) })
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -240,12 +240,9 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
return return
} }
filepath := syncState(data.Port) filepath := string(syncState(data.Port))
pfLogger.Info("Writing port to %s", filepath) pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile( if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
files.Permissions(constants.AllReadWritePermissions),
); err != nil {
pfLogger.Error(err) pfLogger.Error(err)
} }
@@ -281,7 +278,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123)) pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
oldPort := data.Port oldPort := data.Port
for { for {
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager) data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
if err != nil { if err != nil {
pfLogger.Error(err) pfLogger.Error(err)
continue continue
@@ -298,10 +295,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
} }
filepath := syncState(data.Port) filepath := syncState(data.Port)
pfLogger.Info("Writing port to %s", filepath) pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile( if err := writePortForwardedToFile(openFile, string(filepath), data.Port); err != nil {
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
files.Permissions(constants.AllReadWritePermissions),
); err != nil {
pfLogger.Error(err) pfLogger.Error(err)
} }
if err := bindPIAPort(ctx, client, gateway, data); err != nil { if err := bindPIAPort(ctx, client, gateway, data); err != nil {
@@ -365,8 +359,8 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) {
} }
func refreshPIAPortForwardData(ctx context.Context, client *http.Client, func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) { gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
data.Token, err = fetchPIAToken(ctx, fileManager, client) data.Token, err = fetchPIAToken(ctx, openFile, client)
if err != nil { if err != nil {
return data, fmt.Errorf("cannot obtain token: %w", err) return data, fmt.Errorf("cannot obtain token: %w", err)
} }
@@ -374,7 +368,7 @@ func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
if err != nil { if err != nil {
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err) return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
} }
if err := writePIAPortForwardData(fileManager, data); err != nil { if err := writePIAPortForwardData(openFile, data); err != nil {
return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err) return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err)
} }
return data, nil return data, nil
@@ -393,34 +387,39 @@ type piaPortForwardData struct {
Expiration time.Time `json:"expires_at"` Expiration time.Time `json:"expires_at"`
} }
func readPIAPortForwardData(fileManager files.FileManager) (data piaPortForwardData, err error) { func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
const filepath = string(constants.PIAPortForward) const filepath = string(constants.PIAPortForward)
exists, err := fileManager.FileExists(filepath) file, err := openFile(filepath, os.O_RDONLY, 0)
if err != nil { if os.IsNotExist(err) {
return data, err
} else if !exists {
return data, nil return data, nil
} else if err != nil {
return data, err
} }
b, err := fileManager.ReadFile(filepath)
decoder := json.NewDecoder(file)
err = decoder.Decode(&data)
if err != nil { if err != nil {
_ = file.Close()
return data, err return data, err
} }
if err := json.Unmarshal(b, &data); err != nil { return data, file.Close()
return data, err
}
return data, nil
} }
func writePIAPortForwardData(fileManager files.FileManager, data piaPortForwardData) (err error) { func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) {
b, err := json.Marshal(&data) const filepath = string(constants.PIAPortForward)
if err != nil { file, err := openFile(filepath,
return fmt.Errorf("cannot encode data: %w", err) os.O_CREATE|os.O_TRUNC|os.O_WRONLY,
} 0644)
err = fileManager.WriteToFile(string(constants.PIAPortForward), b)
if err != nil { if err != nil {
return err return err
} }
return nil encoder := json.NewEncoder(file)
err = encoder.Encode(data)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
} }
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) { func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
@@ -449,8 +448,9 @@ func packPIAPayload(port uint16, token string, expiration time.Time) (payload st
return payload, nil return payload, nil
} }
func fetchPIAToken(ctx context.Context, fileManager files.FileManager, client *http.Client) (token string, err error) { func fetchPIAToken(ctx context.Context, openFile os.OpenFileFunc,
username, password, err := getOpenvpnCredentials(fileManager) client *http.Client) (token string, err error) {
username, password, err := getOpenvpnCredentials(openFile)
if err != nil { if err != nil {
return "", fmt.Errorf("cannot get Openvpn credentials: %w", err) return "", fmt.Errorf("cannot get Openvpn credentials: %w", err)
} }
@@ -489,10 +489,19 @@ func fetchPIAToken(ctx context.Context, fileManager files.FileManager, client *h
return result.Token, nil return result.Token, nil
} }
func getOpenvpnCredentials(fileManager files.FileManager) (username, password string, err error) { func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) {
authData, err := fileManager.ReadFile(string(constants.OpenVPNAuthConf)) const filepath = string(constants.OpenVPNAuthConf)
file, err := openFile(filepath, os.O_RDONLY, 0)
if err != nil { if err != nil {
return "", "", fmt.Errorf("cannot read openvpn auth file: %w", err) return "", "", fmt.Errorf("cannot read openvpn auth file: %s", err)
}
authData, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return "", "", fmt.Errorf("cannot read openvpn auth file: %s", err)
}
if err := file.Close(); err != nil {
return "", "", err
} }
lines := strings.Split(string(authData), "\n") lines := strings.Split(string(authData), "\n")
const minLines = 2 const minLines = 2
@@ -586,3 +595,17 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
} }
return nil return nil
} }
func writePortForwardedToFile(openFile os.OpenFileFunc,
filepath string, port uint16) (err error) {
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -117,7 +117,7 @@ func (s *privado) BuildConf(connection models.OpenVPNConnection, verbosity int,
} }
func (s *privado) PortForward(ctx context.Context, client *http.Client, func (s *privado) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for privado") panic("port forwarding is not supported for privado")
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -18,7 +18,7 @@ type Provider interface {
BuildConf(connection models.OpenVPNConnection, verbosity int, username string, BuildConf(connection models.OpenVPNConnection, verbosity int, username string,
root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string)
PortForward(ctx context.Context, client *http.Client, PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) syncState func(port uint16) (pfFilepath models.Filepath))
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -150,7 +150,7 @@ func (p *purevpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
} }
func (p *purevpn) PortForward(ctx context.Context, client *http.Client, func (p *purevpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for purevpn") panic("port forwarding is not supported for purevpn")
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -138,7 +138,7 @@ func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity int
} }
func (s *surfshark) PortForward(ctx context.Context, client *http.Client, func (s *surfshark) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for surfshark") panic("port forwarding is not supported for surfshark")
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -119,7 +119,7 @@ func (v *vyprvpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
} }
func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client, func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for vyprvpn") panic("port forwarding is not supported for vyprvpn")
} }

View File

@@ -11,7 +11,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files" "github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -132,7 +132,7 @@ func (w *windscribe) BuildConf(connection models.OpenVPNConnection, verbosity in
} }
func (w *windscribe) PortForward(ctx context.Context, client *http.Client, func (w *windscribe) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) { syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for windscribe") panic("port forwarding is not supported for windscribe")
} }

27
internal/publicip/fs.go Normal file
View File

@@ -0,0 +1,27 @@
package publicip
import "github.com/qdm12/gluetun/internal/os"
func persistPublicIP(openFile os.OpenFileFunc,
filepath string, content string, uid, gid int) error {
file, err := openFile(
filepath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE,
0644)
if err != nil {
return err
}
_, err = file.WriteString(content)
if err != nil {
_ = file.Close()
return err
}
if err := file.Chown(uid, gid); err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -8,8 +8,8 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
) )
@@ -27,9 +27,9 @@ type Looper interface {
type looper struct { type looper struct {
state state state state
// Objects // Objects
getter IPGetter getter IPGetter
logger logging.Logger logger logging.Logger
fileManager files.FileManager os os.OS
// Fixed settings // Fixed settings
uid int uid int
gid int gid int
@@ -45,8 +45,9 @@ type looper struct {
timeSince func(time.Time) time.Duration timeSince func(time.Time) time.Duration
} }
func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager, func NewLooper(client network.Client, logger logging.Logger,
settings settings.PublicIP, uid, gid int) Looper { settings settings.PublicIP, uid, gid int,
os os.OS) Looper {
return &looper{ return &looper{
state: state{ state: state{
status: constants.Stopped, status: constants.Stopped,
@@ -55,7 +56,7 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F
// Objects // Objects
getter: NewIPGetter(client), getter: NewIPGetter(client),
logger: logger.WithPrefix("ip getter: "), logger: logger.WithPrefix("ip getter: "),
fileManager: fileManager, os: os,
uid: uid, uid: uid,
gid: gid, gid: gid,
start: make(chan struct{}), start: make(chan struct{}),
@@ -125,7 +126,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
close(errorCh) close(errorCh)
filepath := l.GetSettings().IPFilepath filepath := l.GetSettings().IPFilepath
l.logger.Info("Removing ip file %s", filepath) l.logger.Info("Removing ip file %s", filepath)
if err := l.fileManager.Remove(string(filepath)); err != nil { if err := l.os.Remove(string(filepath)); err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
return return
@@ -142,12 +143,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
getCancel() getCancel()
l.state.setPublicIP(ip) l.state.setPublicIP(ip)
l.logger.Info("Public IP address is %s", ip) l.logger.Info("Public IP address is %s", ip)
const userReadWritePermissions = 0600 filepath := string(l.state.settings.IPFilepath)
err := l.fileManager.WriteLinesToFile( err := persistPublicIP(l.os.OpenFile, filepath, ip.String(), l.uid, l.gid)
string(l.state.settings.IPFilepath),
[]string{ip.String()},
files.Ownership(l.uid, l.gid),
files.Permissions(userReadWritePermissions))
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
} }

View File

@@ -1,10 +1,8 @@
package storage package storage
import ( import (
"io/ioutil"
"os"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -14,17 +12,13 @@ type Storage interface {
} }
type storage struct { type storage struct {
osStat func(name string) (os.FileInfo, error) os os.OS
readFile func(filename string) (data []byte, err error) logger logging.Logger
writeFile func(filename string, data []byte, perm os.FileMode) error
logger logging.Logger
} }
func New(logger logging.Logger) Storage { func New(logger logging.Logger, os os.OS) Storage {
return &storage{ return &storage{
osStat: os.Stat, os: os,
readFile: ioutil.ReadFile, logger: logger.WithPrefix("storage: "),
writeFile: ioutil.WriteFile,
logger: logger.WithPrefix("storage: "),
} }
} }

View File

@@ -3,10 +3,10 @@ package storage
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
"reflect" "reflect"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
) )
const ( const (
@@ -29,14 +29,18 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) (
allServers models.AllServers, err error) { allServers models.AllServers, err error) {
// Eventually read file // Eventually read file
var serversOnFile models.AllServers var serversOnFile models.AllServers
_, err = s.osStat(jsonFilepath) file, err := s.os.OpenFile(jsonFilepath, os.O_RDONLY, 0)
if err != nil && !os.IsNotExist(err) {
return allServers, err
}
if err == nil { if err == nil {
serversOnFile, err = s.readFromFile() var serversOnFile models.AllServers
if err != nil { decoder := json.NewDecoder(file)
if err := decoder.Decode(&serversOnFile); err != nil {
_ = file.Close()
return allServers, err return allServers, err
} }
} else if !os.IsNotExist(err) { return allServers, file.Close()
return allServers, err
} }
// Merge data from file and hardcoded // Merge data from file and hardcoded
@@ -51,24 +55,16 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) (
return allServers, s.FlushToFile(allServers) return allServers, s.FlushToFile(allServers)
} }
func (s *storage) readFromFile() (servers models.AllServers, err error) {
bytes, err := s.readFile(jsonFilepath)
if err != nil {
return servers, err
}
if err := json.Unmarshal(bytes, &servers); err != nil {
return servers, err
}
return servers, nil
}
func (s *storage) FlushToFile(servers models.AllServers) error { func (s *storage) FlushToFile(servers models.AllServers) error {
bytes, err := json.MarshalIndent(servers, "", " ") file, err := s.os.OpenFile(jsonFilepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil { if err != nil {
return fmt.Errorf("cannot write to file: %w", err)
}
if err := s.writeFile(jsonFilepath, bytes, 0644); err != nil {
return err return err
} }
return nil encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(servers); err != nil {
_ = file.Close()
return fmt.Errorf("cannot write to file: %w", err)
}
return file.Close()
} }