chore(errors): review all errors in codebase

This commit is contained in:
Quentin McGaw
2022-02-20 02:58:16 +00:00
parent ac4a4f83fc
commit 920ad8b54b
88 changed files with 254 additions and 460 deletions

View File

@@ -56,11 +56,6 @@ var (
created = "an unknown date" created = "an unknown date"
) )
var (
errSetupRouting = errors.New("cannot setup routing")
errCreateUser = errors.New("cannot create user")
)
func main() { func main() {
buildInfo := models.BuildInformation{ buildInfo := models.BuildInformation{
Version: version, Version: version,
@@ -278,7 +273,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
const defaultUsername = "nonrootuser" const defaultUsername = "nonrootuser"
nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid) nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errCreateUser, err) return fmt.Errorf("cannot create user: %w", err)
} }
if nonRootUsername != defaultUsername { if nonRootUsername != defaultUsername {
logger.Info("using existing username " + nonRootUsername + " corresponding to user id " + fmt.Sprint(puid)) logger.Info("using existing username " + nonRootUsername + " corresponding to user id " + fmt.Sprint(puid))
@@ -296,7 +291,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
if strings.Contains(err.Error(), "operation not permitted") { if strings.Contains(err.Error(), "operation not permitted") {
logger.Warn("💡 Tip: Are you passing NET_ADMIN capability to gluetun?") logger.Warn("💡 Tip: Are you passing NET_ADMIN capability to gluetun?")
} }
return fmt.Errorf("%w: %s", errSetupRouting, err) return fmt.Errorf("cannot setup routing: %w", err)
} }
defer func() { defer func() {
logger.Info("routing cleanup...") logger.Info("routing cleanup...")

View File

@@ -18,9 +18,6 @@ type ServersFormatter interface {
var ( var (
ErrFormatNotRecognized = errors.New("format is not recognized") ErrFormatNotRecognized = errors.New("format is not recognized")
ErrProviderUnspecified = errors.New("VPN provider to format was not specified") ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
ErrOpenOutputFile = errors.New("cannot open output file")
ErrWriteOutput = errors.New("cannot write to output file")
ErrCloseOutputFile = errors.New("cannot close output file")
) )
func (c *CLI) FormatServers(args []string) error { func (c *CLI) FormatServers(args []string) error {
@@ -62,7 +59,7 @@ func (c *CLI) FormatServers(args []string) error {
logger := newNoopLogger() logger := newNoopLogger()
storage, err := storage.New(logger, constants.ServersData) storage, err := storage.New(logger, constants.ServersData)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrNewStorage, err) return fmt.Errorf("cannot create servers storage: %w", err)
} }
currentServers := storage.GetServers() currentServers := storage.GetServers()
@@ -115,18 +112,18 @@ func (c *CLI) FormatServers(args []string) error {
output = filepath.Clean(output) output = filepath.Clean(output)
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644) file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrOpenOutputFile, err) return fmt.Errorf("cannot open output file: %w", err)
} }
_, err = fmt.Fprint(file, formatted) _, err = fmt.Fprint(file, formatted)
if err != nil { if err != nil {
_ = file.Close() _ = file.Close()
return fmt.Errorf("%w: %s", ErrWriteOutput, err) return fmt.Errorf("cannot write to output file: %w", err)
} }
err = file.Close() err = file.Close()
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrCloseOutputFile, err) return fmt.Errorf("cannot close output file: %w", err)
} }
return nil return nil

View File

@@ -23,9 +23,6 @@ var (
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified") ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
ErrDNSAddress = errors.New("DNS address is not valid") ErrDNSAddress = errors.New("DNS address is not valid")
ErrNoProviderSpecified = errors.New("no provider was specified") ErrNoProviderSpecified = errors.New("no provider was specified")
ErrNewStorage = errors.New("cannot create storage")
ErrUpdateServerInformation = errors.New("cannot update server information")
ErrWriteToFile = errors.New("cannot write updated information to file")
) )
type Updater interface { type Updater interface {
@@ -90,25 +87,25 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
storage, err := storage.New(logger, constants.ServersData) storage, err := storage.New(logger, constants.ServersData)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrNewStorage, err) return fmt.Errorf("cannot create servers storage: %w", err)
} }
currentServers := storage.GetServers() currentServers := storage.GetServers()
updater := updater.New(options, httpClient, currentServers, logger) updater := updater.New(options, httpClient, currentServers, logger)
allServers, err := updater.UpdateServers(ctx) allServers, err := updater.UpdateServers(ctx)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrUpdateServerInformation, err) return fmt.Errorf("cannot update server information: %w", err)
} }
if endUserMode { if endUserMode {
if err := storage.FlushToFile(allServers); err != nil { if err := storage.FlushToFile(allServers); err != nil {
return fmt.Errorf("%w: %s", ErrWriteToFile, err) return fmt.Errorf("cannot write updated information to file: %w", err)
} }
} }
if maintainerMode { if maintainerMode {
if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil { if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil {
return fmt.Errorf("%w: %s", ErrWriteToFile, err) return fmt.Errorf("cannot write updated information to file: %w", err)
} }
} }

View File

@@ -4,19 +4,15 @@ import "errors"
var ( var (
ErrCityNotValid = errors.New("the city specified is not valid") ErrCityNotValid = errors.New("the city specified is not valid")
ErrControlServerAddress = errors.New("listening address it not valid")
ErrControlServerPort = errors.New("listening port it not valid")
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root") ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
ErrCountryNotValid = errors.New("the country specified is not valid") ErrCountryNotValid = errors.New("the country specified is not valid")
ErrFilepathMissing = errors.New("filepath is missing")
ErrFirewallZeroPort = errors.New("cannot have a zero port to block") ErrFirewallZeroPort = errors.New("cannot have a zero port to block")
ErrHostnameNotValid = errors.New("the hostname specified is not valid") ErrHostnameNotValid = errors.New("the hostname specified is not valid")
ErrISPNotValid = errors.New("the ISP specified is not valid") ErrISPNotValid = errors.New("the ISP specified is not valid")
ErrMissingValue = errors.New("missing value")
ErrNameNotValid = errors.New("the server name specified is not valid") ErrNameNotValid = errors.New("the server name specified is not valid")
ErrOpenVPNClientCertMissing = errors.New("client certificate is missing")
ErrOpenVPNClientCertNotValid = errors.New("client certificate is not valid")
ErrOpenVPNClientKeyMissing = errors.New("client key is missing") ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
ErrOpenVPNClientKeyNotValid = errors.New("client key is not valid")
ErrOpenVPNConfigFile = errors.New("custom configuration file error")
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed") ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid") ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid") ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
@@ -27,8 +23,6 @@ var (
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds") ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid") ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled") ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
ErrPortForwardingFilepathNotValid = errors.New("port forwarding filepath given is not valid")
ErrPublicIPFilepathNotValid = errors.New("public IP address file path is not valid")
ErrPublicIPPeriodTooShort = errors.New("public IP address check period is too short") ErrPublicIPPeriodTooShort = errors.New("public IP address check period is too short")
ErrRegionNotValid = errors.New("the region specified is not valid") ErrRegionNotValid = errors.New("the region specified is not valid")
ErrServerAddressNotValid = errors.New("server listening address is not valid") ErrServerAddressNotValid = errors.New("server listening address is not valid")
@@ -44,9 +38,7 @@ var (
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set") ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid") ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set") ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
ErrWireguardPreSharedKeyNotValid = errors.New("pre-shared key is not valid")
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set") ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
ErrWireguardPrivateKeyNotValid = errors.New("private key is not valid")
ErrWireguardPublicKeyNotSet = errors.New("public key is not set") ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid") ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
) )

View File

@@ -27,13 +27,12 @@ func (h Health) Validate() (err error) {
_, err = address.Validate(h.ServerAddress, _, err = address.Validate(h.ServerAddress,
address.OptionListening(uid)) address.OptionListening(uid))
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", return fmt.Errorf("server listening address is not valid: %w", err)
ErrServerAddressNotValid, err)
} }
err = h.VPN.validate() err = h.VPN.validate()
if err != nil { if err != nil {
return fmt.Errorf("health VPN settings validation failed: %w", err) return fmt.Errorf("health VPN settings: %w", err)
} }
return nil return nil

View File

@@ -41,8 +41,7 @@ func (h HTTPProxy) validate() (err error) {
uid := os.Getuid() uid := os.Getuid()
_, err = address.Validate(h.ListeningAddress, address.OptionListening(uid)) _, err = address.Validate(h.ListeningAddress, address.OptionListening(uid))
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", return fmt.Errorf("%w: %s", ErrServerAddressNotValid, h.ListeningAddress)
ErrServerAddressNotValid, h.ListeningAddress)
} }
return nil return nil

View File

@@ -93,17 +93,17 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile) err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
if err != nil { if err != nil {
return err return fmt.Errorf("custom configuration file: %w", err)
} }
err = validateOpenVPNClientCertificate(vpnProvider, *o.ClientCrt) err = validateOpenVPNClientCertificate(vpnProvider, *o.ClientCrt)
if err != nil { if err != nil {
return err return fmt.Errorf("client certificate: %w", err)
} }
err = validateOpenVPNClientKey(vpnProvider, *o.ClientKey) err = validateOpenVPNClientKey(vpnProvider, *o.ClientKey)
if err != nil { if err != nil {
return err return fmt.Errorf("client key: %w", err)
} }
const maxMSSFix = 10000 const maxMSSFix = 10000
@@ -132,12 +132,12 @@ func validateOpenVPNConfigFilepath(isCustom bool,
} }
if confFile == "" { if confFile == "" {
return fmt.Errorf("%w: no file path specified", ErrOpenVPNConfigFile) return ErrFilepathMissing
} }
err = helpers.FileExists(confFile) err = helpers.FileExists(confFile)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrOpenVPNConfigFile, err) return err
} }
return nil return nil
@@ -150,7 +150,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
constants.Cyberghost, constants.Cyberghost,
constants.VPNUnlimited: constants.VPNUnlimited:
if clientCert == "" { if clientCert == "" {
return ErrOpenVPNClientCertMissing return ErrMissingValue
} }
} }
@@ -160,7 +160,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
_, err = parse.ExtractCert([]byte(clientCert)) _, err = parse.ExtractCert([]byte(clientCert))
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrOpenVPNClientCertNotValid, err) return err
} }
return nil return nil
} }
@@ -172,7 +172,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
constants.VPNUnlimited, constants.VPNUnlimited,
constants.Wevpn: constants.Wevpn:
if clientKey == "" { if clientKey == "" {
return ErrOpenVPNClientKeyMissing return ErrMissingValue
} }
} }
@@ -182,7 +182,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
_, err = parse.ExtractPrivateKey([]byte(clientKey)) _, err = parse.ExtractPrivateKey([]byte(clientKey))
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrOpenVPNClientKeyNotValid, err) return err
} }
return nil return nil
} }

View File

@@ -33,14 +33,16 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
if confFile := *o.ConfFile; confFile != "" { if confFile := *o.ConfFile; confFile != "" {
err := helpers.FileExists(confFile) err := helpers.FileExists(confFile)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrOpenVPNConfigFile, err) return fmt.Errorf("configuration file: %w", err)
} }
} }
// Validate TCP // Validate TCP
if *o.TCP && helpers.IsOneOf(vpnProvider, if *o.TCP && helpers.IsOneOf(vpnProvider,
constants.Ipvanish,
constants.Perfectprivacy, constants.Perfectprivacy,
constants.Privado, constants.Privado,
constants.VPNUnlimited,
constants.Vyprvpn, constants.Vyprvpn,
) { ) {
return fmt.Errorf("%w: for VPN service provider %s", return fmt.Errorf("%w: for VPN service provider %s",

View File

@@ -38,7 +38,7 @@ func (p PortForwarding) validate(vpnProvider string) (err error) {
if *p.Filepath != "" { // optional if *p.Filepath != "" { // optional
_, err := filepath.Abs(*p.Filepath) _, err := filepath.Abs(*p.Filepath)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrPortForwardingFilepathNotValid, err) return fmt.Errorf("filepath is not valid: %w", err)
} }
} }

View File

@@ -43,12 +43,12 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e
err = p.ServerSelection.validate(*p.Name, allServers) err = p.ServerSelection.validate(*p.Name, allServers)
if err != nil { if err != nil {
return fmt.Errorf("server selection settings validation failed: %w", err) return fmt.Errorf("server selection: %w", err)
} }
err = p.PortForwarding.validate(*p.Name) err = p.PortForwarding.validate(*p.Name)
if err != nil { if err != nil {
return fmt.Errorf("port forwarding settings validation failed: %w", err) return fmt.Errorf("port forwarding: %w", err)
} }
return nil return nil

View File

@@ -33,7 +33,7 @@ func (p PublicIP) validate() (err error) {
if *p.IPFilepath != "" { // optional if *p.IPFilepath != "" { // optional
_, err := filepath.Abs(*p.IPFilepath) _, err := filepath.Abs(*p.IPFilepath)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrPublicIPFilepathNotValid, err) return fmt.Errorf("filepath is not valid: %w", err)
} }
} }

