Maint: do not mock os functions

- Use filepaths with /tmp for tests instead
- Only mock functions where filepath can't be specified such as user.Lookup
This commit is contained in:
Quentin McGaw (desktop)
2021-07-23 16:06:19 +00:00
parent e94684aa39
commit 21f4cf7ab5
48 changed files with 226 additions and 243 deletions

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
nativeos "os" "os"
"os/signal" "os/signal"
"strconv" "strconv"
"strings" "strings"
@@ -33,8 +33,6 @@ import (
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
versionpkg "github.com/qdm12/gluetun/internal/version" versionpkg "github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/os/user"
"github.com/qdm12/golibs/params" "github.com/qdm12/golibs/params"
"github.com/qdm12/goshutdown" "github.com/qdm12/goshutdown"
"github.com/qdm12/gosplash" "github.com/qdm12/gosplash"
@@ -61,21 +59,19 @@ func main() {
} }
ctx := context.Background() ctx := context.Background()
ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, nativeos.Interrupt) ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
logger := logging.NewParent(logging.Settings{}) logger := logging.NewParent(logging.Settings{})
args := nativeos.Args args := os.Args
os := os.New()
osUser := user.New()
unix := unix.New() unix := unix.New()
cli := cli.New() cli := cli.New()
env := params.NewEnv() env := params.NewEnv()
errorCh := make(chan error) errorCh := make(chan error)
go func() { go func() {
errorCh <- _main(ctx, buildInfo, args, logger, env, os, osUser, unix, cli) errorCh <- _main(ctx, buildInfo, args, logger, env, unix, cli)
}() }()
select { select {
@@ -86,7 +82,7 @@ func main() {
stop() stop()
close(errorCh) close(errorCh)
if err == nil { // expected exit such as healthcheck if err == nil { // expected exit such as healthcheck
nativeos.Exit(0) os.Exit(0)
} }
logger.Error(err) logger.Error(err)
cancel() cancel()
@@ -104,7 +100,7 @@ func main() {
logger.Warn("Shutdown timed out") logger.Warn("Shutdown timed out")
} }
nativeos.Exit(1) os.Exit(1)
} }
var ( var (
@@ -113,18 +109,18 @@ var (
//nolint:gocognit,gocyclo //nolint:gocognit,gocyclo
func _main(ctx context.Context, buildInfo models.BuildInformation, func _main(ctx context.Context, buildInfo models.BuildInformation,
args []string, logger logging.ParentLogger, env params.Env, os os.OS, args []string, logger logging.ParentLogger, env params.Env,
osUser user.OSUser, unix unix.Unix, cli cli.CLI) error { unix unix.Unix, cli cli.CLI) error {
if len(args) > 1 { // cli operation if len(args) > 1 { // cli operation
switch args[1] { switch args[1] {
case "healthcheck": case "healthcheck":
return cli.HealthCheck(ctx, env, os, logger) return cli.HealthCheck(ctx, env, logger)
case "clientkey": case "clientkey":
return cli.ClientKey(args[2:], os.OpenFile) return cli.ClientKey(args[2:])
case "openvpnconfig": case "openvpnconfig":
return cli.OpenvpnConfig(os, logger) return cli.OpenvpnConfig(logger)
case "update": case "update":
return cli.Update(ctx, args[2:], os, logger) return cli.Update(ctx, args[2:], logger)
default: default:
return fmt.Errorf("%w: %s", errCommandUnknown, args[1]) return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
} }
@@ -133,19 +129,19 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
const clientTimeout = 15 * time.Second const clientTimeout = 15 * time.Second
httpClient := &http.Client{Timeout: clientTimeout} httpClient := &http.Client{Timeout: clientTimeout}
// Create configurators // Create configurators
alpineConf := alpine.NewConfigurator(os.OpenFile, osUser) alpineConf := alpine.NewConfigurator()
ovpnConf := openvpn.NewConfigurator( ovpnConf := openvpn.NewConfigurator(
logger.NewChild(logging.Settings{Prefix: "openvpn configurator: "}), logger.NewChild(logging.Settings{Prefix: "openvpn configurator: "}),
os, unix) unix)
dnsCrypto := dnscrypto.New(httpClient, "", "") dnsCrypto := dnscrypto.New(httpClient, "", "")
const cacertsPath = "/etc/ssl/certs/ca-certificates.crt" const cacertsPath = "/etc/ssl/certs/ca-certificates.crt"
dnsConf := unbound.NewConfigurator(nil, os.OpenFile, dnsCrypto, dnsConf := unbound.NewConfigurator(nil, dnsCrypto,
"/etc/unbound", "/usr/sbin/unbound", cacertsPath) "/etc/unbound", "/usr/sbin/unbound", cacertsPath)
routingConf := routing.NewRouting( routingConf := routing.NewRouting(
logger.NewChild(logging.Settings{Prefix: "routing: "})) logger.NewChild(logging.Settings{Prefix: "routing: "}))
firewallConf := firewall.NewConfigurator( firewallConf := firewall.NewConfigurator(
logger.NewChild(logging.Settings{Prefix: "firewall: "}), logger.NewChild(logging.Settings{Prefix: "firewall: "}),
routingConf, os.OpenFile) routingConf)
announcementExp, err := time.Parse(time.RFC3339, "2021-07-22T00:00:00Z") announcementExp, err := time.Parse(time.RFC3339, "2021-07-22T00:00:00Z")
if err != nil { if err != nil {
@@ -179,7 +175,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
} }
var allSettings configuration.Settings var allSettings configuration.Settings
err = allSettings.Read(env, os, err = allSettings.Read(env,
logger.NewChild(logging.Settings{Prefix: "configuration: "})) logger.NewChild(logging.Settings{Prefix: "configuration: "}))
if err != nil { if err != nil {
return err return err
@@ -196,7 +192,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
// TODO run this in a loop or in openvpn to reload from file without restarting // TODO run this in a loop or in openvpn to reload from file without restarting
storage := storage.New( storage := storage.New(
logger.NewChild(logging.Settings{Prefix: "storage: "}), logger.NewChild(logging.Settings{Prefix: "storage: "}),
os, constants.ServersData) constants.ServersData)
allServers, err := storage.SyncServers(constants.GetAllServers()) allServers, err := storage.SyncServers(constants.GetAllServers())
if err != nil { if err != nil {
return err return err
@@ -314,7 +310,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings) otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings)
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh) ovpnConf, firewallConf, routingConf, logger, httpClient, tunnelReadyCh)
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler( openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second}) "openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
// wait for restartOpenvpn // wait for restartOpenvpn
@@ -331,7 +327,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
unboundLogger := logger.NewChild(logging.Settings{Prefix: "dns over tls: "}) unboundLogger := logger.NewChild(logging.Settings{Prefix: "dns over tls: "})
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient, unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient,
unboundLogger, os.OpenFile) unboundLogger)
dnsHandler, dnsCtx, dnsDone := goshutdown.NewGoRoutineHandler( dnsHandler, dnsCtx, dnsDone := goshutdown.NewGoRoutineHandler(
"unbound", defaultGoRoutineSettings) "unbound", defaultGoRoutineSettings)
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
@@ -340,7 +336,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
publicIPLooper := publicip.NewLooper(httpClient, publicIPLooper := publicip.NewLooper(httpClient,
logger.NewChild(logging.Settings{Prefix: "ip getter: "}), logger.NewChild(logging.Settings{Prefix: "ip getter: "}),
allSettings.PublicIP, puid, pgid, os) allSettings.PublicIP, puid, pgid)
pubIPHandler, pubIPCtx, pubIPDone := goshutdown.NewGoRoutineHandler( pubIPHandler, pubIPCtx, pubIPDone := goshutdown.NewGoRoutineHandler(
"public IP", defaultGoRoutineSettings) "public IP", defaultGoRoutineSettings)
go publicIPLooper.Run(pubIPCtx, pubIPDone) go publicIPLooper.Run(pubIPCtx, pubIPDone)

2
go.mod
View File

@@ -5,7 +5,7 @@ go 1.16
require ( require (
github.com/fatih/color v1.12.0 github.com/fatih/color v1.12.0
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/qdm12/dns v1.9.0 github.com/qdm12/dns v1.10.0
github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99 github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99
github.com/qdm12/goshutdown v0.1.0 github.com/qdm12/goshutdown v0.1.0
github.com/qdm12/gosplash v0.1.0 github.com/qdm12/gosplash v0.1.0

4
go.sum
View File

@@ -63,8 +63,8 @@ github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMg
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/qdm12/dns v1.9.0 h1:p4g/BfbpQ+gJRpQdklDAnybkjds+OuenF0wEGoZ8/AI= github.com/qdm12/dns v1.10.0 h1:WX5QQ5+2h34xfhfxJTmvyURbs9XE4qNrEGtyNeq38Bw=
github.com/qdm12/dns v1.9.0/go.mod h1:fqZoDf3VzddnKBMNI/OzZUp5H4dO0VBw1fp4qPkolOg= github.com/qdm12/dns v1.10.0/go.mod h1:fqZoDf3VzddnKBMNI/OzZUp5H4dO0VBw1fp4qPkolOg=
github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg= github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg=
github.com/qdm12/golibs v0.0.0-20210716185557-66793f4ddd80/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg= github.com/qdm12/golibs v0.0.0-20210716185557-66793f4ddd80/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg=
github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99 h1:2OKHAR0SK8BtTtWCRNoSn58eh+iVDA3Cwq4i2CnD3i4= github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99 h1:2OKHAR0SK8BtTtWCRNoSn58eh+iVDA3Cwq4i2CnD3i4=

View File

@@ -3,9 +3,7 @@ package alpine
import ( import (
"context" "context"
"os/user"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/os/user"
) )
type Configurator interface { type Configurator interface {
@@ -14,13 +12,17 @@ type Configurator interface {
} }
type configurator struct { type configurator struct {
openFile os.OpenFileFunc alpineReleasePath string
osUser user.OSUser passwdPath string
lookupID func(uid string) (*user.User, error)
lookup func(username string) (*user.User, error)
} }
func NewConfigurator(openFile os.OpenFileFunc, osUser user.OSUser) Configurator { func NewConfigurator() Configurator {
return &configurator{ return &configurator{
openFile: openFile, alpineReleasePath: "/etc/alpine-release",
osUser: osUser, passwdPath: "/etc/passwd",
lookupID: user.LookupId,
lookup: user.Lookup,
} }
} }

View File

@@ -15,7 +15,7 @@ var (
// CreateUser creates a user in Alpine with the given UID. // CreateUser creates a user in Alpine with the given UID.
func (c *configurator) CreateUser(username string, uid int) (createdUsername string, err error) { func (c *configurator) CreateUser(username string, uid int) (createdUsername string, err error) {
UIDStr := strconv.Itoa(uid) UIDStr := strconv.Itoa(uid)
u, err := c.osUser.LookupID(UIDStr) u, err := c.lookupID(UIDStr)
_, unknownUID := err.(user.UnknownUserIdError) _, unknownUID := err.(user.UnknownUserIdError)
if err != nil && !unknownUID { if err != nil && !unknownUID {
return "", err return "", err
@@ -28,7 +28,7 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
return u.Username, nil return u.Username, nil
} }
u, err = c.osUser.Lookup(username) u, err = c.lookup(username)
_, unknownUsername := err.(user.UnknownUserError) _, unknownUsername := err.(user.UnknownUserError)
if err != nil && !unknownUsername { if err != nil && !unknownUsername {
return "", err return "", err
@@ -39,7 +39,7 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
ErrUserAlreadyExists, username, u.Uid, uid) ErrUserAlreadyExists, username, u.Uid, uid)
} }
file, err := c.openFile("/etc/passwd", os.O_APPEND|os.O_WRONLY, 0644) file, err := os.OpenFile(c.passwdPath, os.O_APPEND|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -8,7 +8,7 @@ import (
) )
func (c *configurator) Version(ctx context.Context) (version string, err error) { func (c *configurator) Version(ctx context.Context) (version string, err error) {
file, err := c.openFile("/etc/alpine-release", os.O_RDONLY, 0) file, err := os.OpenFile(c.alpineReleasePath, os.O_RDONLY, 0)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -5,19 +5,22 @@ import (
"context" "context"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params" "github.com/qdm12/golibs/params"
) )
type CLI interface { type CLI interface {
ClientKey(args []string, openFile os.OpenFileFunc) error ClientKey(args []string) error
HealthCheck(ctx context.Context, env params.Env, os os.OS, logger logging.Logger) error HealthCheck(ctx context.Context, env params.Env, logger logging.Logger) error
OpenvpnConfig(os os.OS, logger logging.Logger) error OpenvpnConfig(logger logging.Logger) error
Update(ctx context.Context, args []string, os os.OS, logger logging.Logger) error Update(ctx context.Context, args []string, logger logging.Logger) error
} }
type cli struct{} type cli struct {
repoServersPath string
}
func New() CLI { func New() CLI {
return &cli{} return &cli{
repoServersPath: "./internal/constants/servers.json",
}
} }

View File

@@ -4,19 +4,19 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"os"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/os"
) )
func (c *cli) ClientKey(args []string, openFile os.OpenFileFunc) error { func (c *cli) ClientKey(args []string) error {
flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError) flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError)
filepath := flagSet.String("path", constants.ClientKey, "file path to the client.key file") filepath := flagSet.String("path", constants.ClientKey, "file path to the client.key file")
if err := flagSet.Parse(args); err != nil { if err := flagSet.Parse(args); err != nil {
return err return err
} }
file, err := openFile(*filepath, os.O_RDONLY, 0) file, err := os.OpenFile(*filepath, os.O_RDONLY, 0)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -9,15 +9,14 @@ import (
"github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/healthcheck" "github.com/qdm12/gluetun/internal/healthcheck"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params" "github.com/qdm12/golibs/params"
) )
func (c *cli) HealthCheck(ctx context.Context, env params.Env, func (c *cli) HealthCheck(ctx context.Context, env params.Env,
os os.OS, logger logging.Logger) error { logger logging.Logger) error {
// Extract the health server port from the configuration. // Extract the health server port from the configuration.
config := configuration.Health{} config := configuration.Health{}
err := config.Read(env, os, logger) err := config.Read(env, logger)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -10,17 +10,16 @@ import (
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params" "github.com/qdm12/golibs/params"
) )
func (c *cli) OpenvpnConfig(os os.OS, logger logging.Logger) error { func (c *cli) OpenvpnConfig(logger logging.Logger) error {
var allSettings configuration.Settings var allSettings configuration.Settings
err := allSettings.Read(params.NewEnv(), os, logger) err := allSettings.Read(params.NewEnv(), logger)
if err != nil { if err != nil {
return err return err
} }
allServers, err := storage.New(logger, os, constants.ServersData). allServers, err := storage.New(logger, constants.ServersData).
SyncServers(constants.GetAllServers()) SyncServers(constants.GetAllServers())
if err != nil { if err != nil {
return err return err

View File

@@ -7,7 +7,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"net/http" "net/http"
nativeos "os" "os"
"time" "time"
"github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/configuration"
@@ -16,7 +16,6 @@ import (
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
var ( var (
@@ -26,7 +25,7 @@ var (
ErrWriteToFile = errors.New("cannot write updated information to file") ErrWriteToFile = errors.New("cannot write updated information to file")
) )
func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger logging.Logger) error { func (c *cli) Update(ctx context.Context, args []string, logger logging.Logger) error {
options := configuration.Updater{CLI: true} options := configuration.Updater{CLI: true}
var endUserMode, maintainerMode, updateAll bool var endUserMode, maintainerMode, updateAll bool
flagSet := flag.NewFlagSet("update", flag.ExitOnError) flagSet := flag.NewFlagSet("update", flag.ExitOnError)
@@ -65,7 +64,7 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
const clientTimeout = 10 * time.Second const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout} httpClient := &http.Client{Timeout: clientTimeout}
storage := storage.New(logger, os, constants.ServersData) storage := storage.New(logger, constants.ServersData)
currentServers, err := storage.SyncServers(constants.GetAllServers()) currentServers, err := storage.SyncServers(constants.GetAllServers())
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrSyncServers, err) return fmt.Errorf("%w: %s", ErrSyncServers, err)
@@ -83,7 +82,7 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
} }
if maintainerMode { if maintainerMode {
if err := writeToEmbeddedJSON(os, allServers); err != nil { if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil {
return fmt.Errorf("%w: %s", ErrWriteToFile, err) return fmt.Errorf("%w: %s", ErrWriteToFile, err)
} }
} }
@@ -91,10 +90,11 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
return nil return nil
} }
func writeToEmbeddedJSON(os os.OS, allServers models.AllServers) error { func writeToEmbeddedJSON(repoServersPath string,
allServers models.AllServers) error {
const perms = 0600 const perms = 0600
f, err := os.OpenFile("./internal/constants/servers.json", f, err := os.OpenFile(repoServersPath,
nativeos.O_TRUNC|nativeos.O_WRONLY|nativeos.O_CREATE, perms) os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -5,7 +5,6 @@ import (
"strings" "strings"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params" "github.com/qdm12/golibs/params"
) )
@@ -33,8 +32,8 @@ func (settings *Health) lines() (lines []string) {
} }
// Read is to be used for the healthcheck query mode. // Read is to be used for the healthcheck query mode.
func (settings *Health) Read(env params.Env, os os.OS, logger logging.Logger) (err error) { func (settings *Health) Read(env params.Env, logger logging.Logger) (err error) {
reader := newReader(env, os, logger) reader := newReader(env, logger)
return settings.read(reader) return settings.read(reader)
} }

View File

@@ -8,7 +8,6 @@ import (
"strings" "strings"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params" "github.com/qdm12/golibs/params"
"github.com/qdm12/golibs/verification" "github.com/qdm12/golibs/verification"
) )
@@ -17,15 +16,13 @@ type reader struct {
env params.Env env params.Env
logger logging.Logger logger logging.Logger
regex verification.Regex regex verification.Regex
os os.OS
} }
func newReader(env params.Env, os os.OS, logger logging.Logger) reader { func newReader(env params.Env, logger logging.Logger) reader {
return reader{ return reader{
env: env, env: env,
logger: logger, logger: logger,
regex: verification.NewRegex(), regex: verification.NewRegex(),
os: os,
} }
} }

View File

@@ -4,9 +4,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
"strings" "strings"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params" "github.com/qdm12/golibs/params"
) )
@@ -48,7 +48,7 @@ func (r *reader) getFromEnvOrSecretFile(envKey string, compulsory bool, retroKey
ErrGetSecretFilepath, secretFilepathEnvKey, err) ErrGetSecretFilepath, secretFilepathEnvKey, err)
} }
file, fileErr := r.os.OpenFile(filepath, os.O_RDONLY, 0) file, fileErr := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(fileErr) { if os.IsNotExist(fileErr) {
if compulsory { if compulsory {
return "", envErr return "", envErr
@@ -85,7 +85,7 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
return b, fmt.Errorf("environment variable %s: %w: %s", key, ErrGetSecretFilepath, err) return b, fmt.Errorf("environment variable %s: %w: %s", key, ErrGetSecretFilepath, err)
} }
b, err = readFromFile(r.os.OpenFile, secretFilepath) b, err = readFromFile(secretFilepath)
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
return b, fmt.Errorf("%w: %s", ErrReadSecretFile, err) return b, fmt.Errorf("%w: %s", ErrReadSecretFile, err)
} else if err == nil { } else if err == nil {
@@ -93,7 +93,7 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
} }
// Secret file does not exist, try the non secret file // Secret file does not exist, try the non secret file
b, err = readFromFile(r.os.OpenFile, filepath) b, err = readFromFile(filepath)
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("%w: %s", ErrReadSecretFile, err) return nil, fmt.Errorf("%w: %s", ErrReadSecretFile, err)
} else if err == nil { } else if err == nil {
@@ -102,8 +102,8 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
return nil, fmt.Errorf("%w: %s and %s", ErrFilesDoNotExist, secretFilepath, filepath) return nil, fmt.Errorf("%w: %s and %s", ErrFilesDoNotExist, secretFilepath, filepath)
} }
func readFromFile(openFile os.OpenFileFunc, filepath string) (b []byte, err error) { func readFromFile(filepath string) (b []byte, err error) {
file, err := openFile(filepath, os.O_RDONLY, 0) file, err := os.Open(filepath)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -6,7 +6,6 @@ import (
"strings" "strings"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params" "github.com/qdm12/golibs/params"
) )
@@ -62,8 +61,8 @@ var (
// Read obtains all configuration options for the program and returns an error as soon // Read obtains all configuration options for the program and returns an error as soon
// as an error is encountered reading them. // as an error is encountered reading them.
func (settings *Settings) Read(env params.Env, os os.OS, logger logging.Logger) (err error) { func (settings *Settings) Read(env params.Env, logger logging.Logger) (err error) {
r := newReader(env, os, logger) r := newReader(env, logger)
settings.VersionInformation, err = r.env.OnOff("VERSION_INFORMATION", params.Default("on")) settings.VersionInformation, err = r.env.OnOff("VERSION_INFORMATION", params.Default("on"))
if err != nil { if err != nil {

View File

@@ -16,7 +16,6 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
type Looper interface { type Looper interface {
@@ -33,6 +32,7 @@ type Looper interface {
type looper struct { type looper struct {
state *state state *state
conf unbound.Configurator conf unbound.Configurator
resolvConf string
blockBuilder blacklist.Builder blockBuilder blacklist.Builder
client *http.Client client *http.Client
logger logging.Logger logger logging.Logger
@@ -45,13 +45,12 @@ type looper struct {
backoffTime time.Duration backoffTime time.Duration
timeNow func() time.Time timeNow func() time.Time
timeSince func(time.Time) time.Duration timeSince func(time.Time) time.Duration
openFile os.OpenFileFunc
} }
const defaultBackoffTime = 10 * time.Second const defaultBackoffTime = 10 * time.Second
func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *http.Client, func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *http.Client,
logger logging.Logger, openFile os.OpenFileFunc) Looper { logger logging.Logger) Looper {
start := make(chan struct{}) start := make(chan struct{})
running := make(chan models.LoopStatus) running := make(chan models.LoopStatus)
stop := make(chan struct{}) stop := make(chan struct{})
@@ -63,6 +62,7 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht
return &looper{ return &looper{
state: state, state: state,
conf: conf, conf: conf,
resolvConf: "/etc/resolv.conf",
blockBuilder: blacklist.NewBuilder(client), blockBuilder: blacklist.NewBuilder(client),
client: client, client: client,
logger: logger, logger: logger,
@@ -75,7 +75,6 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht
backoffTime: defaultBackoffTime, backoffTime: defaultBackoffTime,
timeNow: time.Now, timeNow: time.Now,
timeSince: time.Since, timeSince: time.Since,
openFile: openFile,
} }
} }
@@ -227,8 +226,8 @@ func (l *looper) setupUnbound(ctx context.Context) (
// use Unbound // use Unbound
nameserver.UseDNSInternally(net.IP{127, 0, 0, 1}) nameserver.UseDNSInternally(net.IP{127, 0, 0, 1})
err = nameserver.UseDNSSystemWide(l.openFile, err = nameserver.UseDNSSystemWide(l.resolvConf, net.IP{127, 0, 0, 1},
net.IP{127, 0, 0, 1}, settings.KeepNameserver) settings.KeepNameserver)
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
@@ -256,8 +255,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
l.logger.Info("using plaintext DNS at address %s", targetIP) l.logger.Info("using plaintext DNS at address %s", targetIP)
} }
nameserver.UseDNSInternally(targetIP) nameserver.UseDNSInternally(targetIP)
if err := nameserver.UseDNSSystemWide(l.openFile, err := nameserver.UseDNSSystemWide(l.resolvConf, targetIP, settings.KeepNameserver)
targetIP, settings.KeepNameserver); err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
return return
@@ -271,7 +270,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
l.logger.Info("using plaintext DNS at address " + targetIP.String()) l.logger.Info("using plaintext DNS at address " + targetIP.String())
} }
nameserver.UseDNSInternally(targetIP) nameserver.UseDNSInternally(targetIP)
if err := nameserver.UseDNSSystemWide(l.openFile, targetIP, settings.KeepNameserver); err != nil { err := nameserver.UseDNSSystemWide(l.resolvConf, targetIP, settings.KeepNameserver)
if err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
} }

View File

@@ -136,7 +136,7 @@ func (c *configurator) enable(ctx context.Context) (err error) {
} }
} }
if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil { if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("%w: %s", ErrUserPostRules, err) return fmt.Errorf("%w: %s", ErrUserPostRules, err)
} }

