Maint: do not mock os functions

- Use filepaths with /tmp for tests instead
- Only mock functions where filepath can't be specified such as user.Lookup
This commit is contained in:
Quentin McGaw (desktop)
2021-07-23 16:06:19 +00:00
parent e94684aa39
commit 21f4cf7ab5
48 changed files with 226 additions and 243 deletions

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"net"
"net/http"
nativeos "os"
"os"
"os/signal"
"strconv"
"strings"
@@ -33,8 +33,6 @@ import (
"github.com/qdm12/gluetun/internal/updater"
versionpkg "github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/os/user"
"github.com/qdm12/golibs/params"
"github.com/qdm12/goshutdown"
"github.com/qdm12/gosplash"
@@ -61,21 +59,19 @@ func main() {
}
ctx := context.Background()
ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, nativeos.Interrupt)
ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
ctx, cancel := context.WithCancel(ctx)
logger := logging.NewParent(logging.Settings{})
args := nativeos.Args
os := os.New()
osUser := user.New()
args := os.Args
unix := unix.New()
cli := cli.New()
env := params.NewEnv()
errorCh := make(chan error)
go func() {
errorCh <- _main(ctx, buildInfo, args, logger, env, os, osUser, unix, cli)
errorCh <- _main(ctx, buildInfo, args, logger, env, unix, cli)
}()
select {
@@ -86,7 +82,7 @@ func main() {
stop()
close(errorCh)
if err == nil { // expected exit such as healthcheck
nativeos.Exit(0)
os.Exit(0)
}
logger.Error(err)
cancel()
@@ -104,7 +100,7 @@ func main() {
logger.Warn("Shutdown timed out")
}
nativeos.Exit(1)
os.Exit(1)
}
var (
@@ -113,18 +109,18 @@ var (
//nolint:gocognit,gocyclo
func _main(ctx context.Context, buildInfo models.BuildInformation,
args []string, logger logging.ParentLogger, env params.Env, os os.OS,
osUser user.OSUser, unix unix.Unix, cli cli.CLI) error {
args []string, logger logging.ParentLogger, env params.Env,
unix unix.Unix, cli cli.CLI) error {
if len(args) > 1 { // cli operation
switch args[1] {
case "healthcheck":
return cli.HealthCheck(ctx, env, os, logger)
return cli.HealthCheck(ctx, env, logger)
case "clientkey":
return cli.ClientKey(args[2:], os.OpenFile)
return cli.ClientKey(args[2:])
case "openvpnconfig":
return cli.OpenvpnConfig(os, logger)
return cli.OpenvpnConfig(logger)
case "update":
return cli.Update(ctx, args[2:], os, logger)
return cli.Update(ctx, args[2:], logger)
default:
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
}
@@ -133,19 +129,19 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
const clientTimeout = 15 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
// Create configurators
alpineConf := alpine.NewConfigurator(os.OpenFile, osUser)
alpineConf := alpine.NewConfigurator()
ovpnConf := openvpn.NewConfigurator(
logger.NewChild(logging.Settings{Prefix: "openvpn configurator: "}),
os, unix)
unix)
dnsCrypto := dnscrypto.New(httpClient, "", "")
const cacertsPath = "/etc/ssl/certs/ca-certificates.crt"
dnsConf := unbound.NewConfigurator(nil, os.OpenFile, dnsCrypto,
dnsConf := unbound.NewConfigurator(nil, dnsCrypto,
"/etc/unbound", "/usr/sbin/unbound", cacertsPath)
routingConf := routing.NewRouting(
logger.NewChild(logging.Settings{Prefix: "routing: "}))
firewallConf := firewall.NewConfigurator(
logger.NewChild(logging.Settings{Prefix: "firewall: "}),
routingConf, os.OpenFile)
routingConf)
announcementExp, err := time.Parse(time.RFC3339, "2021-07-22T00:00:00Z")
if err != nil {
@@ -179,7 +175,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
}
var allSettings configuration.Settings
err = allSettings.Read(env, os,
err = allSettings.Read(env,
logger.NewChild(logging.Settings{Prefix: "configuration: "}))
if err != nil {
return err
@@ -196,7 +192,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
// TODO run this in a loop or in openvpn to reload from file without restarting
storage := storage.New(
logger.NewChild(logging.Settings{Prefix: "storage: "}),
os, constants.ServersData)
constants.ServersData)
allServers, err := storage.SyncServers(constants.GetAllServers())
if err != nil {
return err
@@ -314,7 +310,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings)
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh)
ovpnConf, firewallConf, routingConf, logger, httpClient, tunnelReadyCh)
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
// wait for restartOpenvpn
@@ -331,7 +327,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
unboundLogger := logger.NewChild(logging.Settings{Prefix: "dns over tls: "})
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient,
unboundLogger, os.OpenFile)
unboundLogger)
dnsHandler, dnsCtx, dnsDone := goshutdown.NewGoRoutineHandler(
"unbound", defaultGoRoutineSettings)
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
@@ -340,7 +336,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
publicIPLooper := publicip.NewLooper(httpClient,
logger.NewChild(logging.Settings{Prefix: "ip getter: "}),
allSettings.PublicIP, puid, pgid, os)
allSettings.PublicIP, puid, pgid)
pubIPHandler, pubIPCtx, pubIPDone := goshutdown.NewGoRoutineHandler(
"public IP", defaultGoRoutineSettings)
go publicIPLooper.Run(pubIPCtx, pubIPDone)