View File

@@ -23,12 +23,12 @@ type ControlServer struct {
func (c ControlServer) validate() (err error) { func (c ControlServer) validate() (err error) {
_, portStr, err := net.SplitHostPort(*c.Address) _, portStr, err := net.SplitHostPort(*c.Address)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrControlServerAddress, err) return fmt.Errorf("listening address is not valid: %w", err)
} }
port, err := strconv.Atoi(portStr) port, err := strconv.Atoi(portStr)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrControlServerPort, err) return fmt.Errorf("listening port it not valid: %w", err)
} }
uid := os.Getuid() uid := os.Getuid()

View File

@@ -116,18 +116,18 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
if *ss.MultiHopOnly && if *ss.MultiHopOnly &&
vpnServiceProvider != constants.Surfshark { vpnServiceProvider != constants.Surfshark {
return fmt.Errorf("%w: for VPN service provider %s", return fmt.Errorf("%w: for VPN service provider %s",
ErrStreamOnlyNotSupported, vpnServiceProvider) ErrMultiHopOnlyNotSupported, vpnServiceProvider)
} }
if ss.VPN == constants.OpenVPN { if ss.VPN == constants.OpenVPN {
err = ss.OpenVPN.validate(vpnServiceProvider) err = ss.OpenVPN.validate(vpnServiceProvider)
if err != nil { if err != nil {
return fmt.Errorf("OpenVPN server selection settings validation failed: %w", err) return fmt.Errorf("OpenVPN server selection settings: %w", err)
} }
} else { } else {
err = ss.Wireguard.validate(vpnServiceProvider) err = ss.Wireguard.validate(vpnServiceProvider)
if err != nil { if err != nil {
return fmt.Errorf("Wireguard server selection settings validation failed: %w", err) return fmt.Errorf("Wireguard server selection settings: %w", err)
} }
} }

View File

@@ -49,7 +49,7 @@ func (s *Settings) Validate(allServers models.AllServers) (err error) {
for name, validation := range nameToValidation { for name, validation := range nameToValidation {
err = validation() err = validation()
if err != nil { if err != nil {
return fmt.Errorf("failed validating %s settings: %w", name, err) return fmt.Errorf("%s settings: %w", name, err)
} }
} }

View File

@@ -31,18 +31,18 @@ func (v *VPN) validate(allServers models.AllServers) (err error) {
err = v.Provider.validate(v.Type, allServers) err = v.Provider.validate(v.Type, allServers)
if err != nil { if err != nil {
return fmt.Errorf("provider settings validation failed: %w", err) return fmt.Errorf("provider settings: %w", err)
} }
if v.Type == constants.OpenVPN { if v.Type == constants.OpenVPN {
err := v.OpenVPN.validate(*v.Provider.Name) err := v.OpenVPN.validate(*v.Provider.Name)
if err != nil { if err != nil {
return fmt.Errorf("OpenVPN settings validation failed: %w", err) return fmt.Errorf("OpenVPN settings: %w", err)
} }
} else { } else {
err := v.Wireguard.validate(*v.Provider.Name) err := v.Wireguard.validate(*v.Provider.Name)
if err != nil { if err != nil {
return fmt.Errorf("Wireguard settings validation failed: %w", err) return fmt.Errorf("Wireguard settings: %w", err)
} }
} }

View File

@@ -50,14 +50,14 @@ func (w Wireguard) validate(vpnProvider string) (err error) {
} }
_, err = wgtypes.ParseKey(*w.PrivateKey) _, err = wgtypes.ParseKey(*w.PrivateKey)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrWireguardPrivateKeyNotValid, err) return fmt.Errorf("private key is not valid: %w", err)
} }
// Validate PreSharedKey // Validate PreSharedKey
if *w.PreSharedKey != "" { // Note: this is optional if *w.PreSharedKey != "" { // Note: this is optional
_, err = wgtypes.ParseKey(*w.PreSharedKey) _, err = wgtypes.ParseKey(*w.PreSharedKey)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrWireguardPreSharedKeyNotValid, err) return fmt.Errorf("pre-shared key is not valid: %w", err)
} }
} }

View File

@@ -20,7 +20,7 @@ func (r *Reader) readDNS() (dns settings.DNS, err error) {
dns.DoT, err = r.readDoT() dns.DoT, err = r.readDoT()
if err != nil { if err != nil {
return dns, fmt.Errorf("cannot read DoT settings: %w", err) return dns, fmt.Errorf("DoT settings: %w", err)
} }
return dns, nil return dns, nil

View File

@@ -55,8 +55,7 @@ func stringsToPorts(ss []string) (ports []uint16, err error) {
for i, s := range ss { for i, s := range ss {
port, err := strconv.Atoi(s) port, err := strconv.Atoi(s)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s: %s", return nil, fmt.Errorf("%w: %s: %s", ErrPortParsing, s, err)
ErrPortParsing, s, err)
} else if port < 1 || port > 65535 { } else if port < 1 || port > 65535 {
return nil, fmt.Errorf("%w: must be between 1 and 65535: %d", return nil, fmt.Errorf("%w: must be between 1 and 65535: %d",
ErrPortValue, port) ErrPortValue, port)
@@ -66,10 +65,6 @@ func stringsToPorts(ss []string) (ports []uint16, err error) {
return ports, nil return ports, nil
} }
var (
ErrIPNetParsing = errors.New("cannot parse IP network")
)
func stringsToIPNets(ss []string) (ipNets []net.IPNet, err error) { func stringsToIPNets(ss []string) (ipNets []net.IPNet, err error) {
if len(ss) == 0 { if len(ss) == 0 {
return nil, nil return nil, nil
@@ -78,8 +73,7 @@ func stringsToIPNets(ss []string) (ipNets []net.IPNet, err error) {
for i, s := range ss { for i, s := range ss {
ip, ipNet, err := net.ParseCIDR(s) ip, ipNet, err := net.ParseCIDR(s)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s: %s", return nil, fmt.Errorf("cannot parse IP network %q: %w", s, err)
ErrIPNetParsing, s, err)
} }
ipNet.IP = ip ipNet.IP = ip
ipNets[i] = *ipNet ipNets[i] = *ipNet

View File

@@ -38,9 +38,7 @@ func (r *Reader) readDurationWithRetro(envKey, retroEnvKey string) (d *time.Dura
d = new(time.Duration) d = new(time.Duration)
*d, err = time.ParseDuration(s) *d, err = time.ParseDuration(s)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf("environment variable %s: %w", envKey, err)
"environment variable %s: %w",
envKey, err)
} }
return d, nil return d, nil

View File

@@ -2,7 +2,6 @@ package env
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@@ -115,13 +114,11 @@ func lowerAndSplit(csv string) (values []string) {
return strings.Split(csv, ",") return strings.Split(csv, ",")
} }
var ErrDecodeBase64 = errors.New("cannot decode base64 string")
func decodeBase64(b64String string) (decoded string, err error) { func decodeBase64(b64String string) (decoded string, err error) {
b, err := base64.StdEncoding.DecodeString(b64String) b, err := base64.StdEncoding.DecodeString(b64String)
if err != nil { if err != nil {
return "", fmt.Errorf("%w: %s: %s", return "", fmt.Errorf("cannot decode base64 string %q: %w",
ErrDecodeBase64, b64String, err) b64String, err)
} }
return string(b), nil return string(b), nil
} }

View File

@@ -48,7 +48,7 @@ func parseLogLevel(s string) (level logging.Level, err error) {
return logging.LevelError, nil return logging.LevelError, nil
default: default:
return level, fmt.Errorf( return level, fmt.Errorf(
"%w: %s: can be one of: debug, info, warning or error", "%w: %q is not valid and can be one of debug, info, warning or error",
ErrLogLevelUnknown, s) ErrLogLevelUnknown, s)
} }
} }

View File

