Maint: do not mock os functions
- Use filepaths with /tmp for tests instead - Only mock functions where filepath can't be specified such as user.Lookup
This commit is contained in:
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
nativeos "os"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -33,8 +33,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
versionpkg "github.com/qdm12/gluetun/internal/version"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/os/user"
|
||||
"github.com/qdm12/golibs/params"
|
||||
"github.com/qdm12/goshutdown"
|
||||
"github.com/qdm12/gosplash"
|
||||
@@ -61,21 +59,19 @@ func main() {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, nativeos.Interrupt)
|
||||
ctx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
logger := logging.NewParent(logging.Settings{})
|
||||
|
||||
args := nativeos.Args
|
||||
os := os.New()
|
||||
osUser := user.New()
|
||||
args := os.Args
|
||||
unix := unix.New()
|
||||
cli := cli.New()
|
||||
env := params.NewEnv()
|
||||
|
||||
errorCh := make(chan error)
|
||||
go func() {
|
||||
errorCh <- _main(ctx, buildInfo, args, logger, env, os, osUser, unix, cli)
|
||||
errorCh <- _main(ctx, buildInfo, args, logger, env, unix, cli)
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -86,7 +82,7 @@ func main() {
|
||||
stop()
|
||||
close(errorCh)
|
||||
if err == nil { // expected exit such as healthcheck
|
||||
nativeos.Exit(0)
|
||||
os.Exit(0)
|
||||
}
|
||||
logger.Error(err)
|
||||
cancel()
|
||||
@@ -104,7 +100,7 @@ func main() {
|
||||
logger.Warn("Shutdown timed out")
|
||||
}
|
||||
|
||||
nativeos.Exit(1)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -113,18 +109,18 @@ var (
|
||||
|
||||
//nolint:gocognit,gocyclo
|
||||
func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
args []string, logger logging.ParentLogger, env params.Env, os os.OS,
|
||||
osUser user.OSUser, unix unix.Unix, cli cli.CLI) error {
|
||||
args []string, logger logging.ParentLogger, env params.Env,
|
||||
unix unix.Unix, cli cli.CLI) error {
|
||||
if len(args) > 1 { // cli operation
|
||||
switch args[1] {
|
||||
case "healthcheck":
|
||||
return cli.HealthCheck(ctx, env, os, logger)
|
||||
return cli.HealthCheck(ctx, env, logger)
|
||||
case "clientkey":
|
||||
return cli.ClientKey(args[2:], os.OpenFile)
|
||||
return cli.ClientKey(args[2:])
|
||||
case "openvpnconfig":
|
||||
return cli.OpenvpnConfig(os, logger)
|
||||
return cli.OpenvpnConfig(logger)
|
||||
case "update":
|
||||
return cli.Update(ctx, args[2:], os, logger)
|
||||
return cli.Update(ctx, args[2:], logger)
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
||||
}
|
||||
@@ -133,19 +129,19 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
const clientTimeout = 15 * time.Second
|
||||
httpClient := &http.Client{Timeout: clientTimeout}
|
||||
// Create configurators
|
||||
alpineConf := alpine.NewConfigurator(os.OpenFile, osUser)
|
||||
alpineConf := alpine.NewConfigurator()
|
||||
ovpnConf := openvpn.NewConfigurator(
|
||||
logger.NewChild(logging.Settings{Prefix: "openvpn configurator: "}),
|
||||
os, unix)
|
||||
unix)
|
||||
dnsCrypto := dnscrypto.New(httpClient, "", "")
|
||||
const cacertsPath = "/etc/ssl/certs/ca-certificates.crt"
|
||||
dnsConf := unbound.NewConfigurator(nil, os.OpenFile, dnsCrypto,
|
||||
dnsConf := unbound.NewConfigurator(nil, dnsCrypto,
|
||||
"/etc/unbound", "/usr/sbin/unbound", cacertsPath)
|
||||
routingConf := routing.NewRouting(
|
||||
logger.NewChild(logging.Settings{Prefix: "routing: "}))
|
||||
firewallConf := firewall.NewConfigurator(
|
||||
logger.NewChild(logging.Settings{Prefix: "firewall: "}),
|
||||
routingConf, os.OpenFile)
|
||||
routingConf)
|
||||
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2021-07-22T00:00:00Z")
|
||||
if err != nil {
|
||||
@@ -179,7 +175,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
}
|
||||
|
||||
var allSettings configuration.Settings
|
||||
err = allSettings.Read(env, os,
|
||||
err = allSettings.Read(env,
|
||||
logger.NewChild(logging.Settings{Prefix: "configuration: "}))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -196,7 +192,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
// TODO run this in a loop or in openvpn to reload from file without restarting
|
||||
storage := storage.New(
|
||||
logger.NewChild(logging.Settings{Prefix: "storage: "}),
|
||||
os, constants.ServersData)
|
||||
constants.ServersData)
|
||||
allServers, err := storage.SyncServers(constants.GetAllServers())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -314,7 +310,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings)
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
|
||||
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, tunnelReadyCh)
|
||||
ovpnConf, firewallConf, routingConf, logger, httpClient, tunnelReadyCh)
|
||||
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
|
||||
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
|
||||
// wait for restartOpenvpn
|
||||
@@ -331,7 +327,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
unboundLogger := logger.NewChild(logging.Settings{Prefix: "dns over tls: "})
|
||||
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, httpClient,
|
||||
unboundLogger, os.OpenFile)
|
||||
unboundLogger)
|
||||
dnsHandler, dnsCtx, dnsDone := goshutdown.NewGoRoutineHandler(
|
||||
"unbound", defaultGoRoutineSettings)
|
||||
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
|
||||
@@ -340,7 +336,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
publicIPLooper := publicip.NewLooper(httpClient,
|
||||
logger.NewChild(logging.Settings{Prefix: "ip getter: "}),
|
||||
allSettings.PublicIP, puid, pgid, os)
|
||||
allSettings.PublicIP, puid, pgid)
|
||||
pubIPHandler, pubIPCtx, pubIPDone := goshutdown.NewGoRoutineHandler(
|
||||
"public IP", defaultGoRoutineSettings)
|
||||
go publicIPLooper.Run(pubIPCtx, pubIPDone)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -5,7 +5,7 @@ go 1.16
|
||||
require (
|
||||
github.com/fatih/color v1.12.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/qdm12/dns v1.9.0
|
||||
github.com/qdm12/dns v1.10.0
|
||||
github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99
|
||||
github.com/qdm12/goshutdown v0.1.0
|
||||
github.com/qdm12/gosplash v0.1.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -63,8 +63,8 @@ github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMg
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/qdm12/dns v1.9.0 h1:p4g/BfbpQ+gJRpQdklDAnybkjds+OuenF0wEGoZ8/AI=
|
||||
github.com/qdm12/dns v1.9.0/go.mod h1:fqZoDf3VzddnKBMNI/OzZUp5H4dO0VBw1fp4qPkolOg=
|
||||
github.com/qdm12/dns v1.10.0 h1:WX5QQ5+2h34xfhfxJTmvyURbs9XE4qNrEGtyNeq38Bw=
|
||||
github.com/qdm12/dns v1.10.0/go.mod h1:fqZoDf3VzddnKBMNI/OzZUp5H4dO0VBw1fp4qPkolOg=
|
||||
github.com/qdm12/golibs v0.0.0-20210603202746-e5494e9c2ebb/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg=
|
||||
github.com/qdm12/golibs v0.0.0-20210716185557-66793f4ddd80/go.mod h1:15RBzkun0i8XB7ADIoLJWp9ITRgsz3LroEI2FiOXLRg=
|
||||
github.com/qdm12/golibs v0.0.0-20210721223530-ec1d3fe6dc99 h1:2OKHAR0SK8BtTtWCRNoSn58eh+iVDA3Cwq4i2CnD3i4=
|
||||
|
||||
@@ -3,9 +3,7 @@ package alpine
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/os/user"
|
||||
"os/user"
|
||||
)
|
||||
|
||||
type Configurator interface {
|
||||
@@ -14,13 +12,17 @@ type Configurator interface {
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
openFile os.OpenFileFunc
|
||||
osUser user.OSUser
|
||||
alpineReleasePath string
|
||||
passwdPath string
|
||||
lookupID func(uid string) (*user.User, error)
|
||||
lookup func(username string) (*user.User, error)
|
||||
}
|
||||
|
||||
func NewConfigurator(openFile os.OpenFileFunc, osUser user.OSUser) Configurator {
|
||||
func NewConfigurator() Configurator {
|
||||
return &configurator{
|
||||
openFile: openFile,
|
||||
osUser: osUser,
|
||||
alpineReleasePath: "/etc/alpine-release",
|
||||
passwdPath: "/etc/passwd",
|
||||
lookupID: user.LookupId,
|
||||
lookup: user.Lookup,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ var (
|
||||
// CreateUser creates a user in Alpine with the given UID.
|
||||
func (c *configurator) CreateUser(username string, uid int) (createdUsername string, err error) {
|
||||
UIDStr := strconv.Itoa(uid)
|
||||
u, err := c.osUser.LookupID(UIDStr)
|
||||
u, err := c.lookupID(UIDStr)
|
||||
_, unknownUID := err.(user.UnknownUserIdError)
|
||||
if err != nil && !unknownUID {
|
||||
return "", err
|
||||
@@ -28,7 +28,7 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
|
||||
return u.Username, nil
|
||||
}
|
||||
|
||||
u, err = c.osUser.Lookup(username)
|
||||
u, err = c.lookup(username)
|
||||
_, unknownUsername := err.(user.UnknownUserError)
|
||||
if err != nil && !unknownUsername {
|
||||
return "", err
|
||||
@@ -39,7 +39,7 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
|
||||
ErrUserAlreadyExists, username, u.Uid, uid)
|
||||
}
|
||||
|
||||
file, err := c.openFile("/etc/passwd", os.O_APPEND|os.O_WRONLY, 0644)
|
||||
file, err := os.OpenFile(c.passwdPath, os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func (c *configurator) Version(ctx context.Context) (version string, err error) {
|
||||
file, err := c.openFile("/etc/alpine-release", os.O_RDONLY, 0)
|
||||
file, err := os.OpenFile(c.alpineReleasePath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -5,19 +5,22 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
type CLI interface {
|
||||
ClientKey(args []string, openFile os.OpenFileFunc) error
|
||||
HealthCheck(ctx context.Context, env params.Env, os os.OS, logger logging.Logger) error
|
||||
OpenvpnConfig(os os.OS, logger logging.Logger) error
|
||||
Update(ctx context.Context, args []string, os os.OS, logger logging.Logger) error
|
||||
ClientKey(args []string) error
|
||||
HealthCheck(ctx context.Context, env params.Env, logger logging.Logger) error
|
||||
OpenvpnConfig(logger logging.Logger) error
|
||||
Update(ctx context.Context, args []string, logger logging.Logger) error
|
||||
}
|
||||
|
||||
type cli struct{}
|
||||
type cli struct {
|
||||
repoServersPath string
|
||||
}
|
||||
|
||||
func New() CLI {
|
||||
return &cli{}
|
||||
return &cli{
|
||||
repoServersPath: "./internal/constants/servers.json",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (c *cli) ClientKey(args []string, openFile os.OpenFileFunc) error {
|
||||
func (c *cli) ClientKey(args []string) error {
|
||||
flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError)
|
||||
filepath := flagSet.String("path", constants.ClientKey, "file path to the client.key file")
|
||||
if err := flagSet.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
file, err := openFile(*filepath, os.O_RDONLY, 0)
|
||||
file, err := os.OpenFile(*filepath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,15 +9,14 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/healthcheck"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
func (c *cli) HealthCheck(ctx context.Context, env params.Env,
|
||||
os os.OS, logger logging.Logger) error {
|
||||
logger logging.Logger) error {
|
||||
// Extract the health server port from the configuration.
|
||||
config := configuration.Health{}
|
||||
err := config.Read(env, os, logger)
|
||||
err := config.Read(env, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -10,17 +10,16 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
func (c *cli) OpenvpnConfig(os os.OS, logger logging.Logger) error {
|
||||
func (c *cli) OpenvpnConfig(logger logging.Logger) error {
|
||||
var allSettings configuration.Settings
|
||||
err := allSettings.Read(params.NewEnv(), os, logger)
|
||||
err := allSettings.Read(params.NewEnv(), logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
allServers, err := storage.New(logger, os, constants.ServersData).
|
||||
allServers, err := storage.New(logger, constants.ServersData).
|
||||
SyncServers(constants.GetAllServers())
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
nativeos "os"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -26,7 +25,7 @@ var (
|
||||
ErrWriteToFile = errors.New("cannot write updated information to file")
|
||||
)
|
||||
|
||||
func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger logging.Logger) error {
|
||||
func (c *cli) Update(ctx context.Context, args []string, logger logging.Logger) error {
|
||||
options := configuration.Updater{CLI: true}
|
||||
var endUserMode, maintainerMode, updateAll bool
|
||||
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
|
||||
@@ -65,7 +64,7 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
|
||||
|
||||
const clientTimeout = 10 * time.Second
|
||||
httpClient := &http.Client{Timeout: clientTimeout}
|
||||
storage := storage.New(logger, os, constants.ServersData)
|
||||
storage := storage.New(logger, constants.ServersData)
|
||||
currentServers, err := storage.SyncServers(constants.GetAllServers())
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrSyncServers, err)
|
||||
@@ -83,7 +82,7 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
|
||||
}
|
||||
|
||||
if maintainerMode {
|
||||
if err := writeToEmbeddedJSON(os, allServers); err != nil {
|
||||
if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrWriteToFile, err)
|
||||
}
|
||||
}
|
||||
@@ -91,10 +90,11 @@ func (c *cli) Update(ctx context.Context, args []string, os os.OS, logger loggin
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeToEmbeddedJSON(os os.OS, allServers models.AllServers) error {
|
||||
func writeToEmbeddedJSON(repoServersPath string,
|
||||
allServers models.AllServers) error {
|
||||
const perms = 0600
|
||||
f, err := os.OpenFile("./internal/constants/servers.json",
|
||||
nativeos.O_TRUNC|nativeos.O_WRONLY|nativeos.O_CREATE, perms)
|
||||
f, err := os.OpenFile(repoServersPath,
|
||||
os.O_TRUNC|os.O_WRONLY|os.O_CREATE, perms)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
@@ -33,8 +32,8 @@ func (settings *Health) lines() (lines []string) {
|
||||
}
|
||||
|
||||
// Read is to be used for the healthcheck query mode.
|
||||
func (settings *Health) Read(env params.Env, os os.OS, logger logging.Logger) (err error) {
|
||||
reader := newReader(env, os, logger)
|
||||
func (settings *Health) Read(env params.Env, logger logging.Logger) (err error) {
|
||||
reader := newReader(env, logger)
|
||||
return settings.read(reader)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/params"
|
||||
"github.com/qdm12/golibs/verification"
|
||||
)
|
||||
@@ -17,15 +16,13 @@ type reader struct {
|
||||
env params.Env
|
||||
logger logging.Logger
|
||||
regex verification.Regex
|
||||
os os.OS
|
||||
}
|
||||
|
||||
func newReader(env params.Env, os os.OS, logger logging.Logger) reader {
|
||||
func newReader(env params.Env, logger logging.Logger) reader {
|
||||
return reader{
|
||||
env: env,
|
||||
logger: logger,
|
||||
regex: verification.NewRegex(),
|
||||
os: os,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ func (r *reader) getFromEnvOrSecretFile(envKey string, compulsory bool, retroKey
|
||||
ErrGetSecretFilepath, secretFilepathEnvKey, err)
|
||||
}
|
||||
|
||||
file, fileErr := r.os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
file, fileErr := os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(fileErr) {
|
||||
if compulsory {
|
||||
return "", envErr
|
||||
@@ -85,7 +85,7 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
|
||||
return b, fmt.Errorf("environment variable %s: %w: %s", key, ErrGetSecretFilepath, err)
|
||||
}
|
||||
|
||||
b, err = readFromFile(r.os.OpenFile, secretFilepath)
|
||||
b, err = readFromFile(secretFilepath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return b, fmt.Errorf("%w: %s", ErrReadSecretFile, err)
|
||||
} else if err == nil {
|
||||
@@ -93,7 +93,7 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
|
||||
}
|
||||
|
||||
// Secret file does not exist, try the non secret file
|
||||
b, err = readFromFile(r.os.OpenFile, filepath)
|
||||
b, err = readFromFile(filepath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("%w: %s", ErrReadSecretFile, err)
|
||||
} else if err == nil {
|
||||
@@ -102,8 +102,8 @@ func (r *reader) getFromFileOrSecretFile(secretName, filepath string) (
|
||||
return nil, fmt.Errorf("%w: %s and %s", ErrFilesDoNotExist, secretFilepath, filepath)
|
||||
}
|
||||
|
||||
func readFromFile(openFile os.OpenFileFunc, filepath string) (b []byte, err error) {
|
||||
file, err := openFile(filepath, os.O_RDONLY, 0)
|
||||
func readFromFile(filepath string) (b []byte, err error) {
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
@@ -62,8 +61,8 @@ var (
|
||||
|
||||
// Read obtains all configuration options for the program and returns an error as soon
|
||||
// as an error is encountered reading them.
|
||||
func (settings *Settings) Read(env params.Env, os os.OS, logger logging.Logger) (err error) {
|
||||
r := newReader(env, os, logger)
|
||||
func (settings *Settings) Read(env params.Env, logger logging.Logger) (err error) {
|
||||
r := newReader(env, logger)
|
||||
|
||||
settings.VersionInformation, err = r.env.OnOff("VERSION_INFORMATION", params.Default("on"))
|
||||
if err != nil {
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
@@ -33,6 +32,7 @@ type Looper interface {
|
||||
type looper struct {
|
||||
state *state
|
||||
conf unbound.Configurator
|
||||
resolvConf string
|
||||
blockBuilder blacklist.Builder
|
||||
client *http.Client
|
||||
logger logging.Logger
|
||||
@@ -45,13 +45,12 @@ type looper struct {
|
||||
backoffTime time.Duration
|
||||
timeNow func() time.Time
|
||||
timeSince func(time.Time) time.Duration
|
||||
openFile os.OpenFileFunc
|
||||
}
|
||||
|
||||
const defaultBackoffTime = 10 * time.Second
|
||||
|
||||
func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *http.Client,
|
||||
logger logging.Logger, openFile os.OpenFileFunc) Looper {
|
||||
logger logging.Logger) Looper {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
stop := make(chan struct{})
|
||||
@@ -63,6 +62,7 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht
|
||||
return &looper{
|
||||
state: state,
|
||||
conf: conf,
|
||||
resolvConf: "/etc/resolv.conf",
|
||||
blockBuilder: blacklist.NewBuilder(client),
|
||||
client: client,
|
||||
logger: logger,
|
||||
@@ -75,7 +75,6 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht
|
||||
backoffTime: defaultBackoffTime,
|
||||
timeNow: time.Now,
|
||||
timeSince: time.Since,
|
||||
openFile: openFile,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,8 +226,8 @@ func (l *looper) setupUnbound(ctx context.Context) (
|
||||
|
||||
// use Unbound
|
||||
nameserver.UseDNSInternally(net.IP{127, 0, 0, 1})
|
||||
err = nameserver.UseDNSSystemWide(l.openFile,
|
||||
net.IP{127, 0, 0, 1}, settings.KeepNameserver)
|
||||
err = nameserver.UseDNSSystemWide(l.resolvConf, net.IP{127, 0, 0, 1},
|
||||
settings.KeepNameserver)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
@@ -256,8 +255,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
|
||||
l.logger.Info("using plaintext DNS at address %s", targetIP)
|
||||
}
|
||||
nameserver.UseDNSInternally(targetIP)
|
||||
if err := nameserver.UseDNSSystemWide(l.openFile,
|
||||
targetIP, settings.KeepNameserver); err != nil {
|
||||
err := nameserver.UseDNSSystemWide(l.resolvConf, targetIP, settings.KeepNameserver)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
return
|
||||
@@ -271,7 +270,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
|
||||
l.logger.Info("using plaintext DNS at address " + targetIP.String())
|
||||
}
|
||||
nameserver.UseDNSInternally(targetIP)
|
||||
if err := nameserver.UseDNSSystemWide(l.openFile, targetIP, settings.KeepNameserver); err != nil {
|
||||
err := nameserver.UseDNSSystemWide(l.resolvConf, targetIP, settings.KeepNameserver)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ func (c *configurator) enable(ctx context.Context) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil {
|
||||
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrUserPostRules, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
// Configurator allows to change firewall rules and modify network routes.
|
||||
@@ -32,7 +31,6 @@ type configurator struct { //nolint:maligned
|
||||
commander command.Commander
|
||||
logger logging.Logger
|
||||
routing routing.Routing
|
||||
openFile os.OpenFileFunc // for custom iptables rules
|
||||
iptablesMutex sync.Mutex
|
||||
ip6tablesMutex sync.Mutex
|
||||
debug bool
|
||||
@@ -44,6 +42,7 @@ type configurator struct { //nolint:maligned
|
||||
|
||||
// Fixed state
|
||||
ip6Tables bool
|
||||
customRulesPath string
|
||||
|
||||
// State
|
||||
enabled bool
|
||||
@@ -54,15 +53,15 @@ type configurator struct { //nolint:maligned
|
||||
}
|
||||
|
||||
// NewConfigurator creates a new Configurator instance.
|
||||
func NewConfigurator(logger logging.Logger, routing routing.Routing, openFile os.OpenFileFunc) Configurator {
|
||||
func NewConfigurator(logger logging.Logger, routing routing.Routing) Configurator {
|
||||
commander := command.NewCommander()
|
||||
return &configurator{
|
||||
commander: commander,
|
||||
logger: logger,
|
||||
routing: routing,
|
||||
openFile: openFile,
|
||||
allowedInputPorts: make(map[uint16]string),
|
||||
ip6Tables: ip6tablesSupported(context.Background(), commander),
|
||||
customRulesPath: "/iptables/post-rules.txt",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port
|
||||
}
|
||||
|
||||
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
file, err := c.openFile(filepath, os.O_RDONLY, 0)
|
||||
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
|
||||
@@ -10,14 +10,14 @@ import (
|
||||
|
||||
// WriteAuthFile writes the OpenVPN auth file to disk with the right permissions.
|
||||
func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) error {
|
||||
file, err := c.os.OpenFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0)
|
||||
file, err := os.Open(c.authFilePath)
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
file, err = c.os.OpenFile(constants.OpenVPNAuthConf, os.O_WRONLY|os.O_CREATE, 0400)
|
||||
file, err = os.OpenFile(c.authFilePath, os.O_WRONLY|os.O_CREATE, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func (c *configurator) WriteAuthFile(user, password string, puid, pgid int) erro
|
||||
}
|
||||
|
||||
c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf)
|
||||
file, err = c.os.OpenFile(constants.OpenVPNAuthConf, os.O_TRUNC|os.O_WRONLY, 0400)
|
||||
file, err = os.OpenFile(c.authFilePath, os.O_TRUNC|os.O_WRONLY, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -12,14 +13,13 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
var errProcessCustomConfig = errors.New("cannot process custom config")
|
||||
|
||||
func (l *looper) processCustomConfig(settings configuration.OpenVPN) (
|
||||
lines []string, connection models.OpenVPNConnection, err error) {
|
||||
lines, err = readCustomConfigLines(settings.Config, l.openFile)
|
||||
lines, err = readCustomConfigLines(settings.Config)
|
||||
if err != nil {
|
||||
return nil, connection, fmt.Errorf("%w: %s", errProcessCustomConfig, err)
|
||||
}
|
||||
@@ -35,9 +35,9 @@ func (l *looper) processCustomConfig(settings configuration.OpenVPN) (
|
||||
return lines, connection, nil
|
||||
}
|
||||
|
||||
func readCustomConfigLines(filepath string, openFile os.OpenFileFunc) (
|
||||
func readCustomConfigLines(filepath string) (
|
||||
lines []string, err error) {
|
||||
file, err := openFile(filepath, os.O_RDONLY, 0)
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
@@ -37,6 +37,7 @@ type looper struct {
|
||||
username string
|
||||
puid int
|
||||
pgid int
|
||||
targetConfPath string
|
||||
// Configurators
|
||||
conf Configurator
|
||||
fw firewall.Configurator
|
||||
@@ -44,7 +45,6 @@ type looper struct {
|
||||
// Other objects
|
||||
logger, pfLogger logging.Logger
|
||||
client *http.Client
|
||||
openFile os.OpenFileFunc
|
||||
tunnelReady chan<- struct{}
|
||||
// Internal channels and values
|
||||
stop <-chan struct{}
|
||||
@@ -64,7 +64,7 @@ const (
|
||||
func NewLooper(settings configuration.OpenVPN,
|
||||
username string, puid, pgid int, allServers models.AllServers,
|
||||
conf Configurator, fw firewall.Configurator, routing routing.Routing,
|
||||
logger logging.ParentLogger, client *http.Client, openFile os.OpenFileFunc,
|
||||
logger logging.ParentLogger, client *http.Client,
|
||||
tunnelReady chan<- struct{}) Looper {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
@@ -79,13 +79,13 @@ func NewLooper(settings configuration.OpenVPN,
|
||||
username: username,
|
||||
puid: puid,
|
||||
pgid: pgid,
|
||||
targetConfPath: constants.OpenVPNConf,
|
||||
conf: conf,
|
||||
fw: fw,
|
||||
routing: routing,
|
||||
logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}),
|
||||
pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}),
|
||||
client: client,
|
||||
openFile: openFile,
|
||||
tunnelReady: tunnelReady,
|
||||
start: start,
|
||||
running: running,
|
||||
@@ -145,7 +145,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeOpenvpnConf(lines, l.openFile); err != nil {
|
||||
if err := l.writeOpenvpnConf(lines); err != nil {
|
||||
l.signalOrSetStatus(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
@@ -279,13 +279,12 @@ func (l *looper) portForward(ctx context.Context,
|
||||
defer l.state.settingsMu.RUnlock()
|
||||
return settings.Provider.PortForwarding.Filepath
|
||||
}
|
||||
providerConf.PortForward(ctx,
|
||||
client, l.openFile, l.pfLogger,
|
||||
providerConf.PortForward(ctx, client, l.pfLogger,
|
||||
gateway, l.fw, syncState)
|
||||
}
|
||||
|
||||
func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error {
|
||||
file, err := openFile(constants.OpenVPNConf, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
func (l *looper) writeOpenvpnConf(lines []string) error {
|
||||
file, err := os.OpenFile(l.targetConfPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,10 +5,10 @@ package openvpn
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/unix"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
type Configurator interface {
|
||||
@@ -24,15 +24,17 @@ type Configurator interface {
|
||||
type configurator struct {
|
||||
logger logging.Logger
|
||||
commander command.Commander
|
||||
os os.OS
|
||||
unix unix.Unix
|
||||
authFilePath string
|
||||
tunDevPath string
|
||||
}
|
||||
|
||||
func NewConfigurator(logger logging.Logger, os os.OS, unix unix.Unix) Configurator {
|
||||
func NewConfigurator(logger logging.Logger, unix unix.Unix) Configurator {
|
||||
return &configurator{
|
||||
logger: logger,
|
||||
commander: command.NewCommander(),
|
||||
os: os,
|
||||
unix: unix,
|
||||
authFilePath: constants.OpenVPNAuthConf,
|
||||
tunDevPath: constants.TunnelDevice,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,15 +3,15 @@ package openvpn
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/unix"
|
||||
)
|
||||
|
||||
// CheckTUN checks the tunnel device is present and accessible.
|
||||
func (c *configurator) CheckTUN() error {
|
||||
c.logger.Info("checking for device %s", constants.TunnelDevice)
|
||||
f, err := c.os.OpenFile(constants.TunnelDevice, os.O_RDWR, 0)
|
||||
c.logger.Info("checking for device " + c.tunDevPath)
|
||||
f, err := os.OpenFile(c.tunDevPath, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("TUN device is not available: %w", err)
|
||||
}
|
||||
@@ -22,8 +22,10 @@ func (c *configurator) CheckTUN() error {
|
||||
}
|
||||
|
||||
func (c *configurator) CreateTUN() error {
|
||||
c.logger.Info("creating %s", constants.TunnelDevice)
|
||||
if err := c.os.MkdirAll("/dev/net", 0751); err != nil { //nolint:gomnd
|
||||
c.logger.Info("creating " + c.tunDevPath)
|
||||
|
||||
parentDir := filepath.Dir(c.tunDevPath)
|
||||
if err := os.MkdirAll(parentDir, 0751); err != nil { //nolint:gomnd
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -32,17 +34,13 @@ func (c *configurator) CreateTUN() error {
|
||||
minor = 200
|
||||
)
|
||||
dev := c.unix.Mkdev(major, minor)
|
||||
if err := c.unix.Mknod(constants.TunnelDevice, unix.S_IFCHR, int(dev)); err != nil {
|
||||
if err := c.unix.Mknod(c.tunDevPath, unix.S_IFCHR, int(dev)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := c.os.OpenFile(constants.TunnelDevice, os.O_WRONLY, 0666) //nolint:gomnd
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
const readWriteAllPerms os.FileMode = 0666
|
||||
if err := file.Chmod(readWriteAllPerms); err != nil {
|
||||
_ = file.Close()
|
||||
file, err := os.OpenFile(c.tunDevPath, os.O_WRONLY, readWriteAllPerms)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (c *Cyberghost) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for Cyberghost")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (f *Fastestvpn) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for FastestVPN")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (f *HideMyAss) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for HideMyAss")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (i *Ipvanish) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for Ipvanish")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (i *Ivpn) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for Ivpn")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (m *Mullvad) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding logic is not needed for Mullvad")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (n *Nordvpn) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for NordVPN")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (p *Privado) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for Privado")
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/format"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -28,7 +28,7 @@ var (
|
||||
// PortForward obtains a VPN server side port forwarded from PIA.
|
||||
//nolint:gocognit
|
||||
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
logger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
commonName := p.activeServer.ServerName
|
||||
if !p.activeServer.PortForward {
|
||||
@@ -47,7 +47,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
return
|
||||
}
|
||||
|
||||
data, err := readPIAPortForwardData(openFile)
|
||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@@ -67,7 +67,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
|
||||
if !dataFound || expired {
|
||||
tryUntilSuccessful(ctx, logger, func() error {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
return err
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
@@ -91,7 +92,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
||||
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
@@ -128,7 +129,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
data.Expiration.Format(time.RFC1123) + ", getting another one")
|
||||
oldPort := data.Port
|
||||
for {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
@@ -146,7 +148,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
}
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
||||
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
|
||||
logger.Error("Cannot write port forward data to file: " + err.Error())
|
||||
}
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
@@ -168,8 +170,8 @@ var (
|
||||
)
|
||||
|
||||
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchToken(ctx, openFile, client)
|
||||
gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchToken(ctx, client, authFilePath)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("%w: %s", ErrFetchToken, err)
|
||||
}
|
||||
@@ -179,7 +181,7 @@ func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *htt
|
||||
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err)
|
||||
}
|
||||
|
||||
if err := writePIAPortForwardData(openFile, data); err != nil {
|
||||
if err := writePIAPortForwardData(portForwardPath, data); err != nil {
|
||||
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err)
|
||||
}
|
||||
|
||||
@@ -199,8 +201,8 @@ type piaPortForwardData struct {
|
||||
Expiration time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||
file, err := openFile(constants.PIAPortForward, os.O_RDONLY, 0)
|
||||
func readPIAPortForwardData(portForwardPath string) (data piaPortForwardData, err error) {
|
||||
file, err := os.Open(portForwardPath)
|
||||
if os.IsNotExist(err) {
|
||||
return data, nil
|
||||
} else if err != nil {
|
||||
@@ -216,8 +218,8 @@ func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData,
|
||||
return data, file.Close()
|
||||
}
|
||||
|
||||
func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) {
|
||||
file, err := openFile(constants.PIAPortForward, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
func writePIAPortForwardData(portForwardPath string, data piaPortForwardData) (err error) {
|
||||
file, err := os.OpenFile(portForwardPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -269,9 +271,9 @@ var (
|
||||
errEmptyToken = errors.New("token received is empty")
|
||||
)
|
||||
|
||||
func fetchToken(ctx context.Context, openFile os.OpenFileFunc,
|
||||
client *http.Client) (token string, err error) {
|
||||
username, password, err := getOpenvpnCredentials(openFile)
|
||||
func fetchToken(ctx context.Context, client *http.Client,
|
||||
authFilePath string) (token string, err error) {
|
||||
username, password, err := getOpenvpnCredentials(authFilePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %s", errGetCredentials, err)
|
||||
}
|
||||
@@ -321,8 +323,9 @@ var (
|
||||
errAuthFileMalformed = errors.New("authentication file is malformed")
|
||||
)
|
||||
|
||||
func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) {
|
||||
file, err := openFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0)
|
||||
func getOpenvpnCredentials(authFilePath string) (
|
||||
username, password string, err error) {
|
||||
file, err := os.Open(authFilePath)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
|
||||
}
|
||||
@@ -460,9 +463,8 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(openFile os.OpenFileFunc,
|
||||
filepath string, port uint16) (err error) {
|
||||
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
func writePortForwardedToFile(filepath string, port uint16) (err error) {
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
@@ -12,12 +13,18 @@ type PIA struct {
|
||||
randSource rand.Source
|
||||
timeNow func() time.Time
|
||||
activeServer models.PIAServer
|
||||
// Port forwarding
|
||||
portForwardPath string
|
||||
authFilePath string
|
||||
}
|
||||
|
||||
func New(servers []models.PIAServer, randSource rand.Source, timeNow func() time.Time) *PIA {
|
||||
func New(servers []models.PIAServer, randSource rand.Source,
|
||||
timeNow func() time.Time) *PIA {
|
||||
return &PIA{
|
||||
servers: servers,
|
||||
timeNow: timeNow,
|
||||
randSource: randSource,
|
||||
portForwardPath: constants.PIAPortForward,
|
||||
authFilePath: constants.OpenVPNAuthConf,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (p *Privatevpn) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for PrivateVPN")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (p *Protonvpn) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for ProtonVPN")
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/provider/vyprvpn"
|
||||
"github.com/qdm12/gluetun/internal/provider/windscribe"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
// Provider contains methods to read and modify the openvpn configuration to connect as a client.
|
||||
@@ -38,7 +37,7 @@ type Provider interface {
|
||||
GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error)
|
||||
BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string)
|
||||
PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string))
|
||||
}
|
||||
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (p *Purevpn) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for PureVPN")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (s *Surfshark) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for Surfshark")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (t *Torguard) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for Torguard")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (p *Provider) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for VPN Unlimited")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (v *Vyprvpn) PortForward(ctx context.Context, clienv *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
func (v *Vyprvpn) PortForward(ctx context.Context, client *http.Client,
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for Vyprvpn")
|
||||
}
|
||||
|
||||
@@ -7,11 +7,10 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (w *Windscribe) PortForward(ctx context.Context, clienw *http.Client,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP,
|
||||
fw firewall.Configurator, syncState func(port uint16) (pfFilepath string)) {
|
||||
func (w *Windscribe) PortForward(ctx context.Context, client *http.Client,
|
||||
pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("port forwarding is not supported for Windscribe")
|
||||
}
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package publicip
|
||||
|
||||
import "github.com/qdm12/golibs/os"
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func persistPublicIP(openFile os.OpenFileFunc,
|
||||
filepath string, content string, puid, pgid int) error {
|
||||
file, err := openFile(
|
||||
filepath,
|
||||
os.O_TRUNC|os.O_WRONLY|os.O_CREATE,
|
||||
0644)
|
||||
func persistPublicIP(path string, content string, puid, pgid int) error {
|
||||
file, err := os.OpenFile(path, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -11,7 +12,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
@@ -31,7 +31,6 @@ type looper struct {
|
||||
getter IPGetter
|
||||
client *http.Client
|
||||
logger logging.Logger
|
||||
os os.OS
|
||||
// Fixed settings
|
||||
puid int
|
||||
pgid int
|
||||
@@ -51,8 +50,7 @@ type looper struct {
|
||||
const defaultBackoffTime = 5 * time.Second
|
||||
|
||||
func NewLooper(client *http.Client, logger logging.Logger,
|
||||
settings configuration.PublicIP, puid, pgid int,
|
||||
os os.OS) Looper {
|
||||
settings configuration.PublicIP, puid, pgid int) Looper {
|
||||
return &looper{
|
||||
state: state{
|
||||
status: constants.Stopped,
|
||||
@@ -62,7 +60,6 @@ func NewLooper(client *http.Client, logger logging.Logger,
|
||||
client: client,
|
||||
getter: NewIPGetter(client),
|
||||
logger: logger,
|
||||
os: os,
|
||||
puid: puid,
|
||||
pgid: pgid,
|
||||
start: make(chan struct{}),
|
||||
@@ -136,7 +133,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||
close(errorCh)
|
||||
filepath := l.GetSettings().IPFilepath
|
||||
l.logger.Info("Removing ip file " + filepath)
|
||||
if err := l.os.Remove(filepath); err != nil {
|
||||
if err := os.Remove(filepath); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
return
|
||||
@@ -161,7 +158,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
|
||||
}
|
||||
l.logger.Info(message)
|
||||
|
||||
err = persistPublicIP(l.os.OpenFile, l.state.settings.IPFilepath,
|
||||
err = persistPublicIP(l.state.settings.IPFilepath,
|
||||
ip.String(), l.puid, l.pgid)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
|
||||
7
internal/storage/flush.go
Normal file
7
internal/storage/flush.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package storage
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/models"
|
||||
|
||||
func (s *storage) FlushToFile(allServers models.AllServers) error {
|
||||
return flushToFile(s.filepath, allServers)
|
||||
}
|
||||
@@ -4,7 +4,6 @@ package storage
|
||||
import (
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
@@ -14,14 +13,12 @@ type Storage interface {
|
||||
}
|
||||
|
||||
type storage struct {
|
||||
os os.OS
|
||||
logger logging.Logger
|
||||
filepath string
|
||||
}
|
||||
|
||||
func New(logger logging.Logger, os os.OS, filepath string) Storage {
|
||||
func New(logger logging.Logger, filepath string) Storage {
|
||||
return &storage{
|
||||
os: os,
|
||||
logger: logger,
|
||||
filepath: filepath,
|
||||
}
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -39,7 +39,7 @@ func countServers(allServers models.AllServers) int {
|
||||
|
||||
func (s *storage) SyncServers(hardcodedServers models.AllServers) (
|
||||
allServers models.AllServers, err error) {
|
||||
serversOnFile, err := s.readFromFile(s.filepath)
|
||||
serversOnFile, err := readFromFile(s.filepath)
|
||||
if err != nil {
|
||||
return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err)
|
||||
}
|
||||
@@ -62,14 +62,14 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers) (
|
||||
return allServers, nil
|
||||
}
|
||||
|
||||
if err := s.FlushToFile(allServers); err != nil {
|
||||
if err := flushToFile(s.filepath, allServers); err != nil {
|
||||
return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err)
|
||||
}
|
||||
return allServers, nil
|
||||
}
|
||||
|
||||
func (s *storage) readFromFile(filepath string) (servers models.AllServers, err error) {
|
||||
file, err := s.os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
func readFromFile(filepath string) (servers models.AllServers, err error) {
|
||||
file, err := os.Open(filepath)
|
||||
if os.IsNotExist(err) {
|
||||
return servers, nil
|
||||
} else if err != nil {
|
||||
@@ -86,13 +86,13 @@ func (s *storage) readFromFile(filepath string) (servers models.AllServers, err
|
||||
return servers, file.Close()
|
||||
}
|
||||
|
||||
func (s *storage) FlushToFile(servers models.AllServers) error {
|
||||
dirPath := filepath.Dir(s.filepath)
|
||||
if err := s.os.MkdirAll(dirPath, 0644); err != nil {
|
||||
func flushToFile(path string, servers models.AllServers) error {
|
||||
dirPath := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dirPath, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := s.os.OpenFile(s.filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user