View File

@@ -11,7 +11,6 @@ import (
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
// Configurator allows to change firewall rules and modify network routes. // Configurator allows to change firewall rules and modify network routes.
@@ -32,7 +31,6 @@ type configurator struct { //nolint:maligned
commander command.Commander commander command.Commander
logger logging.Logger logger logging.Logger
routing routing.Routing routing routing.Routing
openFile os.OpenFileFunc // for custom iptables rules
iptablesMutex sync.Mutex iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex ip6tablesMutex sync.Mutex
debug bool debug bool
@@ -43,7 +41,8 @@ type configurator struct { //nolint:maligned
networkInfoMutex sync.Mutex networkInfoMutex sync.Mutex
// Fixed state // Fixed state
ip6Tables bool ip6Tables bool
customRulesPath string
// State // State
enabled bool enabled bool
@@ -54,15 +53,15 @@ type configurator struct { //nolint:maligned
} }
// NewConfigurator creates a new Configurator instance. // NewConfigurator creates a new Configurator instance.
func NewConfigurator(logger logging.Logger, routing routing.Routing, openFile os.OpenFileFunc) Configurator { func NewConfigurator(logger logging.Logger, routing routing.Routing) Configurator {
commander := command.NewCommander() commander := command.NewCommander()
return &configurator{ return &configurator{
commander: commander, commander: commander,
logger: logger, logger: logger,
routing: routing, routing: routing,
openFile: openFile,
allowedInputPorts: make(map[uint16]string), allowedInputPorts: make(map[uint16]string),
ip6Tables: ip6tablesSupported(context.Background(), commander), ip6Tables: ip6tablesSupported(context.Background(), commander),
customRulesPath: "/iptables/post-rules.txt",
} }
} }

View File

@@ -196,7 +196,7 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port
} }
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error { func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
file, err := c.openFile(filepath, os.O_RDONLY, 0) file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil return nil
} else if err != nil { } else if err != nil {

View File

@@ -10,14 +10,14 @@ import (
// WriteAuthFile writes the OpenVPN auth file to disk with the right permissions. // WriteAuthFile writes the OpenVPN auth file to disk with the right permissions.
func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) error { func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) error {
file, err := c.os.OpenFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0) file, err := os.Open(c.authFilePath)
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
return err return err
} }
if os.IsNotExist(err) { if os.IsNotExist(err) {
file, err = c.os.OpenFile(constants.OpenVPNAuthConf, os.O_WRONLY|os.O_CREATE, 0400) file, err = os.OpenFile(c.authFilePath, os.O_WRONLY|os.O_CREATE, 0400)
if err != nil { if err != nil {
return err return err
} }
@@ -49,7 +49,7 @@ func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) erro
} }
c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf) c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf)
file, err = c.os.OpenFile(constants.OpenVPNAuthConf, os.O_TRUNC|os.O_WRONLY, 0400) file, err = os.OpenFile(c.authFilePath, os.O_TRUNC|os.O_WRONLY, 0400)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"os"
"strconv" "strconv"
"strings" "strings"
@@ -12,14 +13,13 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/golibs/os"
) )
var errProcessCustomConfig = errors.New("cannot process custom config") var errProcessCustomConfig = errors.New("cannot process custom config")
func (l *looper) processCustomConfig(settings configuration.OpenVPN) ( func (l *looper) processCustomConfig(settings configuration.OpenVPN) (
lines []string, connection models.OpenVPNConnection, err error) { lines []string, connection models.OpenVPNConnection, err error) {
lines, err = readCustomConfigLines(settings.Config, l.openFile) lines, err = readCustomConfigLines(settings.Config)
if err != nil { if err != nil {
return nil, connection, fmt.Errorf("%w: %s", errProcessCustomConfig, err) return nil, connection, fmt.Errorf("%w: %s", errProcessCustomConfig, err)
} }
@@ -35,9 +35,9 @@ func (l *looper) processCustomConfig(settings configuration.OpenVPN) (
return lines, connection, nil return lines, connection, nil
} }
func readCustomConfigLines(filepath string, openFile os.OpenFileFunc) ( func readCustomConfigLines(filepath string) (
lines []string, err error) { lines []string, err error) {
file, err := openFile(filepath, os.O_RDONLY, 0) file, err := os.Open(filepath)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"net" "net"
"net/http" "net/http"
"os"
"strings" "strings"
"time" "time"
@@ -14,7 +15,6 @@ import (
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
type Looper interface { type Looper interface {
@@ -34,9 +34,10 @@ type Looper interface {
type looper struct { type looper struct {
state *state state *state
// Fixed parameters // Fixed parameters
username string username string
puid int puid int
pgid int pgid int
targetConfPath string
// Configurators // Configurators
conf Configurator conf Configurator
fw firewall.Configurator fw firewall.Configurator
@@ -44,7 +45,6 @@ type looper struct {
// Other objects // Other objects
logger, pfLogger logging.Logger logger, pfLogger logging.Logger
client *http.Client client *http.Client
openFile os.OpenFileFunc
tunnelReady chan<- struct{} tunnelReady chan<- struct{}
// Internal channels and values // Internal channels and values
stop <-chan struct{} stop <-chan struct{}
@@ -64,7 +64,7 @@ const (
func NewLooper(settings configuration.OpenVPN, func NewLooper(settings configuration.OpenVPN,
username string, puid, pgid int, allServers models.AllServers, username string, puid, pgid int, allServers models.AllServers,
conf Configurator, fw firewall.Configurator, routing routing.Routing, conf Configurator, fw firewall.Configurator, routing routing.Routing,
logger logging.ParentLogger, client *http.Client, openFile os.OpenFileFunc, logger logging.ParentLogger, client *http.Client,
tunnelReady chan<- struct{}) Looper { tunnelReady chan<- struct{}) Looper {
start := make(chan struct{}) start := make(chan struct{})
running := make(chan models.LoopStatus) running := make(chan models.LoopStatus)
@@ -79,13 +79,13 @@ func NewLooper(settings configuration.OpenVPN,
username: username, username: username,
puid: puid, puid: puid,
pgid: pgid, pgid: pgid,
targetConfPath: constants.OpenVPNConf,
conf: conf, conf: conf,
fw: fw, fw: fw,
routing: routing, routing: routing,
logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}), logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}),
pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}), pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}),
client: client, client: client,
openFile: openFile,
tunnelReady: tunnelReady, tunnelReady: tunnelReady,
start: start, start: start,
running: running, running: running,
@@ -145,7 +145,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
} }
} }
if err := writeOpenvpnConf(lines, l.openFile); err != nil { if err := l.writeOpenvpnConf(lines); err != nil {
l.signalOrSetStatus(constants.Crashed) l.signalOrSetStatus(constants.Crashed)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
continue continue
@@ -279,13 +279,12 @@ func (l *looper) portForward(ctx context.Context,
defer l.state.settingsMu.RUnlock() defer l.state.settingsMu.RUnlock()
return settings.Provider.PortForwarding.Filepath return settings.Provider.PortForwarding.Filepath
} }
providerConf.PortForward(ctx, providerConf.PortForward(ctx, client, l.pfLogger,
client, l.openFile, l.pfLogger,
gateway, l.fw, syncState) gateway, l.fw, syncState)
} }
func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error { func (l *looper) writeOpenvpnConf(lines []string) error {
file, err := openFile(constants.OpenVPNConf, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) file, err := os.OpenFile(l.targetConfPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -5,10 +5,10 @@ package openvpn
import ( import (
"context" "context"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/unix" "github.com/qdm12/gluetun/internal/unix"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
type Configurator interface { type Configurator interface {
@@ -22,17 +22,19 @@ type Configurator interface {
} }
type configurator struct { type configurator struct {
logger logging.Logger logger logging.Logger
commander command.Commander commander command.Commander
os os.OS unix unix.Unix
unix unix.Unix authFilePath string
tunDevPath string
} }
func NewConfigurator(logger logging.Logger, os os.OS, unix unix.Unix) Configurator { func NewConfigurator(logger logging.Logger, unix unix.Unix) Configurator {
return &configurator{ return &configurator{
logger: logger, logger: logger,
commander: command.NewCommander(), commander: command.NewCommander(),
os: os, unix: unix,
unix: unix, authFilePath: constants.OpenVPNAuthConf,
tunDevPath: constants.TunnelDevice,
} }
} }

View File

@@ -3,15 +3,15 @@ package openvpn
import ( import (
"fmt" "fmt"
"os" "os"
"path/filepath"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/unix" "github.com/qdm12/gluetun/internal/unix"
) )
// CheckTUN checks the tunnel device is present and accessible. // CheckTUN checks the tunnel device is present and accessible.
func (c *configurator) CheckTUN() error { func (c *configurator) CheckTUN() error {
c.logger.Info("checking for device %s", constants.TunnelDevice) c.logger.Info("checking for device " + c.tunDevPath)
f, err := c.os.OpenFile(constants.TunnelDevice, os.O_RDWR, 0) f, err := os.OpenFile(c.tunDevPath, os.O_RDWR, 0)
if err != nil { if err != nil {
return fmt.Errorf("TUN device is not available: %w", err) return fmt.Errorf("TUN device is not available: %w", err)
} }
@@ -22,8 +22,10 @@ func (c *configurator) CheckTUN() error {
} }
func (c *configurator) CreateTUN() error { func (c *configurator) CreateTUN() error {
c.logger.Info("creating %s", constants.TunnelDevice) c.logger.Info("creating " + c.tunDevPath)
if err := c.os.MkdirAll("/dev/net", 0751); err != nil { //nolint:gomnd
parentDir := filepath.Dir(c.tunDevPath)
if err := os.MkdirAll(parentDir, 0751); err != nil { //nolint:gomnd
return err return err
} }
@@ -32,17 +34,13 @@ func (c *configurator) CreateTUN() error {
minor = 200 minor = 200
) )
dev := c.unix.Mkdev(major, minor) dev := c.unix.Mkdev(major, minor)
if err := c.unix.Mknod(constants.TunnelDevice, unix.S_IFCHR, int(dev)); err != nil { if err := c.unix.Mknod(c.tunDevPath, unix.S_IFCHR, int(dev)); err != nil {
return err return err
} }
file, err := c.os.OpenFile(constants.TunnelDevice, os.O_WRONLY, 0666) //nolint:gomnd
if err != nil {
return err
}
const readWriteAllPerms os.FileMode = 0666 const readWriteAllPerms os.FileMode = 0666
if err := file.Chmod(readWriteAllPerms); err != nil { file, err := os.OpenFile(c.tunDevPath, os.O_WRONLY, readWriteAllPerms)
_ = file.Close() if err != nil {
return err return err
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (c *Cyberghost) PortForward(ctx context.Context, client *http.Client, func (c *Cyberghost) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Cyberghost") panic("port forwarding is not supported for Cyberghost")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (f *Fastestvpn) PortForward(ctx context.Context, client *http.Client, func (f *Fastestvpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for FastestVPN") panic("port forwarding is not supported for FastestVPN")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (f *HideMyAss) PortForward(ctx context.Context, client *http.Client, func (f *HideMyAss) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for HideMyAss") panic("port forwarding is not supported for HideMyAss")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (i *Ipvanish) PortForward(ctx context.Context, client *http.Client, func (i *Ipvanish) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Ipvanish") panic("port forwarding is not supported for Ipvanish")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (i *Ivpn) PortForward(ctx context.Context, client *http.Client, func (i *Ivpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Ivpn") panic("port forwarding is not supported for Ivpn")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (m *Mullvad) PortForward(ctx context.Context, client *http.Client, func (m *Mullvad) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding logic is not needed for Mullvad") panic("port forwarding logic is not needed for Mullvad")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (n *Nordvpn) PortForward(ctx context.Context, client *http.Client, func (n *Nordvpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for NordVPN") panic("port forwarding is not supported for NordVPN")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (p *Privado) PortForward(ctx context.Context, client *http.Client, func (p *Privado) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Privado") panic("port forwarding is not supported for Privado")
} }

View File

@@ -10,6 +10,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -18,7 +19,6 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/format" "github.com/qdm12/gluetun/internal/format"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
var ( var (
@@ -28,7 +28,7 @@ var (
// PortForward obtains a VPN server side port forwarded from PIA. // PortForward obtains a VPN server side port forwarded from PIA.
//nolint:gocognit //nolint:gocognit
func (p *PIA) PortForward(ctx context.Context, client *http.Client, func (p *PIA) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator, logger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
commonName := p.activeServer.ServerName commonName := p.activeServer.ServerName
if !p.activeServer.PortForward { if !p.activeServer.PortForward {
@@ -47,7 +47,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
return return
} }
data, err := readPIAPortForwardData(openFile) data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
@@ -67,7 +67,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
if !dataFound || expired { if !dataFound || expired {
tryUntilSuccessful(ctx, logger, func() error { tryUntilSuccessful(ctx, logger, func() error {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile) data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
return err return err
}) })
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -91,7 +92,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
filepath := syncState(data.Port) filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath) logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil { if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error(err) logger.Error(err)
} }
@@ -128,7 +129,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
data.Expiration.Format(time.RFC1123) + ", getting another one") data.Expiration.Format(time.RFC1123) + ", getting another one")
oldPort := data.Port oldPort := data.Port
for { for {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile) data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
continue continue
@@ -146,7 +148,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
} }
filepath := syncState(data.Port) filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath) logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil { if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error("Cannot write port forward data to file: " + err.Error()) logger.Error("Cannot write port forward data to file: " + err.Error())
} }
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
@@ -168,8 +170,8 @@ var (
) )
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client, func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) { gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, openFile, client) data.Token, err = fetchToken(ctx, client, authFilePath)
if err != nil { if err != nil {
return data, fmt.Errorf("%w: %s", ErrFetchToken, err) return data, fmt.Errorf("%w: %s", ErrFetchToken, err)
} }
@@ -179,7 +181,7 @@ func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *htt
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err) return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err)
} }
if err := writePIAPortForwardData(openFile, data); err != nil { if err := writePIAPortForwardData(portForwardPath, data); err != nil {
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err) return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err)
} }
@@ -199,8 +201,8 @@ type piaPortForwardData struct {
Expiration time.Time `json:"expires_at"` Expiration time.Time `json:"expires_at"`
} }
func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) { func readPIAPortForwardData(portForwardPath string) (data piaPortForwardData, err error) {
file, err := openFile(constants.PIAPortForward, os.O_RDONLY, 0) file, err := os.Open(portForwardPath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return data, nil return data, nil
} else if err != nil { } else if err != nil {
@@ -216,8 +218,8 @@ func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData,
return data, file.Close() return data, file.Close()
} }
func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) { func writePIAPortForwardData(portForwardPath string, data piaPortForwardData) (err error) {
file, err := openFile(constants.PIAPortForward, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) file, err := os.OpenFile(portForwardPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return err return err
} }
@@ -269,9 +271,9 @@ var (
errEmptyToken = errors.New("token received is empty") errEmptyToken = errors.New("token received is empty")
) )
func fetchToken(ctx context.Context, openFile os.OpenFileFunc, func fetchToken(ctx context.Context, client *http.Client,
client *http.Client) (token string, err error) { authFilePath string) (token string, err error) {
username, password, err := getOpenvpnCredentials(openFile) username, password, err := getOpenvpnCredentials(authFilePath)
if err != nil { if err != nil {
return "", fmt.Errorf("%w: %s", errGetCredentials, err) return "", fmt.Errorf("%w: %s", errGetCredentials, err)
} }
@@ -321,8 +323,9 @@ var (
errAuthFileMalformed = errors.New("authentication file is malformed") errAuthFileMalformed = errors.New("authentication file is malformed")
) )
func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) { func getOpenvpnCredentials(authFilePath string) (
file, err := openFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0) username, password string, err error) {
file, err := os.Open(authFilePath)
if err != nil { if err != nil {
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err) return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
} }
@@ -460,9 +463,8 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
return nil return nil
} }
func writePortForwardedToFile(openFile os.OpenFileFunc, func writePortForwardedToFile(filepath string, port uint16) (err error) {
filepath string, port uint16) (err error) { file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -4,6 +4,7 @@ import (
"math/rand" "math/rand"
"time" "time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
@@ -12,12 +13,18 @@ type PIA struct {
randSource rand.Source randSource rand.Source
timeNow func() time.Time timeNow func() time.Time
activeServer models.PIAServer activeServer models.PIAServer
// Port forwarding
portForwardPath string
authFilePath string
} }
func New(servers []models.PIAServer, randSource rand.Source, timeNow func() time.Time) *PIA { func New(servers []models.PIAServer, randSource rand.Source,
timeNow func() time.Time) *PIA {
return &PIA{ return &PIA{
servers: servers, servers: servers,
timeNow: timeNow, timeNow: timeNow,
randSource: randSource, randSource: randSource,
portForwardPath: constants.PIAPortForward,
authFilePath: constants.OpenVPNAuthConf,
} }
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (p *Privatevpn) PortForward(ctx context.Context, client *http.Client, func (p *Privatevpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for PrivateVPN") panic("port forwarding is not supported for PrivateVPN")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (p *Protonvpn) PortForward(ctx context.Context, client *http.Client, func (p *Protonvpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for ProtonVPN") panic("port forwarding is not supported for ProtonVPN")
} }

View File

@@ -30,7 +30,6 @@ import (
"github.com/qdm12/gluetun/internal/provider/vyprvpn" "github.com/qdm12/gluetun/internal/provider/vyprvpn"
"github.com/qdm12/gluetun/internal/provider/windscribe" "github.com/qdm12/gluetun/internal/provider/windscribe"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
// Provider contains methods to read and modify the openvpn configuration to connect as a client. // Provider contains methods to read and modify the openvpn configuration to connect as a client.
@@ -38,7 +37,7 @@ type Provider interface {
GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error) GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error)
BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string) BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string)
PortForward(ctx context.Context, client *http.Client, PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) syncState func(port uint16) (pfFilepath string))
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (p *Purevpn) PortForward(ctx context.Context, client *http.Client, func (p *Purevpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for PureVPN") panic("port forwarding is not supported for PureVPN")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (s *Surfshark) PortForward(ctx context.Context, client *http.Client, func (s *Surfshark) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Surfshark") panic("port forwarding is not supported for Surfshark")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (t *Torguard) PortForward(ctx context.Context, client *http.Client, func (t *Torguard) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Torguard") panic("port forwarding is not supported for Torguard")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (p *Provider) PortForward(ctx context.Context, client *http.Client, func (p *Provider) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for VPN Unlimited") panic("port forwarding is not supported for VPN Unlimited")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (v *Vyprvpn) PortForward(ctx context.Context, clienv *http.Client, func (v *Vyprvpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Vyprvpn") panic("port forwarding is not supported for Vyprvpn")
} }

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
func (w *Windscribe) PortForward(ctx context.Context, clienw *http.Client, func (w *Windscribe) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) { syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Windscribe") panic("port forwarding is not supported for Windscribe")
} }