@@ -18,12 +18,12 @@ func (r *Reader) readProvider(vpnType string) (provider settings.Provider, err e
provider.ServerSelection, err = r.readServerSelection(providerName, vpnType) provider.ServerSelection, err = r.readServerSelection(providerName, vpnType)
if err != nil { if err != nil {
return provider, fmt.Errorf("cannot read server selection settings: %w", err) return provider, fmt.Errorf("server selection: %w", err)
} }
provider.PortForwarding, err = r.readPortForward() provider.PortForwarding, err = r.readPortForward()
if err != nil { if err != nil {
return provider, fmt.Errorf("cannot read port forwarding settings: %w", err) return provider, fmt.Errorf("port forwarding: %w", err)
} }
return provider, nil return provider, nil

View File

@@ -13,17 +13,17 @@ func (r *Reader) readVPN() (vpn settings.VPN, err error) {
vpn.Provider, err = r.readProvider(vpn.Type) vpn.Provider, err = r.readProvider(vpn.Type)
if err != nil { if err != nil {
return vpn, fmt.Errorf("cannot read provider settings: %w", err) return vpn, fmt.Errorf("VPN provider: %w", err)
} }
vpn.OpenVPN, err = r.readOpenVPN() vpn.OpenVPN, err = r.readOpenVPN()
if err != nil { if err != nil {
return vpn, fmt.Errorf("cannot read OpenVPN settings: %w", err) return vpn, fmt.Errorf("OpenVPN: %w", err)
} }
vpn.Wireguard, err = r.readWireguard() vpn.Wireguard, err = r.readWireguard()
if err != nil { if err != nil {
return vpn, fmt.Errorf("cannot read Wireguard settings: %w", err) return vpn, fmt.Errorf("wireguard: %w", err)
} }
return vpn, nil return vpn, nil

View File

@@ -16,12 +16,12 @@ const (
func (r *Reader) readOpenVPN() (settings settings.OpenVPN, err error) { func (r *Reader) readOpenVPN() (settings settings.OpenVPN, err error) {
settings.ClientKey, err = ReadFromFile(OpenVPNClientKeyPath) settings.ClientKey, err = ReadFromFile(OpenVPNClientKeyPath)
if err != nil { if err != nil {
return settings, fmt.Errorf("cannot read client key: %w", err) return settings, fmt.Errorf("client key: %w", err)
} }
settings.ClientCrt, err = ReadFromFile(OpenVPNClientCertificatePath) settings.ClientCrt, err = ReadFromFile(OpenVPNClientCertificatePath)
if err != nil { if err != nil {
return settings, fmt.Errorf("cannot read client certificate: %w", err) return settings, fmt.Errorf("client certificate: %w", err)
} }
return settings, nil return settings, nil

View File

@@ -9,7 +9,7 @@ import (
func (r *Reader) readVPN() (vpn settings.VPN, err error) { func (r *Reader) readVPN() (vpn settings.VPN, err error) {
vpn.OpenVPN, err = r.readOpenVPN() vpn.OpenVPN, err = r.readOpenVPN()
if err != nil { if err != nil {
return vpn, fmt.Errorf("cannot read OpenVPN settings: %w", err) return vpn, fmt.Errorf("OpenVPN: %w", err)
} }
return vpn, nil return vpn, nil

View File

@@ -26,7 +26,7 @@ func (r *Reader) Read() (settings settings.Settings, err error) {
for _, source := range r.sources { for _, source := range r.sources {
settingsFromSource, err := source.Read() settingsFromSource, err := source.Read()
if err != nil { if err != nil {
return settings, fmt.Errorf("cannot read from source %T: %w", source, err) return settings, fmt.Errorf("reading from source %T: %w", source, err)
} }
settings.MergeWith(settingsFromSource) settings.MergeWith(settingsFromSource)
} }
@@ -42,7 +42,7 @@ func (r *Reader) ReadHealth() (settings settings.Health, err error) {
for _, source := range r.sources { for _, source := range r.sources {
settingsFromSource, err := source.ReadHealth() settingsFromSource, err := source.ReadHealth()
if err != nil { if err != nil {
return settings, fmt.Errorf("cannot read from source %T: %w", source, err) return settings, fmt.Errorf("reading from source %T: %w", source, err)
} }
settings.MergeWith(settingsFromSource) settings.MergeWith(settingsFromSource)
} }

View File

@@ -2,16 +2,9 @@ package firewall
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
) )
var (
ErrEnable = errors.New("failed enabling firewall")
ErrDisable = errors.New("failed disabling firewall")
ErrUserPostRules = errors.New("cannot run user post firewall rules")
)
type Enabler interface { type Enabler interface {
SetEnabled(ctx context.Context, enabled bool) (err error) SetEnabled(ctx context.Context, enabled bool) (err error)
} }
@@ -32,7 +25,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
if !enabled { if !enabled {
c.logger.Info("disabling...") c.logger.Info("disabling...")
if err = c.disable(ctx); err != nil { if err = c.disable(ctx); err != nil {
return err return fmt.Errorf("cannot disable firewall: %w", err)
} }
c.enabled = false c.enabled = false
c.logger.Info("disabled successfully") c.logger.Info("disabled successfully")
@@ -42,7 +35,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
c.logger.Info("enabling...") c.logger.Info("enabling...")
if err := c.enable(ctx); err != nil { if err := c.enable(ctx); err != nil {
return fmt.Errorf("%w: %s", ErrEnable, err) return fmt.Errorf("cannot enable firewall: %w", err)
} }
c.enabled = true c.enabled = true
c.logger.Info("enabled successfully") c.logger.Info("enabled successfully")
@@ -52,13 +45,13 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
func (c *Config) disable(ctx context.Context) (err error) { func (c *Config) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil { if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err) return fmt.Errorf("cannot clear all rules: %w", err)
} }
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil { if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err) return fmt.Errorf("cannot set ipv4 policies: %w", err)
} }
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil { if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("cannot disable firewall: %w", err) return fmt.Errorf("cannot set ipv6 policies: %w", err)
} }
return nil return nil
} }
@@ -76,12 +69,12 @@ func (c *Config) fallbackToDisabled(ctx context.Context) {
func (c *Config) enable(ctx context.Context) (err error) { func (c *Config) enable(ctx context.Context) (err error) {
touched := false touched := false
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil { if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
touched = true touched = true
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil { if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
const remove = false const remove = false
@@ -94,33 +87,33 @@ func (c *Config) enable(ctx context.Context) (err error) {
// Loopback traffic // Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil { if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil { if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil { if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
if c.vpnConnection.IP != nil { if c.vpnConnection.IP != nil {
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil { if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil { if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
} }
for _, network := range c.localNetworks { for _, network := range c.localNetworks {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil { if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
} }
for _, subnet := range c.outboundSubnets { for _, subnet := range c.outboundSubnets {
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil { if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
} }
@@ -128,18 +121,18 @@ func (c *Config) enable(ctx context.Context) (err error) {
// to reach Gluetun. // to reach Gluetun.
for _, network := range c.localNetworks { for _, network := range c.localNetworks {
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, *network.IPNet, remove); err != nil { if err := c.acceptInputToSubnet(ctx, network.InterfaceName, *network.IPNet, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
} }
for port, intf := range c.allowedInputPorts { for port, intf := range c.allowedInputPorts {
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return err
} }
} }
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil { if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("%w: %s", ErrUserPostRules, err) return fmt.Errorf("cannot run user defined post firewall rules: %w", err)
} }
return nil return nil

View File

@@ -11,7 +11,6 @@ import (
) )
var ( var (
ErrIP6Tables = errors.New("failed ip6tables command")
ErrIP6NotSupported = errors.New("ip6tables not supported") ErrIP6NotSupported = errors.New("ip6tables not supported")
) )
@@ -44,18 +43,18 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
flags := strings.Fields(instruction) flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, "ip6tables", flags...) cmd := exec.CommandContext(ctx, "ip6tables", flags...)
if output, err := c.runner.Run(cmd); err != nil { if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("%w: \"ip6tables %s\": %s: %s", ErrIP6Tables, instruction, output, err) return fmt.Errorf("command failed: \"ip6tables %s\": %s: %w", instruction, output, err)
} }
return nil return nil
} }
var errPolicyNotValid = errors.New("policy is not valid") var ErrPolicyNotValid = errors.New("policy is not valid")
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error { func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy { switch policy {
case "ACCEPT", "DROP": case "ACCEPT", "DROP":
default: default:
return fmt.Errorf("%w: %s", errPolicyNotValid, policy) return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
} }
return c.runIP6tablesInstructions(ctx, []string{ return c.runIP6tablesInstructions(ctx, []string{
"--policy INPUT " + policy, "--policy INPUT " + policy,

View File

@@ -16,10 +16,7 @@ import (
var ( var (
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short") ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
ErrIPTables = errors.New("failed iptables command")
ErrPolicyUnknown = errors.New("unknown policy") ErrPolicyUnknown = errors.New("unknown policy")
ErrClearRules = errors.New("cannot clear all rules")
ErrSetIPtablesPolicies = errors.New("cannot set iptables policies")
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it") ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
) )
@@ -79,33 +76,30 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
flags := strings.Fields(instruction) flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, "iptables", flags...) cmd := exec.CommandContext(ctx, "iptables", flags...)
if output, err := c.runner.Run(cmd); err != nil { if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("%w \"iptables %s\": %s: %s", ErrIPTables, instruction, output, err) return fmt.Errorf("command failed: \"iptables %s\": %s: %w", instruction, output, err)
} }
return nil return nil
} }
func (c *Config) clearAllRules(ctx context.Context) error { func (c *Config) clearAllRules(ctx context.Context) error {
if err := c.runMixedIptablesInstructions(ctx, []string{ return c.runMixedIptablesInstructions(ctx, []string{
"--flush", // flush all chains "--flush", // flush all chains
"--delete-chain", // delete all chains "--delete-chain", // delete all chains
}); err != nil { })
return fmt.Errorf("%w: %s", ErrClearRules, err.Error())
}
return nil
} }
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error { func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
switch policy { switch policy {
case "ACCEPT", "DROP": case "ACCEPT", "DROP":
default: default:
return fmt.Errorf("%w: %s: %s", ErrSetIPtablesPolicies, ErrPolicyUnknown, policy) return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
} }
if err := c.runIptablesInstructions(ctx, []string{ if err := c.runIptablesInstructions(ctx, []string{
"--policy INPUT " + policy, "--policy INPUT " + policy,
"--policy OUTPUT " + policy, "--policy OUTPUT " + policy,
"--policy FORWARD " + policy, "--policy FORWARD " + policy,
}); err != nil { }); err != nil {
return fmt.Errorf("%w: %s", ErrSetIPtablesPolicies, err) return err
} }
return nil return nil
} }

View File

@@ -23,7 +23,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e
return nil return nil
} }
c.logger.Info("setting allowed subnets through firewall...") c.logger.Info("setting allowed subnets...")
subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange(c.outboundSubnets, subnets) subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange(c.outboundSubnets, subnets)
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 { if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
@@ -32,7 +32,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e
c.removeOutboundSubnets(ctx, subnetsToRemove) c.removeOutboundSubnets(ctx, subnetsToRemove)
if err := c.addOutboundSubnets(ctx, subnetsToAdd); err != nil { if err := c.addOutboundSubnets(ctx, subnetsToAdd); err != nil {
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) return fmt.Errorf("cannot set allowed outbound subnets: %w", err)
} }
return nil return nil
@@ -42,7 +42,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet)
const remove = true const remove = true
for _, subNet := range subnets { for _, subNet := range subnets {
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subNet, remove); err != nil { if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subNet, remove); err != nil {
c.logger.Error("cannot remove outdated outbound subnet through firewall: " + err.Error()) c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
continue continue
} }
c.outboundSubnets = subnet.RemoveSubnetFromSubnets(c.outboundSubnets, subNet) c.outboundSubnets = subnet.RemoveSubnetFromSubnets(c.outboundSubnets, subNet)
@@ -53,7 +53,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) er
const remove = false const remove = false
for _, subnet := range subnets { for _, subnet := range subnets {
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil { if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err) return err
} }
c.outboundSubnets = append(c.outboundSubnets, subnet) c.outboundSubnets = append(c.outboundSubnets, subnet)
} }

View File

@@ -33,13 +33,13 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
} }
const remove = true const remove = true
if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil { if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil {
return fmt.Errorf("cannot remove old allowed port %d through interface %s: %w", port, existingIntf, err) return fmt.Errorf("cannot remove old allowed port %d: %w", port, err)
} }
} }
const remove = false const remove = false
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("cannot set allowed port %d through interface %s: %w", port, intf, err) return fmt.Errorf("cannot allow input to port %d: %w", port, err)
} }
c.allowedInputPorts[port] = intf c.allowedInputPorts[port] = intf
@@ -60,7 +60,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
return nil return nil
} }
c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + " through firewall...") c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + " ...")
intf, ok := c.allowedInputPorts[port] intf, ok := c.allowedInputPorts[port]
if !ok { if !ok {
@@ -69,7 +69,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
const remove = true const remove = true
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("cannot remove allowed port %d through interface %s: %w", port, intf, err) return fmt.Errorf("cannot remove allowed port %d: %w", port, err)
} }
delete(c.allowedInputPorts, port) delete(c.allowedInputPorts, port)

View File

@@ -23,7 +23,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
return nil return nil
} }
c.logger.Info("setting VPN connection through firewall...") c.logger.Info("allowing VPN connection...")
if c.vpnConnection.Equal(connection) { if c.vpnConnection.Equal(connection) {
return nil return nil
@@ -32,14 +32,14 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove := true remove := true
if c.vpnConnection.IP != nil { if c.vpnConnection.IP != nil {
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil { if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection through firewall: " + err.Error()) c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
} }
} }
c.vpnConnection = models.Connection{} c.vpnConnection = models.Connection{}
if c.vpnIntf != "" { if c.vpnIntf != "" {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil { if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
c.logger.Error("cannot remove outdated VPN interface from firewall: " + err.Error()) c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
} }
} }
c.vpnIntf = "" c.vpnIntf = ""
@@ -47,7 +47,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
remove = false remove = false
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil { if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
return fmt.Errorf("cannot set VPN connection through firewall: %w", err) return fmt.Errorf("cannot allow output traffic through VPN connection: %w", err)
} }
c.vpnConnection = connection c.vpnConnection = connection

View File

@@ -45,5 +45,6 @@ func (c *Client) Check(ctx context.Context, url string) error {
if err != nil { if err != nil {
return err return err
} }
return fmt.Errorf("%w: %s: %s", ErrHTTPStatusNotOK, response.Status, string(b)) return fmt.Errorf("%w: %d %s: %s", ErrHTTPStatusNotOK,
response.StatusCode, response.Status, string(b))
} }

View File

@@ -17,12 +17,12 @@ func (e *Extractor) Data(filepath string) (lines []string,
connection models.Connection, err error) { connection models.Connection, err error) {
lines, err = readCustomConfigLines(filepath) lines, err = readCustomConfigLines(filepath)
if err != nil { if err != nil {
return nil, connection, fmt.Errorf("%w: %s", ErrRead, err) return nil, connection, fmt.Errorf("cannot read configuration file: %w", err)
} }
connection, err = extractDataFromLines(lines) connection, err = extractDataFromLines(lines)
if err != nil { if err != nil {
return nil, connection, fmt.Errorf("%w: %s", ErrExtractConnection, err) return nil, connection, fmt.Errorf("cannot extract connection from file: %w", err)
} }
return lines, connection, nil return lines, connection, nil

View File

@@ -48,25 +48,20 @@ func extractDataFromLines(lines []string) (
return connection, nil return connection, nil
} }
var (
errExtractProto = errors.New("failed extracting protocol from proto line")
errExtractRemote = errors.New("failed extracting from remote line")
)
func extractDataFromLine(line string) ( func extractDataFromLine(line string) (
ip net.IP, port uint16, protocol string, err error) { ip net.IP, port uint16, protocol string, err error) {
switch { switch {
case strings.HasPrefix(line, "proto "): case strings.HasPrefix(line, "proto "):
protocol, err = extractProto(line) protocol, err = extractProto(line)
if err != nil { if err != nil {
return nil, 0, "", fmt.Errorf("%w: %s", errExtractProto, err) return nil, 0, "", fmt.Errorf("failed extracting protocol from proto line: %w", err)
} }
return nil, 0, protocol, nil return nil, 0, protocol, nil
case strings.HasPrefix(line, "remote "): case strings.HasPrefix(line, "remote "):
ip, port, protocol, err = extractRemote(line) ip, port, protocol, err = extractRemote(line)
if err != nil { if err != nil {
return nil, 0, "", fmt.Errorf("%w: %s", errExtractRemote, err) return nil, 0, "", fmt.Errorf("failed extracting from remote line: %w", err)
} }
return ip, port, protocol, nil return ip, port, protocol, nil
} }
@@ -122,7 +117,7 @@ func extractRemote(line string) (ip net.IP, port uint16,
if err != nil { if err != nil {
return nil, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line) return nil, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line)
} else if portInt < 1 || portInt > 65535 { } else if portInt < 1 || portInt > 65535 {
return nil, 0, "", fmt.Errorf("%w: not between 1 and 65535: %d", errPortNotValid, portInt) return nil, 0, "", fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt)
} }
port = uint16(portInt) port = uint16(portInt)
} }