2
go.mod
View File

@@ -5,7 +5,7 @@ go 1.16
require (
github.com/fatih/color v1.12.0
github.com/golang/mock v1.6.0
github.com/qdm12/dns v1.9.0
github.com/qdm12/dns v1.10.0
github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99
github.com/qdm12/goshutdown v0.1.0
github.com/qdm12/gosplash v0.1.0

4
go.sum
View File

@@ -63,8 +63,8 @@ github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMg
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/qdm12/dns v1.9.0 h1:p4g/BfbpQ+gJRpQdklDAnybkjds+OuenF0wEGoZ8/AI=
github.com/qdm12/dns v1.9.0/go.mod h1:fqZoDf3VzddnKBMNI/OzZUp5H4dO0VBw1fp4qPkolOg=
github.com/qdm12/dns v1.10.0 h1:WX5QQ5+2h34xfhfxJTmvyURbs9XE4qNrEGtyNeq38Bw=
github.com/qdm12/dns v1.10.0/go.mod h1:fqZoDf3VzddnKBMNI/OzZUp5H4dO0VBw1fp4qPkolOg=
github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg=
github.com/qdm12/golibs v0.0.0-20210716185557-66793f4ddd80/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg=
github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99 h1:2OKHAR0SK8BtTtWCRNoSn58eh+iVDA3Cwq4i2CnD3i4=

View File

@@ -3,9 +3,7 @@ package alpine
import (
"context"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/os/user"
"os/user"
)
type Configurator interface {
@@ -14,13 +12,17 @@ type Configurator interface {
}
type configurator struct {
openFile os.OpenFileFunc
osUser user.OSUser
alpineReleasePath string
passwdPath string
lookupID func(uid string) (*user.User, error)
lookup func(username string) (*user.User, error)
}
func NewConfigurator(openFile os.OpenFileFunc, osUser user.OSUser) Configurator {
func NewConfigurator() Configurator {
return &configurator{
openFile: openFile,
osUser: osUser,
alpineReleasePath: "/etc/alpine-release",
passwdPath: "/etc/passwd",
lookupID: user.LookupId,
lookup: user.Lookup,
}
}

View File

@@ -15,7 +15,7 @@ var (
// CreateUser creates a user in Alpine with the given UID.
func (c *configurator) CreateUser(username string, uid int) (createdUsername string, err error) {
UIDStr := strconv.Itoa(uid)
u, err := c.osUser.LookupID(UIDStr)
u, err := c.lookupID(UIDStr)
_, unknownUID := err.(user.UnknownUserIdError)
if err != nil && !unknownUID {
return "", err
@@ -28,7 +28,7 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
return u.Username, nil
}
u, err = c.osUser.Lookup(username)
u, err = c.lookup(username)
_, unknownUsername := err.(user.UnknownUserError)
if err != nil && !unknownUsername {
return "", err
@@ -39,7 +39,7 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
ErrUserAlreadyExists, username, u.Uid, uid)
}
file, err := c.openFile("/etc/passwd", os.O_APPEND|os.O_WRONLY, 0644)
file, err := os.OpenFile(c.passwdPath, os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return "", err
}

View File

@@ -8,7 +8,7 @@ import (
)
func (c *configurator) Version(ctx context.Context) (version string, err error) {
file, err := c.openFile("/etc/alpine-release", os.O_RDONLY, 0)
file, err := os.OpenFile(c.alpineReleasePath, os.O_RDONLY, 0)
if err != nil {
return "", err
}

View File

@@ -5,19 +5,22 @@ import (
"context"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params"
)
type CLI interface {
ClientKey(args []string, openFile os.OpenFileFunc) error
HealthCheck(ctx context.Context, env params.Env, os os.OS, logger logging.Logger) error
OpenvpnConfig(os os.OS, logger logging.Logger) error
Update(ctx context.Context, args []string, os os.OS, logger logging.Logger) error
ClientKey(args []string) error
HealthCheck(ctx context.Context, env params.Env, logger logging.Logger) error
OpenvpnConfig(logger logging.Logger) error
Update(ctx context.Context, args []string, logger logging.Logger) error
}
type cli struct{}
type cli struct {
repoServersPath string
}
func New() CLI {
return &cli{}
return &cli{
repoServersPath: "./internal/constants/servers.json",
}
}

View File

@@ -4,19 +4,19 @@ import (
"flag"
"fmt"
"io"
"os"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/os"
)
func (c *cli) ClientKey(args []string, openFile os.OpenFileFunc) error {
func (c *cli) ClientKey(args []string) error {
flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError)
filepath := flagSet.String("path", constants.ClientKey, "file path to the client.key file")
if err := flagSet.Parse(args); err != nil {
return err
}
file, err := openFile(*filepath, os.O_RDONLY, 0)
file, err := os.OpenFile(*filepath, os.O_RDONLY, 0)
if err != nil {
return err
}

View File

@@ -9,15 +9,14 @@ import (
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/healthcheck"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params"
)
func (c *cli) HealthCheck(ctx context.Context, env params.Env,
os os.OS, logger logging.Logger) error {
logger logging.Logger) error {
// Extract the health server port from the configuration.
config := configuration.Health{}
err := config.Read(env, os, logger)
err := config.Read(env, logger)
if err != nil {
return err
}

View File

@@ -10,17 +10,16 @@ import (
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params"
)
func (c *cli) OpenvpnConfig(os os.OS, logger logging.Logger) error {
func (c *cli) OpenvpnConfig(logger logging.Logger) error {
var allSettings configuration.Settings
err := allSettings.Read(params.NewEnv(), os, logger)
err := allSettings.Read(params.NewEnv(), logger)
if err != nil {
return err
}
allServers, err := storage.New(logger, os, constants.ServersData).
allServers, err := storage.New(logger, constants.ServersData).
SyncServers(constants.GetAllServers())
if err != nil {
return err

View File

@@ -7,7 +7,7 @@ import (
"flag"
"fmt"
"net/http"
nativeos "os"
"os"
"time"
"github.com/qdm12/gluetun/internal/configuration"
@@ -16,7 +16,6 @@ import (
"github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
var (
@@ -26,7 +25,7 @@ var (
ErrWriteToFile = errors.New("cannot write updated information to file")
)
func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger logging.Logger) error {
func (c *cli) Update(ctx context.Context, args []string, logger logging.Logger) error {
options := configuration.Updater{CLI: true}
var endUserMode, maintainerMode, updateAll bool
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
@@ -65,7 +64,7 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
storage := storage.New(logger, os, constants.ServersData)
storage := storage.New(logger, constants.ServersData)
currentServers, err := storage.SyncServers(constants.GetAllServers())
if err != nil {
return fmt.Errorf("%w: %s", ErrSyncServers, err)
@@ -83,7 +82,7 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
}
if maintainerMode {
if err := writeToEmbeddedJSON(os, allServers); err != nil {
if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil {
return fmt.Errorf("%w: %s", ErrWriteToFile, err)
}
}
@@ -91,10 +90,11 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
return nil
}
func writeToEmbeddedJSON(os os.OS, allServers models.AllServers) error {
func writeToEmbeddedJSON(repoServersPath string,
allServers models.AllServers) error {
const perms = 0600
f, err := os.OpenFile("./internal/constants/servers.json",
nativeos.O_TRUNC|nativeos.O_WRONLY|nativeos.O_CREATE, perms)
f, err := os.OpenFile(repoServersPath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)
if err != nil {
return err
}

View File

@@ -5,7 +5,6 @@ import (
"strings"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params"
)
@@ -33,8 +32,8 @@ func (settings *Health) lines() (lines []string) {
}
// Read is to be used for the healthcheck query mode.
func (settings *Health) Read(env params.Env, os os.OS, logger logging.Logger) (err error) {
reader := newReader(env, os, logger)
func (settings *Health) Read(env params.Env, logger logging.Logger) (err error) {
reader := newReader(env, logger)
return settings.read(reader)
}

View File

@@ -8,7 +8,6 @@ import (
"strings"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params"
"github.com/qdm12/golibs/verification"
)
@@ -17,15 +16,13 @@ type reader struct {
env params.Env
logger logging.Logger
regex verification.Regex
os os.OS
}
func newReader(env params.Env, os os.OS, logger logging.Logger) reader {
func newReader(env params.Env, logger logging.Logger) reader {
return reader{
env: env,
logger: logger,
regex: verification.NewRegex(),
os: os,
}
}

View File

@@ -4,9 +4,9 @@ import (
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params"
)
@@ -48,7 +48,7 @@ func (r *reader) getFromEnvOrSecretFile(envKey string, compulsory bool, retroKey
ErrGetSecretFilepath, secretFilepathEnvKey, err)
}
file, fileErr := r.os.OpenFile(filepath, os.O_RDONLY, 0)
file, fileErr := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(fileErr) {
if compulsory {
return "", envErr
@@ -85,7 +85,7 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
return b, fmt.Errorf("environment variable %s: %w: %s", key, ErrGetSecretFilepath, err)
}
b, err = readFromFile(r.os.OpenFile, secretFilepath)
b, err = readFromFile(secretFilepath)
if err != nil && !os.IsNotExist(err) {
return b, fmt.Errorf("%w: %s", ErrReadSecretFile, err)
} else if err == nil {
@@ -93,7 +93,7 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
}
// Secret file does not exist, try the non secret file
b, err = readFromFile(r.os.OpenFile, filepath)
b, err = readFromFile(filepath)
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("%w: %s", ErrReadSecretFile, err)
} else if err == nil {
@@ -102,8 +102,8 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
return nil, fmt.Errorf("%w: %s and %s", ErrFilesDoNotExist, secretFilepath, filepath)
}
func readFromFile(openFile os.OpenFileFunc, filepath string) (b []byte, err error) {
file, err := openFile(filepath, os.O_RDONLY, 0)
func readFromFile(filepath string) (b []byte, err error) {
file, err := os.Open(filepath)
if err != nil {
return nil, err
}

View File

@@ -6,7 +6,6 @@ import (
"strings"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
"github.com/qdm12/golibs/params"
)
@@ -62,8 +61,8 @@ var (
// Read obtains all configuration options for the program and returns an error as soon
// as an error is encountered reading them.
func (settings *Settings) Read(env params.Env, os os.OS, logger logging.Logger) (err error) {
r := newReader(env, os, logger)
func (settings *Settings) Read(env params.Env, logger logging.Logger) (err error) {
r := newReader(env, logger)
settings.VersionInformation, err = r.env.OnOff("VERSION_INFORMATION", params.Default("on"))
if err != nil {

View File

@@ -16,7 +16,6 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
type Looper interface {
@@ -33,6 +32,7 @@ type Looper interface {
type looper struct {
state *state
conf unbound.Configurator
resolvConf string
blockBuilder blacklist.Builder
client *http.Client
logger logging.Logger
@@ -45,13 +45,12 @@ type looper struct {
backoffTime time.Duration
timeNow func() time.Time
timeSince func(time.Time) time.Duration
openFile os.OpenFileFunc
}
const defaultBackoffTime = 10 * time.Second
func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *http.Client,
logger logging.Logger, openFile os.OpenFileFunc) Looper {
logger logging.Logger) Looper {
start := make(chan struct{})
running := make(chan models.LoopStatus)
stop := make(chan struct{})
@@ -63,6 +62,7 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht
return &looper{
state: state,
conf: conf,
resolvConf: "/etc/resolv.conf",
blockBuilder: blacklist.NewBuilder(client),
client: client,
logger: logger,
@@ -75,7 +75,6 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht
backoffTime: defaultBackoffTime,
timeNow: time.Now,
timeSince: time.Since,
openFile: openFile,
}
}
@@ -227,8 +226,8 @@ func (l *looper) setupUnbound(ctx context.Context) (
// use Unbound
nameserver.UseDNSInternally(net.IP{127, 0, 0, 1})
err = nameserver.UseDNSSystemWide(l.openFile,
net.IP{127, 0, 0, 1}, settings.KeepNameserver)
err = nameserver.UseDNSSystemWide(l.resolvConf, net.IP{127, 0, 0, 1},
settings.KeepNameserver)
if err != nil {
l.logger.Error(err)
}
@@ -256,8 +255,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
l.logger.Info("using plaintext DNS at address %s", targetIP)
}
nameserver.UseDNSInternally(targetIP)
if err := nameserver.UseDNSSystemWide(l.openFile,
targetIP, settings.KeepNameserver); err != nil {
err := nameserver.UseDNSSystemWide(l.resolvConf, targetIP, settings.KeepNameserver)
if err != nil {
l.logger.Error(err)
}
return
@@ -271,7 +270,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
l.logger.Info("using plaintext DNS at address " + targetIP.String())
}
nameserver.UseDNSInternally(targetIP)
if err := nameserver.UseDNSSystemWide(l.openFile, targetIP, settings.KeepNameserver); err != nil {
err := nameserver.UseDNSSystemWide(l.resolvConf, targetIP, settings.KeepNameserver)
if err != nil {
l.logger.Error(err)
}
}

View File

@@ -136,7 +136,7 @@ func (c *configurator) enable(ctx context.Context) (err error) {
}
}
if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil {
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("%w: %s", ErrUserPostRules, err)
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
// Configurator allows to change firewall rules and modify network routes.
@@ -32,7 +31,6 @@ type configurator struct { //nolint:maligned
commander command.Commander
logger logging.Logger
routing routing.Routing
openFile os.OpenFileFunc // for custom iptables rules
iptablesMutex sync.Mutex
ip6tablesMutex sync.Mutex
debug bool
@@ -43,7 +41,8 @@ type configurator struct { //nolint:maligned
networkInfoMutex sync.Mutex
// Fixed state
ip6Tables bool
ip6Tables bool
customRulesPath string
// State
enabled bool
@@ -54,15 +53,15 @@ type configurator struct { //nolint:maligned
}
// NewConfigurator creates a new Configurator instance.
func NewConfigurator(logger logging.Logger, routing routing.Routing, openFile os.OpenFileFunc) Configurator {
func NewConfigurator(logger logging.Logger, routing routing.Routing) Configurator {
commander := command.NewCommander()
return &configurator{
commander: commander,
logger: logger,
routing: routing,
openFile: openFile,
allowedInputPorts: make(map[uint16]string),
ip6Tables: ip6tablesSupported(context.Background(), commander),
customRulesPath: "/iptables/post-rules.txt",
}
}

View File

@@ -196,7 +196,7 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port
}
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
file, err := c.openFile(filepath, os.O_RDONLY, 0)
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return nil
} else if err != nil {

View File

@@ -10,14 +10,14 @@ import (
// WriteAuthFile writes the OpenVPN auth file to disk with the right permissions.
func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) error {
file, err := c.os.OpenFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0)
file, err := os.Open(c.authFilePath)
if err != nil && !os.IsNotExist(err) {
return err
}
if os.IsNotExist(err) {
file, err = c.os.OpenFile(constants.OpenVPNAuthConf, os.O_WRONLY|os.O_CREATE, 0400)
file, err = os.OpenFile(c.authFilePath, os.O_WRONLY|os.O_CREATE, 0400)
if err != nil {
return err
}
@@ -49,7 +49,7 @@ func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) erro
}
c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf)
file, err = c.os.OpenFile(constants.OpenVPNAuthConf, os.O_TRUNC|os.O_WRONLY, 0400)
file, err = os.OpenFile(c.authFilePath, os.O_TRUNC|os.O_WRONLY, 0400)
if err != nil {
return err
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
@@ -12,14 +13,13 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/golibs/os"
)
var errProcessCustomConfig = errors.New("cannot process custom config")
func (l *looper) processCustomConfig(settings configuration.OpenVPN) (
lines []string, connection models.OpenVPNConnection, err error) {
lines, err = readCustomConfigLines(settings.Config, l.openFile)
lines, err = readCustomConfigLines(settings.Config)
if err != nil {
return nil, connection, fmt.Errorf("%w: %s", errProcessCustomConfig, err)
}
@@ -35,9 +35,9 @@ func (l *looper) processCustomConfig(settings configuration.OpenVPN) (
return lines, connection, nil
}
func readCustomConfigLines(filepath string, openFile os.OpenFileFunc) (
func readCustomConfigLines(filepath string) (
lines []string, err error) {
file, err := openFile(filepath, os.O_RDONLY, 0)
file, err := os.Open(filepath)
if err != nil {
return nil, err
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"net"
"net/http"
"os"
"strings"
"time"
@@ -14,7 +15,6 @@ import (
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
type Looper interface {
@@ -34,9 +34,10 @@ type Looper interface {
type looper struct {
state *state
// Fixed parameters
username string
puid int
pgid int
username string
puid int
pgid int
targetConfPath string
// Configurators
conf Configurator
fw firewall.Configurator
@@ -44,7 +45,6 @@ type looper struct {
// Other objects
logger, pfLogger logging.Logger
client *http.Client
openFile os.OpenFileFunc
tunnelReady chan<- struct{}
// Internal channels and values
stop <-chan struct{}
@@ -64,7 +64,7 @@ const (
func NewLooper(settings configuration.OpenVPN,
username string, puid, pgid int, allServers models.AllServers,
conf Configurator, fw firewall.Configurator, routing routing.Routing,
logger logging.ParentLogger, client *http.Client, openFile os.OpenFileFunc,
logger logging.ParentLogger, client *http.Client,
tunnelReady chan<- struct{}) Looper {
start := make(chan struct{})
running := make(chan models.LoopStatus)
@@ -79,13 +79,13 @@ func NewLooper(settings configuration.OpenVPN,
username: username,
puid: puid,
pgid: pgid,
targetConfPath: constants.OpenVPNConf,
conf: conf,
fw: fw,
routing: routing,
logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}),
pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}),
client: client,
openFile: openFile,
tunnelReady: tunnelReady,
start: start,
running: running,
@@ -145,7 +145,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
}
}
if err := writeOpenvpnConf(lines, l.openFile); err != nil {
if err := l.writeOpenvpnConf(lines); err != nil {
l.signalOrSetStatus(constants.Crashed)
l.logAndWait(ctx, err)
continue
@@ -279,13 +279,12 @@ func (l *looper) portForward(ctx context.Context,
defer l.state.settingsMu.RUnlock()
return settings.Provider.PortForwarding.Filepath
}
providerConf.PortForward(ctx,
client, l.openFile, l.pfLogger,
providerConf.PortForward(ctx, client, l.pfLogger,
gateway, l.fw, syncState)
}
func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error {
file, err := openFile(constants.OpenVPNConf, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
func (l *looper) writeOpenvpnConf(lines []string) error {
file, err := os.OpenFile(l.targetConfPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}

View File

@@ -5,10 +5,10 @@ package openvpn
import (
"context"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/unix"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
type Configurator interface {
@@ -22,17 +22,19 @@ type Configurator interface {
}
type configurator struct {
logger logging.Logger
commander command.Commander
os os.OS
unix unix.Unix
logger logging.Logger
commander command.Commander
unix unix.Unix
authFilePath string
tunDevPath string
}
func NewConfigurator(logger logging.Logger, os os.OS, unix unix.Unix) Configurator {
func NewConfigurator(logger logging.Logger, unix unix.Unix) Configurator {
return &configurator{
logger: logger,
commander: command.NewCommander(),
os: os,
unix: unix,
logger: logger,
commander: command.NewCommander(),
unix: unix,
authFilePath: constants.OpenVPNAuthConf,
tunDevPath: constants.TunnelDevice,
}
}

View File

@@ -3,15 +3,15 @@ package openvpn
import (
"fmt"
"os"
"path/filepath"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/unix"
)
// CheckTUN checks the tunnel device is present and accessible.
func (c *configurator) CheckTUN() error {
c.logger.Info("checking for device %s", constants.TunnelDevice)
f, err := c.os.OpenFile(constants.TunnelDevice, os.O_RDWR, 0)
c.logger.Info("checking for device " + c.tunDevPath)
f, err := os.OpenFile(c.tunDevPath, os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("TUN device is not available: %w", err)
}
@@ -22,8 +22,10 @@ func (c *configurator) CheckTUN() error {
}
func (c *configurator) CreateTUN() error {
c.logger.Info("creating %s", constants.TunnelDevice)
if err := c.os.MkdirAll("/dev/net", 0751); err != nil { //nolint:gomnd
c.logger.Info("creating " + c.tunDevPath)
parentDir := filepath.Dir(c.tunDevPath)
if err := os.MkdirAll(parentDir, 0751); err != nil { //nolint:gomnd
return err
}
@@ -32,17 +34,13 @@ func (c *configurator) CreateTUN() error {
minor = 200
)
dev := c.unix.Mkdev(major, minor)
if err := c.unix.Mknod(constants.TunnelDevice, unix.S_IFCHR, int(dev)); err != nil {
if err := c.unix.Mknod(c.tunDevPath, unix.S_IFCHR, int(dev)); err != nil {
return err
}
file, err := c.os.OpenFile(constants.TunnelDevice, os.O_WRONLY, 0666) //nolint:gomnd
if err != nil {
return err
}
const readWriteAllPerms os.FileMode = 0666
if err := file.Chmod(readWriteAllPerms); err != nil {
_ = file.Close()
file, err := os.OpenFile(c.tunDevPath, os.O_WRONLY, readWriteAllPerms)
if err != nil {
return err
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (c *Cyberghost) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Cyberghost")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (f *Fastestvpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for FastestVPN")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (f *HideMyAss) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for HideMyAss")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (i *Ipvanish) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Ipvanish")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (i *Ivpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Ivpn")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (m *Mullvad) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding logic is not needed for Mullvad")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (n *Nordvpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for NordVPN")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (p *Privado) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Privado")
}

View File

@@ -10,6 +10,7 @@ import (
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
@@ -18,7 +19,6 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/format"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
var (
@@ -28,7 +28,7 @@ var (
// PortForward obtains a VPN server side port forwarded from PIA.
//nolint:gocognit
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator,
logger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
commonName := p.activeServer.ServerName
if !p.activeServer.PortForward {
@@ -47,7 +47,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
return
}
data, err := readPIAPortForwardData(openFile)
data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil {
logger.Error(err)
}
@@ -67,7 +67,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
if !dataFound || expired {
tryUntilSuccessful(ctx, logger, func() error {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
return err
})
if ctx.Err() != nil {
@@ -91,7 +92,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error(err)
}
@@ -128,7 +129,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
data.Expiration.Format(time.RFC1123) + ", getting another one")
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
if err != nil {
logger.Error(err)
continue
@@ -146,7 +148,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
}
filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error("Cannot write port forward data to file: " + err.Error())
}
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
@@ -168,8 +170,8 @@ var (
)
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, openFile, client)
gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, authFilePath)
if err != nil {
return data, fmt.Errorf("%w: %s", ErrFetchToken, err)
}
@@ -179,7 +181,7 @@ func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *htt
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err)
}
if err := writePIAPortForwardData(openFile, data); err != nil {
if err := writePIAPortForwardData(portForwardPath, data); err != nil {
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err)
}
@@ -199,8 +201,8 @@ type piaPortForwardData struct {
Expiration time.Time `json:"expires_at"`
}
func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
file, err := openFile(constants.PIAPortForward, os.O_RDONLY, 0)
func readPIAPortForwardData(portForwardPath string) (data piaPortForwardData, err error) {
file, err := os.Open(portForwardPath)
if os.IsNotExist(err) {
return data, nil
} else if err != nil {
@@ -216,8 +218,8 @@ func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData,
return data, file.Close()
}
func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) {
file, err := openFile(constants.PIAPortForward, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
func writePIAPortForwardData(portForwardPath string, data piaPortForwardData) (err error) {
file, err := os.OpenFile(portForwardPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
@@ -269,9 +271,9 @@ var (
errEmptyToken = errors.New("token received is empty")
)
func fetchToken(ctx context.Context, openFile os.OpenFileFunc,
client *http.Client) (token string, err error) {
username, password, err := getOpenvpnCredentials(openFile)
func fetchToken(ctx context.Context, client *http.Client,
authFilePath string) (token string, err error) {
username, password, err := getOpenvpnCredentials(authFilePath)
if err != nil {
return "", fmt.Errorf("%w: %s", errGetCredentials, err)
}
@@ -321,8 +323,9 @@ var (
errAuthFileMalformed = errors.New("authentication file is malformed")
)
func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) {
file, err := openFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0)
func getOpenvpnCredentials(authFilePath string) (
username, password string, err error) {
file, err := os.Open(authFilePath)
if err != nil {
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
}
@@ -460,9 +463,8 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
return nil
}
func writePortForwardedToFile(openFile os.OpenFileFunc,
filepath string, port uint16) (err error) {
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
func writePortForwardedToFile(filepath string, port uint16) (err error) {
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}

View File

@@ -4,6 +4,7 @@ import (
"math/rand"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
)
@@ -12,12 +13,18 @@ type PIA struct {
randSource rand.Source
timeNow func() time.Time
activeServer models.PIAServer
// Port forwarding
portForwardPath string
authFilePath string
}
func New(servers []models.PIAServer, randSource rand.Source, timeNow func() time.Time) *PIA {
func New(servers []models.PIAServer, randSource rand.Source,
timeNow func() time.Time) *PIA {
return &PIA{
servers: servers,
timeNow: timeNow,
randSource: randSource,
servers: servers,
timeNow: timeNow,
randSource: randSource,
portForwardPath: constants.PIAPortForward,
authFilePath: constants.OpenVPNAuthConf,
}
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (p *Privatevpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for PrivateVPN")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (p *Protonvpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for ProtonVPN")
}

View File

@@ -30,7 +30,6 @@ import (
"github.com/qdm12/gluetun/internal/provider/vyprvpn"
"github.com/qdm12/gluetun/internal/provider/windscribe"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
// Provider contains methods to read and modify the openvpn configuration to connect as a client.
@@ -38,7 +37,7 @@ type Provider interface {
GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error)
BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string)
PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string))
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (p *Purevpn) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for PureVPN")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (s *Surfshark) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Surfshark")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (t *Torguard) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Torguard")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (p *Provider) PortForward(ctx context.Context, client *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for VPN Unlimited")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (v *Vyprvpn) PortForward(ctx context.Context, clienv *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
func (v *Vyprvpn) PortForward(ctx context.Context, client *http.Client,
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Vyprvpn")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
func (w *Windscribe) PortForward(ctx context.Context, clienw *http.Client,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
func (w *Windscribe) PortForward(ctx context.Context, client *http.Client,
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath string)) {
panic("port forwarding is not supported for Windscribe")
}

View File

@@ -1,13 +1,11 @@
package publicip
import "github.com/qdm12/golibs/os"
import (
"os"
)
func persistPublicIP(openFile os.OpenFileFunc,
filepath string, content string, puid, pgid int) error {
file, err := openFile(
filepath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE,
0644)
func persistPublicIP(path string, content string, puid, pgid int) error {
file, err := os.OpenFile(path, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return err
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"net"
"net/http"
"os"
"sync"
"time"
@@ -11,7 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
type Looper interface {
@@ -31,7 +31,6 @@ type looper struct {
getter IPGetter
client *http.Client
logger logging.Logger
os os.OS
// Fixed settings
puid int
pgid int
@@ -51,8 +50,7 @@ type looper struct {
const defaultBackoffTime = 5 * time.Second
func NewLooper(client *http.Client, logger logging.Logger,
settings configuration.PublicIP, puid, pgid int,
os os.OS) Looper {
settings configuration.PublicIP, puid, pgid int) Looper {
return &looper{
state: state{
status: constants.Stopped,
@@ -62,7 +60,6 @@ func NewLooper(client *http.Client, logger logging.Logger,
client: client,
getter: NewIPGetter(client),
logger: logger,
os: os,
puid: puid,
pgid: pgid,
start: make(chan struct{}),
@@ -136,7 +133,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
close(errorCh)
filepath := l.GetSettings().IPFilepath
l.logger.Info("Removing ip file " + filepath)
if err := l.os.Remove(filepath); err != nil {
if err := os.Remove(filepath); err != nil {
l.logger.Error(err)
}
return
@@ -161,7 +158,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
}
l.logger.Info(message)
err = persistPublicIP(l.os.OpenFile, l.state.settings.IPFilepath,
err = persistPublicIP(l.state.settings.IPFilepath,
ip.String(), l.puid, l.pgid)
if err != nil {
l.logger.Error(err)

View File

@@ -0,0 +1,7 @@
package storage
import "github.com/qdm12/gluetun/internal/models"
func (s *storage) FlushToFile(allServers models.AllServers) error {
return flushToFile(s.filepath, allServers)
}

View File

@@ -4,7 +4,6 @@ package storage
import (
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/os"
)
type Storage interface {
@@ -14,14 +13,12 @@ type Storage interface {
}
type storage struct {
os os.OS
logger logging.Logger
filepath string
}
func New(logger logging.Logger, os os.OS, filepath string) Storage {
func New(logger logging.Logger, filepath string) Storage {
return &storage{
os: os,
logger: logger,
filepath: filepath,
}

View File

@@ -5,11 +5,11 @@ import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"reflect"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/os"
)
var (
@@ -39,7 +39,7 @@ func countServers(allServers models.AllServers) int {
func (s *storage) SyncServers(hardcodedServers models.AllServers) (
allServers models.AllServers, err error) {
serversOnFile, err := s.readFromFile(s.filepath)
serversOnFile, err := readFromFile(s.filepath)
if err != nil {
return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err)
}
@@ -62,14 +62,14 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers) (
return allServers, nil
}
if err := s.FlushToFile(allServers); err != nil {
if err := flushToFile(s.filepath, allServers); err != nil {
return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err)
}
return allServers, nil
}
func (s *storage) readFromFile(filepath string) (servers models.AllServers, err error) {
file, err := s.os.OpenFile(filepath, os.O_RDONLY, 0)
func readFromFile(filepath string) (servers models.AllServers, err error) {
file, err := os.Open(filepath)
if os.IsNotExist(err) {
return servers, nil
} else if err != nil {
@@ -86,13 +86,13 @@ func (s *storage) readFromFile(filepath string) (servers models.AllServers, err
return servers, file.Close()
}
func (s *storage) FlushToFile(servers models.AllServers) error {
dirPath := filepath.Dir(s.filepath)
if err := s.os.MkdirAll(dirPath, 0644); err != nil {
func flushToFile(path string, servers models.AllServers) error {
dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, 0644); err != nil {
return err
}
file, err := s.os.OpenFile(s.filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}