View File

@@ -1,13 +1,11 @@
package publicip package publicip
import "github.com/qdm12/golibs/os" import (
"os"
)
func persistPublicIP(openFile os.OpenFileFunc, func persistPublicIP(path string, content string, puid, pgid int) error {
filepath string, content string, puid, pgid int) error { file, err := os.OpenFile(path, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
file, err := openFile(
filepath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE,
0644)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"net" "net"
"net/http" "net/http"
"os"
"sync" "sync"
"time" "time"
@@ -11,7 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
type Looper interface { type Looper interface {
@@ -31,7 +31,6 @@ type looper struct {
getter IPGetter getter IPGetter
client *http.Client client *http.Client
logger logging.Logger logger logging.Logger
os os.OS
// Fixed settings // Fixed settings
puid int puid int
pgid int pgid int
@@ -51,8 +50,7 @@ type looper struct {
const defaultBackoffTime = 5 * time.Second const defaultBackoffTime = 5 * time.Second
func NewLooper(client *http.Client, logger logging.Logger, func NewLooper(client *http.Client, logger logging.Logger,
settings configuration.PublicIP, puid, pgid int, settings configuration.PublicIP, puid, pgid int) Looper {
os os.OS) Looper {
return &looper{ return &looper{
state: state{ state: state{
status: constants.Stopped, status: constants.Stopped,
@@ -62,7 +60,6 @@ func NewLooper(client *http.Client, logger logging.Logger,
client: client, client: client,
getter: NewIPGetter(client), getter: NewIPGetter(client),
logger: logger, logger: logger,
os: os,
puid: puid, puid: puid,
pgid: pgid, pgid: pgid,
start: make(chan struct{}), start: make(chan struct{}),
@@ -136,7 +133,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
close(errorCh) close(errorCh)
filepath := l.GetSettings().IPFilepath filepath := l.GetSettings().IPFilepath
l.logger.Info("Removing ip file " + filepath) l.logger.Info("Removing ip file " + filepath)
if err := l.os.Remove(filepath); err != nil { if err := os.Remove(filepath); err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
return return
@@ -161,7 +158,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
} }
l.logger.Info(message) l.logger.Info(message)
err = persistPublicIP(l.os.OpenFile, l.state.settings.IPFilepath, err = persistPublicIP(l.state.settings.IPFilepath,
ip.String(), l.puid, l.pgid) ip.String(), l.puid, l.pgid)
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)

View File

@@ -0,0 +1,7 @@
package storage
import "github.com/qdm12/gluetun/internal/models"
func (s *storage) FlushToFile(allServers models.AllServers) error {
return flushToFile(s.filepath, allServers)
}

View File

@@ -4,7 +4,6 @@ package storage
import ( import (
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
) )
type Storage interface { type Storage interface {
@@ -14,14 +13,12 @@ type Storage interface {
} }
type storage struct { type storage struct {
os os.OS
logger logging.Logger logger logging.Logger
filepath string filepath string
} }
func New(logger logging.Logger, os os.OS, filepath string) Storage { func New(logger logging.Logger, filepath string) Storage {
return &storage{ return &storage{
os: os,
logger: logger, logger: logger,
filepath: filepath, filepath: filepath,
} }

View File

@@ -5,11 +5,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/os"
) )
var ( var (
@@ -39,7 +39,7 @@ func countServers(allServers models.AllServers) int {
func (s *storage) SyncServers(hardcodedServers models.AllServers) ( func (s *storage) SyncServers(hardcodedServers models.AllServers) (
allServers models.AllServers, err error) { allServers models.AllServers, err error) {
serversOnFile, err := s.readFromFile(s.filepath) serversOnFile, err := readFromFile(s.filepath)
if err != nil { if err != nil {
return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err) return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err)
} }
@@ -62,14 +62,14 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers) (
return allServers, nil return allServers, nil
} }
if err := s.FlushToFile(allServers); err != nil { if err := flushToFile(s.filepath, allServers); err != nil {
return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err) return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err)
} }
return allServers, nil return allServers, nil
} }
func (s *storage) readFromFile(filepath string) (servers models.AllServers, err error) { func readFromFile(filepath string) (servers models.AllServers, err error) {
file, err := s.os.OpenFile(filepath, os.O_RDONLY, 0) file, err := os.Open(filepath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return servers, nil return servers, nil
} else if err != nil { } else if err != nil {
@@ -86,13 +86,13 @@ func (s *storage) readFromFile(filepath string) (servers models.AllServers, err
return servers, file.Close() return servers, file.Close()
} }
func (s *storage) FlushToFile(servers models.AllServers) error { func flushToFile(path string, servers models.AllServers) error {
dirPath := filepath.Dir(s.filepath) dirPath := filepath.Dir(path)
if err := s.os.MkdirAll(dirPath, 0644); err != nil { if err := os.MkdirAll(dirPath, 0644); err != nil {
return err return err
} }
file, err := s.os.OpenFile(s.filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil { if err != nil {
return err return err
} }