View File

@@ -98,7 +98,7 @@ func Test_extractDataFromLine(t *testing.T) {
}, },
"extract proto error": { "extract proto error": {
line: "proto bad", line: "proto bad",
isErr: errExtractProto, isErr: errProtocolNotSupported,
}, },
"extract proto success": { "extract proto success": {
line: "proto tcp", line: "proto tcp",
@@ -106,7 +106,7 @@ func Test_extractDataFromLine(t *testing.T) {
}, },
"extract remote error": { "extract remote error": {
line: "remote bad", line: "remote bad",
isErr: errExtractRemote, isErr: errHostNotIP,
}, },
"extract remote success": { "extract remote success": {
line: "remote 1.2.3.4 1194 udp", line: "remote 1.2.3.4 1194 udp",
@@ -213,15 +213,15 @@ func Test_extractRemote(t *testing.T) {
}, },
"port is zero": { "port is zero": {
line: "remote 1.2.3.4 0", line: "remote 1.2.3.4 0",
err: errors.New("port is not valid: not between 1 and 65535: 0"), err: errors.New("port is not valid: 0 must be between 1 and 65535"),
}, },
"port is minus one": { "port is minus one": {
line: "remote 1.2.3.4 -1", line: "remote 1.2.3.4 -1",
err: errors.New("port is not valid: not between 1 and 65535: -1"), err: errors.New("port is not valid: -1 must be between 1 and 65535"),
}, },
"port is over 65535": { "port is over 65535": {
line: "remote 1.2.3.4 65536", line: "remote 1.2.3.4 65536",
err: errors.New("port is not valid: not between 1 and 65535: 65536"), err: errors.New("port is not valid: 65536 must be between 1 and 65535"),
}, },
"IP host and port": { "IP host and port": {
line: "remote 1.2.3.4 8000", line: "remote 1.2.3.4 8000",

View File

@@ -7,7 +7,7 @@ import (
func ExtractCert(b []byte) (certData string, err error) { func ExtractCert(b []byte) (certData string, err error) {
certData, err = extractPEM(b, "CERTIFICATE") certData, err = extractPEM(b, "CERTIFICATE")
if err != nil { if err != nil {
return "", fmt.Errorf("%w: %s", ErrExtractPEM, err) return "", fmt.Errorf("cannot extract PEM data: %w", err)
} }
return certData, nil return certData, nil

View File

@@ -1,7 +0,0 @@
package parse
import "errors"
var (
ErrExtractPEM = errors.New("cannot extract PEM data")
)

View File

@@ -7,7 +7,7 @@ import (
func ExtractPrivateKey(b []byte) (keyData string, err error) { func ExtractPrivateKey(b []byte) (keyData string, err error) {
keyData, err = extractPEM(b, "PRIVATE KEY") keyData, err = extractPEM(b, "PRIVATE KEY")
if err != nil { if err != nil {
return "", fmt.Errorf("%w: %s", ErrExtractPEM, err) return "", fmt.Errorf("cannot extract PEM data: %w", err)
} }
return keyData, nil return keyData, nil

View File

@@ -27,6 +27,6 @@ func (l *Loop) firewallAllowPort(ctx context.Context) {
startData := l.state.GetStartData() startData := l.state.GetStartData()
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface) err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
if err != nil { if err != nil {
l.logger.Error("cannot allow port through firewall: " + err.Error()) l.logger.Error("cannot allow port: " + err.Error())
} }
} }

View File

@@ -13,7 +13,6 @@ import (
var ( var (
ErrVPNTypeNotSupported = errors.New("VPN type not supported for custom provider") ErrVPNTypeNotSupported = errors.New("VPN type not supported for custom provider")
ErrExtractConnection = errors.New("cannot extract connection")
) )
// GetConnection gets the connection from the OpenVPN configuration file. // GetConnection gets the connection from the OpenVPN configuration file.
@@ -34,7 +33,7 @@ func getOpenVPNConnection(extractor extract.Interface,
connection models.Connection, err error) { connection models.Connection, err error) {
_, connection, err = extractor.Data(*selection.OpenVPN.ConfFile) _, connection, err = extractor.Data(*selection.OpenVPN.ConfFile)
if err != nil { if err != nil {
return connection, fmt.Errorf("%w: %s", ErrExtractConnection, err) return connection, fmt.Errorf("cannot extract connection: %w", err)
} }
connection.Port = getPort(connection.Port, selection) connection.Port = getPort(connection.Port, selection)

View File

@@ -18,7 +18,7 @@ func (p *Provider) BuildConf(connection models.Connection,
settings settings.OpenVPN) (lines []string, err error) { settings settings.OpenVPN) (lines []string, err error) {
lines, _, err = p.extractor.Data(*settings.ConfFile) lines, _, err = p.extractor.Data(*settings.ConfFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrExtractData, err) return nil, fmt.Errorf("failed extracting information from custom configuration file: %w", err)
} }
lines = modifyConfig(lines, connection, settings) lines = modifyConfig(lines, connection, settings)

View File

@@ -1,23 +1,16 @@
package ipvanish package ipvanish
import ( import (
"errors"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
var ErrProtocolUnsupported = errors.New("network protocol is not supported")
func (i *Ipvanish) GetConnection(selection settings.ServerSelection) ( func (i *Ipvanish) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
const port = 443 const port = 443
const protocol = constants.UDP const protocol = constants.UDP
if *selection.OpenVPN.TCP {
return connection, ErrProtocolUnsupported
}
servers, err := i.filterServers(selection) servers, err := i.filterServers(selection)
if err != nil { if err != nil {

View File

@@ -1,24 +1,16 @@
package privado package privado
import ( import (
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
var ErrProtocolUnsupported = errors.New("network protocol is not supported")
func (p *Privado) GetConnection(selection settings.ServerSelection) ( func (p *Privado) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
const port = 1194 const port = 1194
const protocol = constants.UDP const protocol = constants.UDP
if *selection.OpenVPN.TCP {
return connection, fmt.Errorf("%w: TCP for provider Privado", ErrProtocolUnsupported)
}
servers, err := p.filterServers(selection) servers, err := p.filterServers(selection)
if err != nil { if err != nil {

View File

@@ -4,7 +4,6 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -13,18 +12,14 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
) )
var (
ErrParseCertificate = errors.New("cannot parse X509 certificate")
)
func newHTTPClient(serverName string) (client *http.Client, err error) { func newHTTPClient(serverName string) (client *http.Client, err error) {
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PiaCAStrong) certificateBytes, err := base64.StdEncoding.DecodeString(constants.PiaCAStrong)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrParseCertificate, err) return nil, fmt.Errorf("cannot parse X509 certificate: %w", err)
} }
certificate, err := x509.ParseCertificate(certificateBytes) certificate, err := x509.ParseCertificate(certificateBytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrParseCertificate, err) return nil, fmt.Errorf("cannot parse X509 certificate: %w", err)
} }
//nolint:gomnd //nolint:gomnd

View File

@@ -23,10 +23,6 @@ import (
var ( var (
ErrGatewayIPIsNil = errors.New("gateway IP address is nil") ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
ErrServerNameEmpty = errors.New("server name is empty") ErrServerNameEmpty = errors.New("server name is empty")
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
ErrBindPort = errors.New("cannot bind port")
) )
// PortForward obtains a VPN server side port forwarded from PIA. // PortForward obtains a VPN server side port forwarded from PIA.
@@ -53,12 +49,12 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
privateIPClient, err := newHTTPClient(serverName) privateIPClient, err := newHTTPClient(serverName)
if err != nil { if err != nil {
return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err) return 0, fmt.Errorf("cannot create custom HTTP client: %w", err)
} }
data, err := readPIAPortForwardData(p.portForwardPath) data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil { if err != nil {
return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err) return 0, fmt.Errorf("cannot read saved port forwarded data: %w", err)
} }
dataFound := data.Port > 0 dataFound := data.Port > 0
@@ -79,7 +75,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath) p.portForwardPath, p.authFilePath)
if err != nil { if err != nil {
return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err) return 0, fmt.Errorf("cannot refresh port forward data: %w", err)
} }
durationToExpiration = data.Expiration.Sub(p.timeNow()) durationToExpiration = data.Expiration.Sub(p.timeNow())
} }
@@ -87,7 +83,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
// First time binding // First time binding
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
return 0, fmt.Errorf("%w: %s", ErrBindPort, err) return 0, fmt.Errorf("cannot bind port: %w", err)
} }
return data.Port, nil return data.Port, nil
@@ -101,12 +97,12 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
port uint16, gateway net.IP, serverName string) (err error) { port uint16, gateway net.IP, serverName string) (err error) {
privateIPClient, err := newHTTPClient(serverName) privateIPClient, err := newHTTPClient(serverName)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err) return fmt.Errorf("cannot create custom HTTP client: %w", err)
} }
data, err := readPIAPortForwardData(p.portForwardPath) data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err) return fmt.Errorf("cannot read saved port forwarded data: %w", err)
} }
durationToExpiration := data.Expiration.Sub(p.timeNow()) durationToExpiration := data.Expiration.Sub(p.timeNow())
@@ -128,7 +124,7 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
case <-keepAliveTimer.C: case <-keepAliveTimer.C:
err := bindPort(ctx, privateIPClient, gateway, data) err := bindPort(ctx, privateIPClient, gateway, data)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrBindPort, err) return fmt.Errorf("cannot bind port: %w", err)
} }
keepAliveTimer.Reset(keepAlivePeriod) keepAliveTimer.Reset(keepAlivePeriod)
case <-expiryTimer.C: case <-expiryTimer.C:
@@ -138,26 +134,20 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
} }
} }
var (
ErrFetchToken = errors.New("cannot fetch token")
ErrFetchPortForwarding = errors.New("cannot fetch port forwarding data")
ErrPersistPortForwarding = errors.New("cannot persist port forwarding data")
)
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client, func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) { gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, authFilePath) data.Token, err = fetchToken(ctx, client, authFilePath)
if err != nil { if err != nil {
return data, fmt.Errorf("%w: %s", ErrFetchToken, err) return data, fmt.Errorf("cannot fetch token: %w", err)
} }
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token) data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
if err != nil { if err != nil {
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err) return data, fmt.Errorf("cannot fetch port forwarding data: %w", err)
} }
if err := writePIAPortForwardData(portForwardPath, data); err != nil { if err := writePIAPortForwardData(portForwardPath, data); err != nil {
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err) return data, fmt.Errorf("cannot persist port forwarding data: %w", err)
} }
return data, nil return data, nil
@@ -242,7 +232,6 @@ func packPayload(port uint16, token string, expiration time.Time) (payload strin
} }
var ( var (
errGetCredentials = errors.New("cannot get username and password")
errEmptyToken = errors.New("token received is empty") errEmptyToken = errors.New("token received is empty")
) )
@@ -250,7 +239,7 @@ func fetchToken(ctx context.Context, client *http.Client,
authFilePath string) (token string, err error) { authFilePath string) (token string, err error) {
username, password, err := getOpenvpnCredentials(authFilePath) username, password, err := getOpenvpnCredentials(authFilePath)
if err != nil { if err != nil {
return "", fmt.Errorf("%w: %s", errGetCredentials, err) return "", fmt.Errorf("cannot get username and password: %w", err)
} }
errSubstitutions := map[string]string{ errSubstitutions := map[string]string{
@@ -284,7 +273,7 @@ func fetchToken(ctx context.Context, client *http.Client,
Token string `json:"token"` Token string `json:"token"`
} }
if err := decoder.Decode(&result); err != nil { if err := decoder.Decode(&result); err != nil {
return "", fmt.Errorf("%w: %s", ErrUnmarshalResponse, err) return "", fmt.Errorf("cannot unmarshal response: %w", err)
} }
if result.Token == "" { if result.Token == "" {
@@ -294,7 +283,6 @@ func fetchToken(ctx context.Context, client *http.Client,
} }
var ( var (
errAuthFileRead = errors.New("cannot read OpenVPN authentication file")
errAuthFileMalformed = errors.New("authentication file is malformed") errAuthFileMalformed = errors.New("authentication file is malformed")
) )
@@ -302,13 +290,13 @@ func getOpenvpnCredentials(authFilePath string) (
username, password string, err error) { username, password string, err error) {
file, err := os.Open(authFilePath) file, err := os.Open(authFilePath)
if err != nil { if err != nil {
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err) return "", "", fmt.Errorf("cannot read OpenVPN authentication file: %w", err)
} }
authData, err := io.ReadAll(file) authData, err := io.ReadAll(file)
if err != nil { if err != nil {
_ = file.Close() _ = file.Close()
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err) return "", "", fmt.Errorf("authentication file is malformed: %w", err)
} }
if err := file.Close(); err != nil { if err := file.Close(); err != nil {
@@ -325,11 +313,6 @@ func getOpenvpnCredentials(authFilePath string) (
return username, password, nil return username, password, nil
} }
var (
errGetSignaturePayload = errors.New("cannot obtain signature payload")
errUnpackPayload = errors.New("cannot unpack payload data")
)
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) ( func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (
port uint16, signature string, expiration time.Time, err error) { port uint16, signature string, expiration time.Time, err error) {
errSubstitutions := map[string]string{token: "<token>"} errSubstitutions := map[string]string{token: "<token>"}
@@ -345,13 +328,13 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil { if err != nil {
err = replaceInErr(err, errSubstitutions) err = replaceInErr(err, errSubstitutions)
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err) return 0, "", expiration, fmt.Errorf("cannot obtain signature payload: %w", err)
} }
response, err := client.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
err = replaceInErr(err, errSubstitutions) err = replaceInErr(err, errSubstitutions)
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err) return 0, "", expiration, fmt.Errorf("cannot obtain signature payload: %w", err)
} }
defer response.Body.Close() defer response.Body.Close()
@@ -366,7 +349,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
Signature string `json:"signature"` Signature string `json:"signature"`
} }
if err := decoder.Decode(&data); err != nil { if err := decoder.Decode(&data); err != nil {
return 0, "", expiration, fmt.Errorf("%w: %s", ErrUnmarshalResponse, err) return 0, "", expiration, fmt.Errorf("cannot unmarshal response: %w", err)
} }
if data.Status != "OK" { if data.Status != "OK" {
@@ -375,21 +358,19 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
port, _, expiration, err = unpackPayload(data.Payload) port, _, expiration, err = unpackPayload(data.Payload)
if err != nil { if err != nil {
return 0, "", expiration, fmt.Errorf("%w: %s", errUnpackPayload, err) return 0, "", expiration, fmt.Errorf("cannot unpack payload data: %w", err)
} }
return port, data.Signature, expiration, err return port, data.Signature, expiration, err
} }
var ( var (
ErrSerializePayload = errors.New("cannot serialize payload")
ErrUnmarshalResponse = errors.New("cannot unmarshal response")
ErrBadResponse = errors.New("bad response received") ErrBadResponse = errors.New("bad response received")
) )
func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) { func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
payload, err := packPayload(data.Port, data.Token, data.Expiration) payload, err := packPayload(data.Port, data.Token, data.Expiration)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrSerializePayload, err) return fmt.Errorf("cannot serialize payload: %w", err)
} }
queryParams := make(url.Values) queryParams := make(url.Values)
@@ -428,7 +409,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
Message string `json:"message"` Message string `json:"message"`
} }
if err := decoder.Decode(&responseData); err != nil { if err := decoder.Decode(&responseData); err != nil {
return fmt.Errorf("%w: from %s: %s", ErrUnmarshalResponse, url.String(), err) return fmt.Errorf("cannot unmarshal response: from %s: %w", url.String(), err)
} }
if responseData.Status != "OK" { if responseData.Status != "OK" {
@@ -464,6 +445,7 @@ func makeNOKStatusError(response *http.Response, substitutions map[string]string
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ") shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
shortenMessage = replaceInString(shortenMessage, substitutions) shortenMessage = replaceInString(shortenMessage, substitutions)
return fmt.Errorf("%w: %s: %s: response received: %s", return fmt.Errorf("%w: %s: %d %s: response received: %s",
ErrHTTPStatusCodeNotOK, url, response.Status, shortenMessage) ErrHTTPStatusCodeNotOK, url, response.StatusCode,
response.Status, shortenMessage)
} }

View File

@@ -1,23 +1,16 @@
package vpnunlimited package vpnunlimited
import ( import (
"errors"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
var ErrProtocolUnsupported = errors.New("network protocol is not supported")
func (p *Provider) GetConnection(selection settings.ServerSelection) ( func (p *Provider) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
const port = 1194 const port = 1194
const protocol = constants.UDP const protocol = constants.UDP
if *selection.OpenVPN.TCP {
return connection, ErrProtocolUnsupported
}
servers, err := p.filterServers(selection) servers, err := p.filterServers(selection)
if err != nil { if err != nil {

View File

@@ -1,24 +1,16 @@
package vyprvpn package vyprvpn
import ( import (
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
) )
var ErrProtocolUnsupported = errors.New("network protocol is not supported")
func (v *Vyprvpn) GetConnection(selection settings.ServerSelection) ( func (v *Vyprvpn) GetConnection(selection settings.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
const port = 443 const port = 443
const protocol = constants.UDP const protocol = constants.UDP
if *selection.OpenVPN.TCP {
return connection, fmt.Errorf("%w: TCP for provider VyprVPN", ErrProtocolUnsupported)
}
servers, err := v.filterServers(selection) servers, err := v.filterServers(selection)
if err != nil { if err != nil {

View File

@@ -4,5 +4,4 @@ import "errors"
var ( var (
ErrBadStatusCode = errors.New("bad HTTP status") ErrBadStatusCode = errors.New("bad HTTP status")
ErrCannotReadBody = errors.New("cannot read response body")
) )

View File

@@ -56,12 +56,13 @@ func (f *Fetch) FetchPublicIP(ctx context.Context) (ip net.IP, err error) {
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status) return nil, fmt.Errorf("%w from %s: %d %s", ErrBadStatusCode,
url, response.StatusCode, response.Status)
} }
content, err := io.ReadAll(response.Body) content, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err) return nil, fmt.Errorf("cannot ready response body: %w", err)
} }
s := strings.ReplaceAll(string(content), "\n", "") s := strings.ReplaceAll(string(content), "\n", "")

View File

@@ -38,7 +38,8 @@ func Info(ctx context.Context, client *http.Client, ip net.IP) ( //nolint:interf
case http.StatusTooManyRequests: case http.StatusTooManyRequests:
return result, fmt.Errorf("%w: %s", ErrTooManyRequests, baseURL) return result, fmt.Errorf("%w: %s", ErrTooManyRequests, baseURL)
default: default:
return result, fmt.Errorf("%w: %d", ErrBadHTTPStatus, response.StatusCode) return result, fmt.Errorf("%w: %d %s", ErrBadHTTPStatus,
response.StatusCode, response.Status)
} }
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)

View File

@@ -19,7 +19,7 @@ type DefaultRouteGetter interface {
func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("%w: %s", ErrRoutesList, err) return "", nil, fmt.Errorf("cannot list routes: %w", err)
} }
for _, route := range routes { for _, route := range routes {
if route.Dst == nil { if route.Dst == nil {
@@ -27,7 +27,7 @@ func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
linkIndex := route.LinkIndex linkIndex := route.LinkIndex
link, err := r.netLinker.LinkByIndex(linkIndex) link, err := r.netLinker.LinkByIndex(linkIndex)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err) return "", nil, fmt.Errorf("cannot obtain link by index: for default route at index %d: %w", linkIndex, err)
} }
attributes := link.Attrs() attributes := link.Attrs()
defaultInterface = attributes.Name defaultInterface = attributes.Name
@@ -46,7 +46,7 @@ type DefaultIPGetter interface {
func (r *Routing) DefaultIP() (ip net.IP, err error) { func (r *Routing) DefaultIP() (ip net.IP, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) return nil, fmt.Errorf("cannot list routes: %w", err)
} }
defaultLinkName := "" defaultLinkName := ""
@@ -55,7 +55,7 @@ func (r *Routing) DefaultIP() (ip net.IP, err error) {
linkIndex := route.LinkIndex linkIndex := route.LinkIndex
link, err := r.netLinker.LinkByIndex(linkIndex) link, err := r.netLinker.LinkByIndex(linkIndex)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err) return nil, fmt.Errorf("cannot find link by index: for default route at index %d: %w", linkIndex, err)
} }
defaultLinkName = link.Attrs().Name defaultLinkName = link.Attrs().Name
} }

View File

@@ -1,17 +1,9 @@
package routing package routing
import ( import (
"errors"
"fmt" "fmt"
) )
var (
ErrDefaultRoute = errors.New("cannot get default route")
ErrAddInboundFromDefault = errors.New("cannot add routes for inbound traffic from default IP")
ErrDelInboundFromDefault = errors.New("cannot remove routes for inbound traffic from default IP")
ErrSubnetsOutboundSet = errors.New("cannot set outbound subnets routes")
)
type Setuper interface { type Setuper interface {
Setup() (err error) Setup() (err error)
} }
@@ -19,7 +11,7 @@ type Setuper interface {
func (r *Routing) Setup() (err error) { func (r *Routing) Setup() (err error) {
defaultInterfaceName, defaultGateway, err := r.DefaultRoute() defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrDefaultRoute, err) return fmt.Errorf("cannot get default route: %w", err)
} }
touched := false touched := false
@@ -35,14 +27,14 @@ func (r *Routing) Setup() (err error) {
err = r.routeInboundFromDefault(defaultGateway, defaultInterfaceName) err = r.routeInboundFromDefault(defaultGateway, defaultInterfaceName)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrAddInboundFromDefault, err) return fmt.Errorf("cannot add routes for inbound traffic from default IP: %w", err)
} }
r.stateMutex.RLock() r.stateMutex.RLock()
outboundSubnets := r.outboundSubnets outboundSubnets := r.outboundSubnets
r.stateMutex.RUnlock() r.stateMutex.RUnlock()
if err := r.setOutboundRoutes(outboundSubnets, defaultInterfaceName, defaultGateway); err != nil { if err := r.setOutboundRoutes(outboundSubnets, defaultInterfaceName, defaultGateway); err != nil {
return fmt.Errorf("%w: %s", ErrSubnetsOutboundSet, err) return fmt.Errorf("cannot set outbound subnets routes: %w", err)
} }
return nil return nil
@@ -55,16 +47,16 @@ type TearDowner interface {
func (r *Routing) TearDown() error { func (r *Routing) TearDown() error {
defaultInterfaceName, defaultGateway, err := r.DefaultRoute() defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrDefaultRoute, err) return fmt.Errorf("cannot get default route: %w", err)
} }
err = r.unrouteInboundFromDefault(defaultGateway, defaultInterfaceName) err = r.unrouteInboundFromDefault(defaultGateway, defaultInterfaceName)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrDelInboundFromDefault, err) return fmt.Errorf("cannot remove routes for inbound traffic from default IP: %w", err)
} }
if err := r.setOutboundRoutes(nil, defaultInterfaceName, defaultGateway); err != nil { if err := r.setOutboundRoutes(nil, defaultInterfaceName, defaultGateway); err != nil {
return fmt.Errorf("%w: %s", ErrSubnetsOutboundSet, err) return fmt.Errorf("cannot set outbound subnets routes: %w", err)
} }
return nil return nil

View File

@@ -5,7 +5,5 @@ import (
) )
var ( var (
ErrLinkByIndex = errors.New("cannot obtain link by index")
ErrLinkDefaultNotFound = errors.New("default link not found") ErrLinkDefaultNotFound = errors.New("default link not found")
ErrRoutesList = errors.New("cannot list routes")
) )

View File

@@ -1,7 +1,6 @@
package routing package routing
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
@@ -13,19 +12,15 @@ const (
inboundPriority = 100 inboundPriority = 100
) )
var (
errDefaultIP = errors.New("cannot get default IP address")
)
func (r *Routing) routeInboundFromDefault(defaultGateway net.IP, func (r *Routing) routeInboundFromDefault(defaultGateway net.IP,
defaultInterface string) (err error) { defaultInterface string) (err error) {
if err := r.addRuleInboundFromDefault(inboundTable); err != nil { if err := r.addRuleInboundFromDefault(inboundTable); err != nil {
return fmt.Errorf("%w: %s", errRuleAdd, err) return fmt.Errorf("cannot add rule: %w", err)
} }
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)} defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
if err := r.addRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil { if err := r.addRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil {
return fmt.Errorf("%w: %s", errRouteAdd, err) return fmt.Errorf("cannot add route: %w", err)
} }
return nil return nil
@@ -35,11 +30,11 @@ func (r *Routing) unrouteInboundFromDefault(defaultGateway net.IP,
defaultInterface string) (err error) { defaultInterface string) (err error) {
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)} defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
if err := r.deleteRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil { if err := r.deleteRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil {
return fmt.Errorf("%w: %s", errRouteDelete, err) return fmt.Errorf("cannot delete route: %w", err)
} }
if err := r.delRuleInboundFromDefault(inboundTable); err != nil { if err := r.delRuleInboundFromDefault(inboundTable); err != nil {
return fmt.Errorf("%w: %s", errRuleDelete, err) return fmt.Errorf("cannot delete rule: %w", err)
} }
return nil return nil
@@ -48,14 +43,14 @@ func (r *Routing) unrouteInboundFromDefault(defaultGateway net.IP,
func (r *Routing) addRuleInboundFromDefault(table int) (err error) { func (r *Routing) addRuleInboundFromDefault(table int) (err error) {
defaultIP, err := r.DefaultIP() defaultIP, err := r.DefaultIP()
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errDefaultIP, err) return fmt.Errorf("cannot find default IP: %w", err)
} }
defaultIPMasked32 := netlink.NewIPNet(defaultIP) defaultIPMasked32 := netlink.NewIPNet(defaultIP)
ruleDstNet := (*net.IPNet)(nil) ruleDstNet := (*net.IPNet)(nil)
err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority) err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errRuleAdd, err) return fmt.Errorf("cannot add rule: %w", err)
} }
return nil return nil
@@ -64,14 +59,14 @@ func (r *Routing) addRuleInboundFromDefault(table int) (err error) {
func (r *Routing) delRuleInboundFromDefault(table int) (err error) { func (r *Routing) delRuleInboundFromDefault(table int) (err error) {
defaultIP, err := r.DefaultIP() defaultIP, err := r.DefaultIP()
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errDefaultIP, err) return fmt.Errorf("cannot find default IP: %w", err)
} }
defaultIPMasked32 := netlink.NewIPNet(defaultIP) defaultIPMasked32 := netlink.NewIPNet(defaultIP)
ruleDstNet := (*net.IPNet)(nil) ruleDstNet := (*net.IPNet)(nil)
err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority) err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errRuleDelete, err) return fmt.Errorf("cannot delete rule: %w", err)
} }
return nil return nil

View File

@@ -13,18 +13,16 @@ func IPIsPrivate(ip net.IP) bool {
var ( var (
errInterfaceIPNotFound = errors.New("IP address not found for interface") errInterfaceIPNotFound = errors.New("IP address not found for interface")
errInterfaceListAddr = errors.New("cannot list interface addresses")
errInterfaceNotFound = errors.New("network interface not found")
) )
func (r *Routing) assignedIP(interfaceName string) (ip net.IP, err error) { func (r *Routing) assignedIP(interfaceName string) (ip net.IP, err error) {
iface, err := net.InterfaceByName(interfaceName) iface, err := net.InterfaceByName(interfaceName)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s: %s", errInterfaceNotFound, interfaceName, err) return nil, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
} }
addresses, err := iface.Addrs() addresses, err := iface.Addrs()
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s: %s", errInterfaceListAddr, interfaceName, err) return nil, fmt.Errorf("cannot list interface %s addresses: %w", interfaceName, err)
} }
for _, address := range addresses { for _, address := range addresses {
switch value := address.(type) { switch value := address.(type) {

View File

@@ -9,7 +9,6 @@ import (
) )
var ( var (
ErrLinkList = errors.New("cannot list links")
ErrLinkLocalNotFound = errors.New("local link not found") ErrLinkLocalNotFound = errors.New("local link not found")
ErrSubnetDefaultNotFound = errors.New("default subnet not found") ErrSubnetDefaultNotFound = errors.New("default subnet not found")
ErrSubnetLocalNotFound = errors.New("local subnet not found") ErrSubnetLocalNotFound = errors.New("local subnet not found")
@@ -28,7 +27,7 @@ type LocalSubnetGetter interface {
func (r *Routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { func (r *Routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil { if err != nil {
return defaultSubnet, fmt.Errorf("%w: %s", ErrRoutesList, err) return defaultSubnet, fmt.Errorf("cannot list routes: %w", err)
} }
defaultLinkIndex := -1 defaultLinkIndex := -1
@@ -61,7 +60,7 @@ type LocalNetworksGetter interface {
func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
links, err := r.netLinker.LinkList() links, err := r.netLinker.LinkList()
if err != nil { if err != nil {
return localNetworks, fmt.Errorf("%w: %s", ErrLinkList, err) return localNetworks, fmt.Errorf("cannot list links: %w", err)
} }
localLinks := make(map[int]struct{}) localLinks := make(map[int]struct{})
@@ -81,7 +80,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_V4) routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_V4)
if err != nil { if err != nil {
return localNetworks, fmt.Errorf("%w: %s", ErrRoutesList, err) return localNetworks, fmt.Errorf("cannot list routes: %w", err)
} }
for _, route := range routes { for _, route := range routes {
@@ -98,7 +97,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
link, err := r.netLinker.LinkByIndex(route.LinkIndex) link, err := r.netLinker.LinkByIndex(route.LinkIndex)
if err != nil { if err != nil {
return localNetworks, fmt.Errorf("%w: at index %d: %s", ErrLinkByIndex, route.LinkIndex, err) return localNetworks, fmt.Errorf("cannot find link at index %d: %w", route.LinkIndex, err)
} }
localNet.InterfaceName = link.Attrs().Name localNet.InterfaceName = link.Attrs().Name

View File

@@ -1,7 +1,6 @@
package routing package routing
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
@@ -13,10 +12,6 @@ const (
outboundPriority = 99 outboundPriority = 99
) )
var (
errAddOutboundSubnet = errors.New("cannot add outbound subnet to routes")
)
type OutboundRoutesSetter interface { type OutboundRoutesSetter interface {
SetOutboundRoutes(outboundSubnets []net.IPNet) error SetOutboundRoutes(outboundSubnets []net.IPNet) error
} }
@@ -48,7 +43,7 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet,
err = r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway) err = r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errAddOutboundSubnet, err) return fmt.Errorf("cannot add outbound subnet to routes: %w", err)
} }
return nil return nil
@@ -68,7 +63,7 @@ func (r *Routing) removeOutboundSubnets(subnets []net.IPNet,
err = r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) err = r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil { if err != nil {
warnings = append(warnings, warnings = append(warnings,
errRuleDelete.Error()+": for subnet "+subNet.String()+": "+err.Error()) "cannot delete rule: for subnet "+subNet.String()+": "+err.Error())
continue continue
} }
@@ -83,16 +78,14 @@ func (r *Routing) addOutboundSubnets(subnets []net.IPNet,
for i, subnet := range subnets { for i, subnet := range subnets {
err := r.addRouteVia(subnet, defaultGateway, defaultInterfaceName, outboundTable) err := r.addRouteVia(subnet, defaultGateway, defaultInterfaceName, outboundTable)
if err != nil { if err != nil {
return fmt.Errorf("%w: for subnet %s: %s", return fmt.Errorf("cannot add route for subnet %s: %w", subnet, err)
errRouteAdd, subnet, err)
} }
ruleSrcNet := (*net.IPNet)(nil) ruleSrcNet := (*net.IPNet)(nil)
ruleDstNet := &subnets[i] ruleDstNet := &subnets[i]
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil { if err != nil {
return fmt.Errorf("%w: for subnet %s: %s", return fmt.Errorf("cannot add rule: for subnet %s: %w", subnet, err)
errRuleAdd, subnet, err)
} }
r.outboundSubnets = append(r.outboundSubnets, subnet) r.outboundSubnets = append(r.outboundSubnets, subnet)

View File

@@ -1,7 +1,6 @@
package routing package routing
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@@ -9,12 +8,6 @@ import (
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
) )
var (
errLinkByName = errors.New("cannot obtain link by name")
errRouteAdd = errors.New("cannot add route")
errRouteDelete = errors.New("cannot delete route")
)
func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
iface string, table int) error { iface string, table int) error {
destinationStr := destination.String() destinationStr := destination.String()
@@ -26,7 +19,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
link, err := r.netLinker.LinkByName(iface) link, err := r.netLinker.LinkByName(iface)
if err != nil { if err != nil {
return fmt.Errorf("%w: interface %s: %s", errLinkByName, iface, err) return fmt.Errorf("cannot find link for interface %s: %w", iface, err)
} }
route := netlink.Route{ route := netlink.Route{
@@ -36,8 +29,8 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
Table: table, Table: table,
} }
if err := r.netLinker.RouteReplace(&route); err != nil { if err := r.netLinker.RouteReplace(&route); err != nil {
return fmt.Errorf("%w: for subnet %s at interface %s", return fmt.Errorf("cannot replace route for subnet %s at interface %s: %w",
err, destinationStr, iface) destinationStr, iface, err)
} }
return nil return nil
@@ -54,7 +47,7 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP,
link, err := r.netLinker.LinkByName(iface) link, err := r.netLinker.LinkByName(iface)
if err != nil { if err != nil {
return fmt.Errorf("%w: for interface %s: %s", errLinkByName, iface, err) return fmt.Errorf("cannot find link for interface %s: %w", iface, err)
} }
route := netlink.Route{ route := netlink.Route{
@@ -64,8 +57,8 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP,
Table: table, Table: table,
} }
if err := r.netLinker.RouteDel(&route); err != nil { if err := r.netLinker.RouteDel(&route); err != nil {
return fmt.Errorf("%w: for subnet %s at interface %s", return fmt.Errorf("cannot delete route: for subnet %s at interface %s: %w",
err, destinationStr, iface) destinationStr, iface, err)
} }
return nil return nil

View File

@@ -2,19 +2,12 @@ package routing
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"net" "net"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
) )
var (
errRulesList = errors.New("cannot list rules")
errRuleAdd = errors.New("cannot add rule")
errRuleDelete = errors.New("cannot delete rule")
)
func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error { func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
const add = true const add = true
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority)) r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
@@ -27,7 +20,7 @@ func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errRulesList, err) return fmt.Errorf("cannot list rules: %w", err)
} }
for i := range existingRules { for i := range existingRules {
if !rulesAreEqual(&existingRules[i], rule) { if !rulesAreEqual(&existingRules[i], rule) {
@@ -37,7 +30,7 @@ func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
} }
if err := r.netLinker.RuleAdd(rule); err != nil { if err := r.netLinker.RuleAdd(rule); err != nil {
return fmt.Errorf("%w: for rule: %s", err, rule) return fmt.Errorf("cannot add rule %s: %w", rule, err)
} }
return nil return nil
} }
@@ -54,14 +47,14 @@ func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error {
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errRulesList, err) return fmt.Errorf("cannot list rules: %w", err)
} }
for i := range existingRules { for i := range existingRules {
if !rulesAreEqual(&existingRules[i], rule) { if !rulesAreEqual(&existingRules[i], rule) {
continue continue
} }
if err := r.netLinker.RuleDel(rule); err != nil { if err := r.netLinker.RuleDel(rule); err != nil {
return fmt.Errorf("%w: for rule: %s", err, rule) return fmt.Errorf("cannot delete rule %s: %w", rule, err)
} }
} }
return nil return nil

View File

@@ -88,7 +88,7 @@ func Test_Routing_addIPRule(t *testing.T) {
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
err: errDummy, err: errDummy,
}, },
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"), err: errors.New("cannot add rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"),
}, },
"add rule success": { "add rule success": {
src: makeIPNet(t, 1), src: makeIPNet(t, 1),
@@ -193,7 +193,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
err: errDummy, err: errDummy,
}, },
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"), err: errors.New("cannot delete rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"),
}, },
"rule deleted": { "rule deleted": {
src: makeIPNet(t, 1), src: makeIPNet(t, 1),

View File

@@ -21,7 +21,7 @@ type VPNDestinationIPGetter interface {
func (r *Routing) VPNDestinationIP() (ip net.IP, err error) { func (r *Routing) VPNDestinationIP() (ip net.IP, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) return nil, fmt.Errorf("cannot list routes: %w", err)
} }
defaultLinkIndex := -1 defaultLinkIndex := -1
@@ -53,12 +53,12 @@ type VPNLocalGatewayIPGetter interface {
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) { func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) return nil, fmt.Errorf("cannot list routes: %w", err)
} }
for _, route := range routes { for _, route := range routes {
link, err := r.netLinker.LinkByIndex(route.LinkIndex) link, err := r.netLinker.LinkByIndex(route.LinkIndex)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err) return nil, fmt.Errorf("cannot find link at index %d: %w", route.LinkIndex, err)
} }
interfaceName := link.Attrs().Name interfaceName := link.Attrs().Name
if interfaceName == vpnIntf && if interfaceName == vpnIntf &&

View File

@@ -35,8 +35,6 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
} }
var ( var (
errDecodeVersions = errors.New("cannot decode versions")
errDecodeServers = errors.New("cannot decode servers")
errDecodeProvider = errors.New("cannot decode servers for provider") errDecodeProvider = errors.New("cannot decode servers for provider")
) )
@@ -44,12 +42,12 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
servers models.AllServers, err error) { servers models.AllServers, err error) {
var versions allVersions var versions allVersions
if err := json.Unmarshal(b, &versions); err != nil { if err := json.Unmarshal(b, &versions); err != nil {
return servers, fmt.Errorf("%w: %s", errDecodeVersions, err) return servers, fmt.Errorf("cannot decode versions: %w", err)
} }
var rawMessages allJSONRawMessages var rawMessages allJSONRawMessages
if err := json.Unmarshal(b, &rawMessages); err != nil { if err := json.Unmarshal(b, &rawMessages); err != nil {
return servers, fmt.Errorf("%w: %s", errDecodeServers, err) return servers, fmt.Errorf("cannot decode servers: %w", err)
} }
// TODO simplify with generics in Go 1.18 // TODO simplify with generics in Go 1.18

View File

@@ -1,18 +1,12 @@
package storage package storage
import ( import (
"errors"
"fmt" "fmt"
"reflect" "reflect"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
var (
ErrCannotReadFile = errors.New("cannot read servers from file")
ErrCannotWriteFile = errors.New("cannot write servers to file")
)
func countServers(allServers models.AllServers) int { func countServers(allServers models.AllServers) int {
return len(allServers.Cyberghost.Servers) + return len(allServers.Cyberghost.Servers) +
len(allServers.Expressvpn.Servers) + len(allServers.Expressvpn.Servers) +
@@ -39,7 +33,7 @@ func countServers(allServers models.AllServers) int {
func (s *Storage) SyncServers() (err error) { func (s *Storage) SyncServers() (err error) {
serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers) serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrCannotReadFile, err) return fmt.Errorf("cannot read servers from file: %w", err)
} }
hardcodedCount := countServers(s.hardcodedServers) hardcodedCount := countServers(s.hardcodedServers)
@@ -64,7 +58,7 @@ func (s *Storage) SyncServers() (err error) {
} }
if err := flushToFile(s.filepath, s.mergedServers); err != nil { if err := flushToFile(s.filepath, s.mergedServers); err != nil {
return fmt.Errorf("%w: %s", ErrCannotWriteFile, err) return fmt.Errorf("cannot write servers to file: %w", err)
} }
return nil return nil
} }

View File

@@ -12,24 +12,21 @@ type Checker interface {
} }
var ( var (
ErrTUNNotAvailable = errors.New("TUN device is not available")
ErrTUNStat = errors.New("cannot stat TUN file")
ErrTUNInfo = errors.New("cannot get syscall stat info of TUN file") ErrTUNInfo = errors.New("cannot get syscall stat info of TUN file")
ErrTUNBadRdev = errors.New("TUN file has an unexpected rdev") ErrTUNBadRdev = errors.New("TUN file has an unexpected rdev")
ErrTUNClose = errors.New("cannot close TUN device")
) )
// Check checks the tunnel device specified by path is present and accessible. // Check checks the tunnel device specified by path is present and accessible.
func (t *Tun) Check(path string) error { func (t *Tun) Check(path string) error {
f, err := os.OpenFile(path, os.O_RDWR, 0) f, err := os.OpenFile(path, os.O_RDWR, 0)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrTUNNotAvailable, err) return fmt.Errorf("TUN device is not available: %w", err)
} }
defer f.Close() defer f.Close()
info, err := f.Stat() info, err := f.Stat()
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrTUNStat, err) return fmt.Errorf("cannot stat TUN file: %w", err)
} }
sys, ok := info.Sys().(*syscall.Stat_t) sys, ok := info.Sys().(*syscall.Stat_t)
@@ -44,7 +41,7 @@ func (t *Tun) Check(path string) error {
} }
if err := f.Close(); err != nil { if err := f.Close(); err != nil {
return fmt.Errorf("%w: %s", ErrTUNClose, err) return fmt.Errorf("cannot close TUN device: %w", err)
} }
return nil return nil

View File

@@ -1,7 +1,6 @@
package tun package tun
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -13,12 +12,6 @@ type Creator interface {
Create(path string) error Create(path string) error
} }
var (
ErrMknod = errors.New("cannot create TUN device file node")
ErrUnixOpen = errors.New("cannot Unix Open TUN device file")
ErrSetNonBlock = errors.New("cannot set non block to TUN device file descriptor")
)
// Create creates a TUN device at the path specified. // Create creates a TUN device at the path specified.
func (t *Tun) Create(path string) error { func (t *Tun) Create(path string) error {
parentDir := filepath.Dir(path) parentDir := filepath.Dir(path)
@@ -33,18 +26,18 @@ func (t *Tun) Create(path string) error {
dev := unix.Mkdev(major, minor) dev := unix.Mkdev(major, minor)
err := t.mknod(path, unix.S_IFCHR, int(dev)) err := t.mknod(path, unix.S_IFCHR, int(dev))
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrMknod, err) return fmt.Errorf("cannot create TUN device file node: %w", err)
} }
fd, err := unix.Open(path, 0, 0) fd, err := unix.Open(path, 0, 0)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrUnixOpen, err) return fmt.Errorf("cannot Unix Open TUN device file: %w", err)
} }
const nonBlocking = true const nonBlocking = true
err = unix.SetNonblock(fd, nonBlocking) err = unix.SetNonblock(fd, nonBlocking)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", ErrSetNonBlock, err) return fmt.Errorf("cannot set non block to TUN device file descriptor: %w", err)
} }
return nil return nil

View File

@@ -9,11 +9,7 @@ import (
) )
var ( var (
errBuildRequest = errors.New("cannot build HTTP request")
errDoRequest = errors.New("failed doing HTTP request")
errHTTPStatusCodeNotOK = errors.New("HTTP status code not OK") errHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
errUnmarshalResponseBody = errors.New("failed unmarshaling response body")
errCloseBody = errors.New("failed closing HTTP body")
) )
type apiData struct { type apiData struct {
@@ -40,12 +36,12 @@ func fetchAPI(ctx context.Context, client *http.Client) (
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return data, fmt.Errorf("%w: %s", errBuildRequest, err) return data, err
} }
response, err := client.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
return data, fmt.Errorf("%w: %s", errDoRequest, err) return data, err
} }
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
@@ -57,11 +53,11 @@ func fetchAPI(ctx context.Context, client *http.Client) (
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)
if err := decoder.Decode(&data); err != nil { if err := decoder.Decode(&data); err != nil {
_ = response.Body.Close() _ = response.Body.Close()
return data, fmt.Errorf("%w: %s", errUnmarshalResponseBody, err) return data, fmt.Errorf("failed unmarshaling response body: %w", err)
} }
if err := response.Body.Close(); err != nil { if err := response.Body.Close(); err != nil {
return data, fmt.Errorf("%w: %s", errCloseBody, err) return data, fmt.Errorf("cannot close response body: %w", err)
} }
return data, nil return data, nil

View File

@@ -14,7 +14,6 @@ import (
) )
var ( var (
ErrFetchAPI = errors.New("failed fetching API")
ErrNotEnoughServers = errors.New("not enough servers found") ErrNotEnoughServers = errors.New("not enough servers found")
) )
@@ -23,7 +22,7 @@ func GetServers(ctx context.Context, client *http.Client,
servers []models.IvpnServer, warnings []string, err error) { servers []models.IvpnServer, warnings []string, err error) {
data, err := fetchAPI(ctx, client) data, err := fetchAPI(ctx, client)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrFetchAPI, err) return nil, nil, fmt.Errorf("failed fetching API: %w", err)
} }
hosts := make([]string, 0, len(data.Servers)) hosts := make([]string, 0, len(data.Servers))

View File

@@ -41,7 +41,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (data []serverData, err
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %s", ErrHTTPStatusCodeNotOK, response.Status) return nil, fmt.Errorf("%w: %d %s", ErrHTTPStatusCodeNotOK,
response.StatusCode, response.Status)
} }
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)

View File

@@ -10,7 +10,6 @@ import (
var ( var (
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK") ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
ErrUnmarshalResponseBody = errors.New("failed unmarshaling response body")
) )
type serverData struct { type serverData struct {
@@ -44,7 +43,7 @@ func fetchAPI(ctx context.Context, client *http.Client) (data []serverData, err
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)
if err := decoder.Decode(&data); err != nil { if err := decoder.Decode(&data); err != nil {
return nil, fmt.Errorf("%w: %s", ErrUnmarshalResponseBody, err) return nil, fmt.Errorf("failed unmarshaling response body: %w", err)
} }
if err := response.Body.Close(); err != nil { if err := response.Body.Close(); err != nil {

View File

@@ -51,7 +51,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
return data, fmt.Errorf("%w: %s", ErrHTTPStatusCodeNotOK, response.Status) return data, fmt.Errorf("%w: %d %s", ErrHTTPStatusCodeNotOK,
response.StatusCode, response.Status)
} }
b, err := io.ReadAll(response.Body) b, err := io.ReadAll(response.Body)

View File

@@ -44,8 +44,7 @@ func parseFilename(fileName string) (
parts := strings.Split(s, "-") parts := strings.Split(s, "-")
const minParts = 2 const minParts = 2
if len(parts) < minParts { if len(parts) < minParts {
return "", "", fmt.Errorf("%w: %s", return "", "", fmt.Errorf("%w: %s", errNotEnoughParts, fileName)
errNotEnoughParts, fileName)
} }
countryCode, city = parts[0], parts[1] countryCode, city = parts[0], parts[1]

View File

@@ -11,7 +11,6 @@ import (
var ( var (
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK") ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
ErrUnmarshalResponseBody = errors.New("failed unmarshaling response body")
) )
type apiData struct { type apiData struct {
@@ -49,12 +48,13 @@ func fetchAPI(ctx context.Context, client *http.Client) (
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
return data, fmt.Errorf("%w: %s", ErrHTTPStatusCodeNotOK, response.Status) return data, fmt.Errorf("%w: %d %s", ErrHTTPStatusCodeNotOK,
response.StatusCode, response.Status)
} }
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)
if err := decoder.Decode(&data); err != nil { if err := decoder.Decode(&data); err != nil {
return data, fmt.Errorf("%w: %s", ErrUnmarshalResponseBody, err) return data, fmt.Errorf("failed unmarshaling response body: %w", err)
} }
if err := response.Body.Close(); err != nil { if err := response.Body.Close(); err != nil {

View File

@@ -34,7 +34,6 @@ func addServersFromAPI(ctx context.Context, client *http.Client,
var ( var (
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK") ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
ErrUnmarshalResponseBody = errors.New("failed unmarshaling response body")
) )
type serverData struct { type serverData struct {
@@ -66,7 +65,7 @@ func fetchAPI(ctx context.Context, client *http.Client) (
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)
if err := decoder.Decode(&servers); err != nil { if err := decoder.Decode(&servers); err != nil {
return nil, fmt.Errorf("%w: %s", ErrUnmarshalResponseBody, err) return nil, fmt.Errorf("failed unmarshaling response body: %w", err)
} }
if err := response.Body.Close(); err != nil { if err := response.Body.Close(); err != nil {

View File

@@ -14,8 +14,6 @@ import (
) )
var ( var (
ErrGetZip = errors.New("cannot get OpenVPN ZIP file")
ErrGetAPI = errors.New("cannot fetch server information from API")
ErrNotEnoughServers = errors.New("not enough servers found") ErrNotEnoughServers = errors.New("not enough servers found")
) )
@@ -26,12 +24,12 @@ func GetServers(ctx context.Context, unzipper unzip.Unzipper,
err = addServersFromAPI(ctx, client, hts) err = addServersFromAPI(ctx, client, hts)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrGetAPI, err) return nil, nil, fmt.Errorf("cannot fetch server information from API: %w", err)
} }
warnings, err = addOpenVPNServersFromZip(ctx, unzipper, hts) warnings, err = addOpenVPNServersFromZip(ctx, unzipper, hts)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrGetZip, err) return nil, nil, fmt.Errorf("cannot get OpenVPN ZIP file: %w", err)
} }
getRemainingServers(hts) getRemainingServers(hts)

View File

@@ -13,7 +13,6 @@ import (
var ( var (
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK") ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
ErrUnmarshalResponseBody = errors.New("failed unmarshaling response body")
) )
type apiData struct { type apiData struct {
@@ -63,7 +62,7 @@ func fetchAPI(ctx context.Context, client *http.Client) (
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)
if err := decoder.Decode(&data); err != nil { if err := decoder.Decode(&data); err != nil {
return data, fmt.Errorf("%w: %s", ErrUnmarshalResponseBody, err) return data, fmt.Errorf("failed unmarshaling response body: %w", err)
} }
return data, nil return data, nil

View File

@@ -66,7 +66,6 @@ func (r *repeat) Resolve(ctx context.Context, host string, settings RepeatSettin
var ( var (
ErrMaxNoNew = errors.New("reached the maximum number of no new update") ErrMaxNoNew = errors.New("reached the maximum number of no new update")
ErrMaxFails = errors.New("reached the maximum number of consecutive failures") ErrMaxFails = errors.New("reached the maximum number of consecutive failures")
ErrTimeout = errors.New("reached the timeout")
) )
func (r *repeat) resolveOnce(ctx, timedCtx context.Context, host string, func (r *repeat) resolveOnce(ctx, timedCtx context.Context, host string,
@@ -120,7 +119,7 @@ func (r *repeat) resolveOnce(ctx, timedCtx context.Context, host string,
return noNewCounter, failCounter, err return noNewCounter, failCounter, err
} }
return noNewCounter, failCounter, return noNewCounter, failCounter,
fmt.Errorf("%w: %s", ErrTimeout, timedCtx.Err()) fmt.Errorf("reached the timeout: %w", timedCtx.Err())
} }
} }

View File

@@ -26,7 +26,8 @@ func (u *unzipper) FetchAndExtract(ctx context.Context, url string) (
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %s for %s", ErrHTTPStatusCodeNotOK, response.Status, url) return nil, fmt.Errorf("%w: %s: %d %s", ErrHTTPStatusCodeNotOK,
url, response.StatusCode, response.Status)
} }
b, err := io.ReadAll(response.Body) b, err := io.ReadAll(response.Body)

View File

@@ -41,7 +41,8 @@ func getGithubReleases(ctx context.Context, client *http.Client) (releases []git
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %s", errHTTPStatusCode, response.Status) return nil, fmt.Errorf("%w: %d %s", errHTTPStatusCode,
response.StatusCode, response.Status)
} }
decoder := json.NewDecoder(response.Body) decoder := json.NewDecoder(response.Body)

View File

@@ -2,7 +2,6 @@ package vpn
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
@@ -12,14 +11,6 @@ import (
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
) )
var (
errServerConn = errors.New("failed finding a valid server connection")
errBuildConfig = errors.New("failed building configuration")
errWriteConfig = errors.New("failed writing configuration to file")
errWriteAuth = errors.New("failed writing auth to file")
errFirewall = errors.New("failed allowing VPN connection through firewall")
)
// setupOpenVPN sets OpenVPN up using the configurators and settings given. // setupOpenVPN sets OpenVPN up using the configurators and settings given.
// It returns a serverName for port forwarding (PIA) and an error if it fails. // It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter, func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
@@ -28,27 +19,27 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
runner vpnRunner, serverName string, err error) { runner vpnRunner, serverName string, err error) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection) connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errServerConn, err) return nil, "", fmt.Errorf("failed finding a valid server connection: %w", err)
} }
lines, err := providerConf.BuildConf(connection, settings.OpenVPN) lines, err := providerConf.BuildConf(connection, settings.OpenVPN)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err) return nil, "", fmt.Errorf("failed building configuration: %w", err)
} }
if err := openvpnConf.WriteConfig(lines); err != nil { if err := openvpnConf.WriteConfig(lines); err != nil {
return nil, "", fmt.Errorf("%w: %s", errWriteConfig, err) return nil, "", fmt.Errorf("failed writing configuration to file: %w", err)
} }
if settings.OpenVPN.User != "" { if settings.OpenVPN.User != "" {
err := openvpnConf.WriteAuthFile(settings.OpenVPN.User, settings.OpenVPN.Password) err := openvpnConf.WriteAuthFile(settings.OpenVPN.User, settings.OpenVPN.Password)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errWriteAuth, err) return nil, "", fmt.Errorf("failed writing auth to file: %w", err)
} }
} }
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil { if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
return nil, "", fmt.Errorf("%w: %s", errFirewall, err) return nil, "", fmt.Errorf("failed allowing VPN connection through firewall: %w", err)
} }
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger) runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)

View File

@@ -2,18 +2,12 @@ package vpn
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
"github.com/qdm12/gluetun/internal/portforward" "github.com/qdm12/gluetun/internal/portforward"
) )
var (
errObtainVPNLocalGateway = errors.New("cannot obtain VPN local gateway IP")
errStartPortForwarding = errors.New("cannot start port forwarding")
)
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) { func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
if !data.portForwarding { if !data.portForwarding {
return nil return nil
@@ -22,7 +16,7 @@ func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err
// only used for PIA for now // only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf) gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
if err != nil { if err != nil {
return fmt.Errorf("%w: for interface %s: %s", errObtainVPNLocalGateway, data.vpnIntf, err) return fmt.Errorf("cannot obtain VPN local gateway IP for interface %s: %w", data.vpnIntf, err)
} }
l.logger.Info("VPN gateway IP address: " + gateway.String()) l.logger.Info("VPN gateway IP address: " + gateway.String())
@@ -34,7 +28,7 @@ func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err
} }
_, err = l.portForward.Start(ctx, pfData) _, err = l.portForward.Start(ctx, pfData)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errStartPortForwarding, err) return fmt.Errorf("cannot start port forwarding: %w", err)
} }
return nil return nil

View File

@@ -2,7 +2,6 @@ package vpn
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
@@ -13,11 +12,6 @@ import (
"github.com/qdm12/gluetun/internal/wireguard" "github.com/qdm12/gluetun/internal/wireguard"
) )
var (
errGetServer = errors.New("failed finding a VPN server")
errCreateWireguard = errors.New("failed creating Wireguard")
)
// setupWireguard sets Wireguard up using the configurators and settings given. // setupWireguard sets Wireguard up using the configurators and settings given.
// It returns a serverName for port forwarding (PIA) and an error if it fails. // It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupWireguard(ctx context.Context, netlinker netlink.NetLinker, func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
@@ -26,7 +20,7 @@ func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
wireguarder wireguard.Wireguarder, serverName string, err error) { wireguarder wireguard.Wireguarder, serverName string, err error) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection) connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errGetServer, err) return nil, "", fmt.Errorf("failed finding a VPN server: %w", err)
} }
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard) wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard)
@@ -37,12 +31,12 @@ func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger) wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errCreateWireguard, err) return nil, "", fmt.Errorf("failed creating Wireguard: %w", err)
} }
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface) err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("%w: %s", errFirewall, err) return nil, "", fmt.Errorf("failed setting firewall: %w", err)
} }
return wireguarder, connection.Hostname, nil return wireguarder, connection.Hostname, nil

View File

@@ -1,7 +1,6 @@
package wireguard package wireguard
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
@@ -9,20 +8,15 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var (
errMakeConfig = errors.New("cannot make device configuration")
errConfigureDevice = errors.New("cannot configure device")
)
func configureDevice(client *wgctrl.Client, settings Settings) (err error) { func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
deviceConfig, err := makeDeviceConfig(settings) deviceConfig, err := makeDeviceConfig(settings)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errMakeConfig, err) return fmt.Errorf("cannot make device configuration: %w", err)
} }
err = client.ConfigureDevice(settings.InterfaceName, deviceConfig) err = client.ConfigureDevice(settings.InterfaceName, deviceConfig)
if err != nil { if err != nil {
return fmt.Errorf("%w: %s", errConfigureDevice, err) return fmt.Errorf("cannot configure device: %w", err)
} }
return nil return nil

View File

@@ -1,27 +1,21 @@
package wireguard package wireguard
import ( import (
"errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
) )
var (
errLinkList = errors.New("cannot list links")
errRouteList = errors.New("cannot list routes")
)
func (w *Wireguard) isIPv6Supported() (supported bool, err error) { func (w *Wireguard) isIPv6Supported() (supported bool, err error) {
links, err := w.netlink.LinkList() links, err := w.netlink.LinkList()
if err != nil { if err != nil {
return false, fmt.Errorf("%w: %s", errLinkList, err) return false, fmt.Errorf("cannot list links: %w", err)
} }
for _, link := range links { for _, link := range links {
routes, err := w.netlink.RouteList(link, netlink.FAMILY_V6) routes, err := w.netlink.RouteList(link, netlink.FAMILY_V6)
if err != nil { if err != nil {
return false, fmt.Errorf("%w: %s", errRouteList, err) return false, fmt.Errorf("cannot list routes: %w", err)
} }
if len(routes) > 0 { if len(routes) > 0 {

View File

@@ -19,7 +19,7 @@ func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
err = w.netlink.RouteAdd(route) err = w.netlink.RouteAdd(route)
if err != nil { if err != nil {
return fmt.Errorf("%w: when adding route: %s", err, route) return fmt.Errorf("cannot add route %s: %w", route, err)
} }
return err return err

View File

@@ -53,7 +53,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
Table: firewallMark, Table: firewallMark,
}, },
routeAddErr: errDummy, routeAddErr: errDummy,
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820 Realm: 0}"), //nolint:lll err: errors.New("cannot add route {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820 Realm: 0}: dummy"), //nolint:lll
}, },
} }

View File

@@ -14,13 +14,13 @@ func (w *Wireguard) addRule(rulePriority, firewallMark int) (
rule.Mark = firewallMark rule.Mark = firewallMark
rule.Table = firewallMark rule.Table = firewallMark
if err := w.netlink.RuleAdd(rule); err != nil { if err := w.netlink.RuleAdd(rule); err != nil {
return nil, fmt.Errorf("%w: when adding rule: %s", err, rule) return nil, fmt.Errorf("cannot add rule %s: %w", rule, err)
} }
cleanup = func() error { cleanup = func() error {
err := w.netlink.RuleDel(rule) err := w.netlink.RuleDel(rule)
if err != nil { if err != nil {
return fmt.Errorf("%w: when deleting rule: %s", err, rule) return fmt.Errorf("cannot delete rule %s: %w", rule, err)
} }
return nil return nil
} }

View File

@@ -51,7 +51,7 @@ func Test_Wireguard_addRule(t *testing.T) {
SuppressPrefixlen: -1, SuppressPrefixlen: -1,
}, },
ruleAddErr: errDummy, ruleAddErr: errDummy,
err: errors.New("dummy: when adding rule: ip rule 987: from all to all table 456"), err: errors.New("cannot add rule ip rule 987: from all to all table 456: dummy"),
}, },
"rule delete error": { "rule delete error": {
expectedRule: &netlink.Rule{ expectedRule: &netlink.Rule{
@@ -66,7 +66,7 @@ func Test_Wireguard_addRule(t *testing.T) {
SuppressPrefixlen: -1, SuppressPrefixlen: -1,
}, },
ruleDelErr: errDummy, ruleDelErr: errDummy,
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from all to all table 456"), cleanupErr: errors.New("cannot delete rule ip rule 987: from all to all table 456: dummy"),
}, },
} }