chore(errors): review all errors in codebase
This commit is contained in:
@@ -56,11 +56,6 @@ var (
|
|||||||
created = "an unknown date"
|
created = "an unknown date"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errSetupRouting = errors.New("cannot setup routing")
|
|
||||||
errCreateUser = errors.New("cannot create user")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
buildInfo := models.BuildInformation{
|
buildInfo := models.BuildInformation{
|
||||||
Version: version,
|
Version: version,
|
||||||
@@ -278,7 +273,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
const defaultUsername = "nonrootuser"
|
const defaultUsername = "nonrootuser"
|
||||||
nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid)
|
nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errCreateUser, err)
|
return fmt.Errorf("cannot create user: %w", err)
|
||||||
}
|
}
|
||||||
if nonRootUsername != defaultUsername {
|
if nonRootUsername != defaultUsername {
|
||||||
logger.Info("using existing username " + nonRootUsername + " corresponding to user id " + fmt.Sprint(puid))
|
logger.Info("using existing username " + nonRootUsername + " corresponding to user id " + fmt.Sprint(puid))
|
||||||
@@ -296,7 +291,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
if strings.Contains(err.Error(), "operation not permitted") {
|
if strings.Contains(err.Error(), "operation not permitted") {
|
||||||
logger.Warn("💡 Tip: Are you passing NET_ADMIN capability to gluetun?")
|
logger.Warn("💡 Tip: Are you passing NET_ADMIN capability to gluetun?")
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %s", errSetupRouting, err)
|
return fmt.Errorf("cannot setup routing: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
logger.Info("routing cleanup...")
|
logger.Info("routing cleanup...")
|
||||||
|
|||||||
@@ -18,9 +18,6 @@ type ServersFormatter interface {
|
|||||||
var (
|
var (
|
||||||
ErrFormatNotRecognized = errors.New("format is not recognized")
|
ErrFormatNotRecognized = errors.New("format is not recognized")
|
||||||
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
|
ErrProviderUnspecified = errors.New("VPN provider to format was not specified")
|
||||||
ErrOpenOutputFile = errors.New("cannot open output file")
|
|
||||||
ErrWriteOutput = errors.New("cannot write to output file")
|
|
||||||
ErrCloseOutputFile = errors.New("cannot close output file")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *CLI) FormatServers(args []string) error {
|
func (c *CLI) FormatServers(args []string) error {
|
||||||
@@ -62,7 +59,7 @@ func (c *CLI) FormatServers(args []string) error {
|
|||||||
logger := newNoopLogger()
|
logger := newNoopLogger()
|
||||||
storage, err := storage.New(logger, constants.ServersData)
|
storage, err := storage.New(logger, constants.ServersData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrNewStorage, err)
|
return fmt.Errorf("cannot create servers storage: %w", err)
|
||||||
}
|
}
|
||||||
currentServers := storage.GetServers()
|
currentServers := storage.GetServers()
|
||||||
|
|
||||||
@@ -115,18 +112,18 @@ func (c *CLI) FormatServers(args []string) error {
|
|||||||
output = filepath.Clean(output)
|
output = filepath.Clean(output)
|
||||||
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
|
file, err := os.OpenFile(output, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrOpenOutputFile, err)
|
return fmt.Errorf("cannot open output file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = fmt.Fprint(file, formatted)
|
_, err = fmt.Fprint(file, formatted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = file.Close()
|
_ = file.Close()
|
||||||
return fmt.Errorf("%w: %s", ErrWriteOutput, err)
|
return fmt.Errorf("cannot write to output file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = file.Close()
|
err = file.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrCloseOutputFile, err)
|
return fmt.Errorf("cannot close output file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -23,9 +23,6 @@ var (
|
|||||||
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
|
ErrModeUnspecified = errors.New("at least one of -enduser or -maintainer must be specified")
|
||||||
ErrDNSAddress = errors.New("DNS address is not valid")
|
ErrDNSAddress = errors.New("DNS address is not valid")
|
||||||
ErrNoProviderSpecified = errors.New("no provider was specified")
|
ErrNoProviderSpecified = errors.New("no provider was specified")
|
||||||
ErrNewStorage = errors.New("cannot create storage")
|
|
||||||
ErrUpdateServerInformation = errors.New("cannot update server information")
|
|
||||||
ErrWriteToFile = errors.New("cannot write updated information to file")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Updater interface {
|
type Updater interface {
|
||||||
@@ -90,25 +87,25 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
|||||||
|
|
||||||
storage, err := storage.New(logger, constants.ServersData)
|
storage, err := storage.New(logger, constants.ServersData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrNewStorage, err)
|
return fmt.Errorf("cannot create servers storage: %w", err)
|
||||||
}
|
}
|
||||||
currentServers := storage.GetServers()
|
currentServers := storage.GetServers()
|
||||||
|
|
||||||
updater := updater.New(options, httpClient, currentServers, logger)
|
updater := updater.New(options, httpClient, currentServers, logger)
|
||||||
allServers, err := updater.UpdateServers(ctx)
|
allServers, err := updater.UpdateServers(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrUpdateServerInformation, err)
|
return fmt.Errorf("cannot update server information: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if endUserMode {
|
if endUserMode {
|
||||||
if err := storage.FlushToFile(allServers); err != nil {
|
if err := storage.FlushToFile(allServers); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrWriteToFile, err)
|
return fmt.Errorf("cannot write updated information to file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if maintainerMode {
|
if maintainerMode {
|
||||||
if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil {
|
if err := writeToEmbeddedJSON(c.repoServersPath, allServers); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrWriteToFile, err)
|
return fmt.Errorf("cannot write updated information to file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,19 +4,15 @@ import "errors"
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrCityNotValid = errors.New("the city specified is not valid")
|
ErrCityNotValid = errors.New("the city specified is not valid")
|
||||||
ErrControlServerAddress = errors.New("listening address it not valid")
|
|
||||||
ErrControlServerPort = errors.New("listening port it not valid")
|
|
||||||
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
|
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
|
||||||
ErrCountryNotValid = errors.New("the country specified is not valid")
|
ErrCountryNotValid = errors.New("the country specified is not valid")
|
||||||
|
ErrFilepathMissing = errors.New("filepath is missing")
|
||||||
ErrFirewallZeroPort = errors.New("cannot have a zero port to block")
|
ErrFirewallZeroPort = errors.New("cannot have a zero port to block")
|
||||||
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
|
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
|
||||||
ErrISPNotValid = errors.New("the ISP specified is not valid")
|
ErrISPNotValid = errors.New("the ISP specified is not valid")
|
||||||
|
ErrMissingValue = errors.New("missing value")
|
||||||
ErrNameNotValid = errors.New("the server name specified is not valid")
|
ErrNameNotValid = errors.New("the server name specified is not valid")
|
||||||
ErrOpenVPNClientCertMissing = errors.New("client certificate is missing")
|
|
||||||
ErrOpenVPNClientCertNotValid = errors.New("client certificate is not valid")
|
|
||||||
ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
|
ErrOpenVPNClientKeyMissing = errors.New("client key is missing")
|
||||||
ErrOpenVPNClientKeyNotValid = errors.New("client key is not valid")
|
|
||||||
ErrOpenVPNConfigFile = errors.New("custom configuration file error")
|
|
||||||
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
|
ErrOpenVPNCustomPortNotAllowed = errors.New("custom endpoint port is not allowed")
|
||||||
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
|
ErrOpenVPNEncryptionPresetNotValid = errors.New("PIA encryption preset is not valid")
|
||||||
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
|
ErrOpenVPNInterfaceNotValid = errors.New("interface name is not valid")
|
||||||
@@ -27,8 +23,6 @@ var (
|
|||||||
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
|
ErrOpenVPNVerbosityIsOutOfBounds = errors.New("verbosity value is out of bounds")
|
||||||
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
|
ErrOpenVPNVersionIsNotValid = errors.New("version is not valid")
|
||||||
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
|
ErrPortForwardingEnabled = errors.New("port forwarding cannot be enabled")
|
||||||
ErrPortForwardingFilepathNotValid = errors.New("port forwarding filepath given is not valid")
|
|
||||||
ErrPublicIPFilepathNotValid = errors.New("public IP address file path is not valid")
|
|
||||||
ErrPublicIPPeriodTooShort = errors.New("public IP address check period is too short")
|
ErrPublicIPPeriodTooShort = errors.New("public IP address check period is too short")
|
||||||
ErrRegionNotValid = errors.New("the region specified is not valid")
|
ErrRegionNotValid = errors.New("the region specified is not valid")
|
||||||
ErrServerAddressNotValid = errors.New("server listening address is not valid")
|
ErrServerAddressNotValid = errors.New("server listening address is not valid")
|
||||||
@@ -44,9 +38,7 @@ var (
|
|||||||
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
|
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
|
||||||
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
|
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
|
||||||
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
|
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
|
||||||
ErrWireguardPreSharedKeyNotValid = errors.New("pre-shared key is not valid")
|
|
||||||
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
|
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")
|
||||||
ErrWireguardPrivateKeyNotValid = errors.New("private key is not valid")
|
|
||||||
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
|
ErrWireguardPublicKeyNotSet = errors.New("public key is not set")
|
||||||
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
|
ErrWireguardPublicKeyNotValid = errors.New("public key is not valid")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -27,13 +27,12 @@ func (h Health) Validate() (err error) {
|
|||||||
_, err = address.Validate(h.ServerAddress,
|
_, err = address.Validate(h.ServerAddress,
|
||||||
address.OptionListening(uid))
|
address.OptionListening(uid))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s",
|
return fmt.Errorf("server listening address is not valid: %w", err)
|
||||||
ErrServerAddressNotValid, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.VPN.validate()
|
err = h.VPN.validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("health VPN settings validation failed: %w", err)
|
return fmt.Errorf("health VPN settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -41,8 +41,7 @@ func (h HTTPProxy) validate() (err error) {
|
|||||||
uid := os.Getuid()
|
uid := os.Getuid()
|
||||||
_, err = address.Validate(h.ListeningAddress, address.OptionListening(uid))
|
_, err = address.Validate(h.ListeningAddress, address.OptionListening(uid))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s",
|
return fmt.Errorf("%w: %s", ErrServerAddressNotValid, h.ListeningAddress)
|
||||||
ErrServerAddressNotValid, h.ListeningAddress)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -93,17 +93,17 @@ func (o OpenVPN) validate(vpnProvider string) (err error) {
|
|||||||
|
|
||||||
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
|
err = validateOpenVPNConfigFilepath(isCustom, *o.ConfFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("custom configuration file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateOpenVPNClientCertificate(vpnProvider, *o.ClientCrt)
|
err = validateOpenVPNClientCertificate(vpnProvider, *o.ClientCrt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("client certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateOpenVPNClientKey(vpnProvider, *o.ClientKey)
|
err = validateOpenVPNClientKey(vpnProvider, *o.ClientKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("client key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxMSSFix = 10000
|
const maxMSSFix = 10000
|
||||||
@@ -132,12 +132,12 @@ func validateOpenVPNConfigFilepath(isCustom bool,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if confFile == "" {
|
if confFile == "" {
|
||||||
return fmt.Errorf("%w: no file path specified", ErrOpenVPNConfigFile)
|
return ErrFilepathMissing
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helpers.FileExists(confFile)
|
err = helpers.FileExists(confFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrOpenVPNConfigFile, err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -150,7 +150,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
|
|||||||
constants.Cyberghost,
|
constants.Cyberghost,
|
||||||
constants.VPNUnlimited:
|
constants.VPNUnlimited:
|
||||||
if clientCert == "" {
|
if clientCert == "" {
|
||||||
return ErrOpenVPNClientCertMissing
|
return ErrMissingValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,7 +160,7 @@ func validateOpenVPNClientCertificate(vpnProvider,
|
|||||||
|
|
||||||
_, err = parse.ExtractCert([]byte(clientCert))
|
_, err = parse.ExtractCert([]byte(clientCert))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrOpenVPNClientCertNotValid, err)
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -172,7 +172,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
|
|||||||
constants.VPNUnlimited,
|
constants.VPNUnlimited,
|
||||||
constants.Wevpn:
|
constants.Wevpn:
|
||||||
if clientKey == "" {
|
if clientKey == "" {
|
||||||
return ErrOpenVPNClientKeyMissing
|
return ErrMissingValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,7 +182,7 @@ func validateOpenVPNClientKey(vpnProvider, clientKey string) (err error) {
|
|||||||
|
|
||||||
_, err = parse.ExtractPrivateKey([]byte(clientKey))
|
_, err = parse.ExtractPrivateKey([]byte(clientKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrOpenVPNClientKeyNotValid, err)
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,14 +33,16 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) {
|
|||||||
if confFile := *o.ConfFile; confFile != "" {
|
if confFile := *o.ConfFile; confFile != "" {
|
||||||
err := helpers.FileExists(confFile)
|
err := helpers.FileExists(confFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrOpenVPNConfigFile, err)
|
return fmt.Errorf("configuration file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate TCP
|
// Validate TCP
|
||||||
if *o.TCP && helpers.IsOneOf(vpnProvider,
|
if *o.TCP && helpers.IsOneOf(vpnProvider,
|
||||||
|
constants.Ipvanish,
|
||||||
constants.Perfectprivacy,
|
constants.Perfectprivacy,
|
||||||
constants.Privado,
|
constants.Privado,
|
||||||
|
constants.VPNUnlimited,
|
||||||
constants.Vyprvpn,
|
constants.Vyprvpn,
|
||||||
) {
|
) {
|
||||||
return fmt.Errorf("%w: for VPN service provider %s",
|
return fmt.Errorf("%w: for VPN service provider %s",
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (p PortForwarding) validate(vpnProvider string) (err error) {
|
|||||||
if *p.Filepath != "" { // optional
|
if *p.Filepath != "" { // optional
|
||||||
_, err := filepath.Abs(*p.Filepath)
|
_, err := filepath.Abs(*p.Filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrPortForwardingFilepathNotValid, err)
|
return fmt.Errorf("filepath is not valid: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,12 +43,12 @@ func (p *Provider) validate(vpnType string, allServers models.AllServers) (err e
|
|||||||
|
|
||||||
err = p.ServerSelection.validate(*p.Name, allServers)
|
err = p.ServerSelection.validate(*p.Name, allServers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("server selection settings validation failed: %w", err)
|
return fmt.Errorf("server selection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.PortForwarding.validate(*p.Name)
|
err = p.PortForwarding.validate(*p.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("port forwarding settings validation failed: %w", err)
|
return fmt.Errorf("port forwarding: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func (p PublicIP) validate() (err error) {
|
|||||||
if *p.IPFilepath != "" { // optional
|
if *p.IPFilepath != "" { // optional
|
||||||
_, err := filepath.Abs(*p.IPFilepath)
|
_, err := filepath.Abs(*p.IPFilepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrPublicIPFilepathNotValid, err)
|
return fmt.Errorf("filepath is not valid: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,12 +23,12 @@ type ControlServer struct {
|
|||||||
func (c ControlServer) validate() (err error) {
|
func (c ControlServer) validate() (err error) {
|
||||||
_, portStr, err := net.SplitHostPort(*c.Address)
|
_, portStr, err := net.SplitHostPort(*c.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrControlServerAddress, err)
|
return fmt.Errorf("listening address is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
port, err := strconv.Atoi(portStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrControlServerPort, err)
|
return fmt.Errorf("listening port it not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
uid := os.Getuid()
|
uid := os.Getuid()
|
||||||
|
|||||||
@@ -116,18 +116,18 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
|
|||||||
if *ss.MultiHopOnly &&
|
if *ss.MultiHopOnly &&
|
||||||
vpnServiceProvider != constants.Surfshark {
|
vpnServiceProvider != constants.Surfshark {
|
||||||
return fmt.Errorf("%w: for VPN service provider %s",
|
return fmt.Errorf("%w: for VPN service provider %s",
|
||||||
ErrStreamOnlyNotSupported, vpnServiceProvider)
|
ErrMultiHopOnlyNotSupported, vpnServiceProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ss.VPN == constants.OpenVPN {
|
if ss.VPN == constants.OpenVPN {
|
||||||
err = ss.OpenVPN.validate(vpnServiceProvider)
|
err = ss.OpenVPN.validate(vpnServiceProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("OpenVPN server selection settings validation failed: %w", err)
|
return fmt.Errorf("OpenVPN server selection settings: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = ss.Wireguard.validate(vpnServiceProvider)
|
err = ss.Wireguard.validate(vpnServiceProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Wireguard server selection settings validation failed: %w", err)
|
return fmt.Errorf("Wireguard server selection settings: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (s *Settings) Validate(allServers models.AllServers) (err error) {
|
|||||||
for name, validation := range nameToValidation {
|
for name, validation := range nameToValidation {
|
||||||
err = validation()
|
err = validation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed validating %s settings: %w", name, err)
|
return fmt.Errorf("%s settings: %w", name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,18 +31,18 @@ func (v *VPN) validate(allServers models.AllServers) (err error) {
|
|||||||
|
|
||||||
err = v.Provider.validate(v.Type, allServers)
|
err = v.Provider.validate(v.Type, allServers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("provider settings validation failed: %w", err)
|
return fmt.Errorf("provider settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.Type == constants.OpenVPN {
|
if v.Type == constants.OpenVPN {
|
||||||
err := v.OpenVPN.validate(*v.Provider.Name)
|
err := v.OpenVPN.validate(*v.Provider.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("OpenVPN settings validation failed: %w", err)
|
return fmt.Errorf("OpenVPN settings: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err := v.Wireguard.validate(*v.Provider.Name)
|
err := v.Wireguard.validate(*v.Provider.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Wireguard settings validation failed: %w", err)
|
return fmt.Errorf("Wireguard settings: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,14 +50,14 @@ func (w Wireguard) validate(vpnProvider string) (err error) {
|
|||||||
}
|
}
|
||||||
_, err = wgtypes.ParseKey(*w.PrivateKey)
|
_, err = wgtypes.ParseKey(*w.PrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrWireguardPrivateKeyNotValid, err)
|
return fmt.Errorf("private key is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate PreSharedKey
|
// Validate PreSharedKey
|
||||||
if *w.PreSharedKey != "" { // Note: this is optional
|
if *w.PreSharedKey != "" { // Note: this is optional
|
||||||
_, err = wgtypes.ParseKey(*w.PreSharedKey)
|
_, err = wgtypes.ParseKey(*w.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrWireguardPreSharedKeyNotValid, err)
|
return fmt.Errorf("pre-shared key is not valid: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
2
internal/configuration/sources/env/dns.go
vendored
2
internal/configuration/sources/env/dns.go
vendored
@@ -20,7 +20,7 @@ func (r *Reader) readDNS() (dns settings.DNS, err error) {
|
|||||||
|
|
||||||
dns.DoT, err = r.readDoT()
|
dns.DoT, err = r.readDoT()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return dns, fmt.Errorf("cannot read DoT settings: %w", err)
|
return dns, fmt.Errorf("DoT settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return dns, nil
|
return dns, nil
|
||||||
|
|||||||
10
internal/configuration/sources/env/firewall.go
vendored
10
internal/configuration/sources/env/firewall.go
vendored
@@ -55,8 +55,7 @@ func stringsToPorts(ss []string) (ports []uint16, err error) {
|
|||||||
for i, s := range ss {
|
for i, s := range ss {
|
||||||
port, err := strconv.Atoi(s)
|
port, err := strconv.Atoi(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s: %s",
|
return nil, fmt.Errorf("%w: %s: %s", ErrPortParsing, s, err)
|
||||||
ErrPortParsing, s, err)
|
|
||||||
} else if port < 1 || port > 65535 {
|
} else if port < 1 || port > 65535 {
|
||||||
return nil, fmt.Errorf("%w: must be between 1 and 65535: %d",
|
return nil, fmt.Errorf("%w: must be between 1 and 65535: %d",
|
||||||
ErrPortValue, port)
|
ErrPortValue, port)
|
||||||
@@ -66,10 +65,6 @@ func stringsToPorts(ss []string) (ports []uint16, err error) {
|
|||||||
return ports, nil
|
return ports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrIPNetParsing = errors.New("cannot parse IP network")
|
|
||||||
)
|
|
||||||
|
|
||||||
func stringsToIPNets(ss []string) (ipNets []net.IPNet, err error) {
|
func stringsToIPNets(ss []string) (ipNets []net.IPNet, err error) {
|
||||||
if len(ss) == 0 {
|
if len(ss) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -78,8 +73,7 @@ func stringsToIPNets(ss []string) (ipNets []net.IPNet, err error) {
|
|||||||
for i, s := range ss {
|
for i, s := range ss {
|
||||||
ip, ipNet, err := net.ParseCIDR(s)
|
ip, ipNet, err := net.ParseCIDR(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s: %s",
|
return nil, fmt.Errorf("cannot parse IP network %q: %w", s, err)
|
||||||
ErrIPNetParsing, s, err)
|
|
||||||
}
|
}
|
||||||
ipNet.IP = ip
|
ipNet.IP = ip
|
||||||
ipNets[i] = *ipNet
|
ipNets[i] = *ipNet
|
||||||
|
|||||||
4
internal/configuration/sources/env/health.go
vendored
4
internal/configuration/sources/env/health.go
vendored
@@ -38,9 +38,7 @@ func (r *Reader) readDurationWithRetro(envKey, retroEnvKey string) (d *time.Dura
|
|||||||
d = new(time.Duration)
|
d = new(time.Duration)
|
||||||
*d, err = time.ParseDuration(s)
|
*d, err = time.ParseDuration(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf("environment variable %s: %w", envKey, err)
|
||||||
"environment variable %s: %w",
|
|
||||||
envKey, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package env
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -115,13 +114,11 @@ func lowerAndSplit(csv string) (values []string) {
|
|||||||
return strings.Split(csv, ",")
|
return strings.Split(csv, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrDecodeBase64 = errors.New("cannot decode base64 string")
|
|
||||||
|
|
||||||
func decodeBase64(b64String string) (decoded string, err error) {
|
func decodeBase64(b64String string) (decoded string, err error) {
|
||||||
b, err := base64.StdEncoding.DecodeString(b64String)
|
b, err := base64.StdEncoding.DecodeString(b64String)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("%w: %s: %s",
|
return "", fmt.Errorf("cannot decode base64 string %q: %w",
|
||||||
ErrDecodeBase64, b64String, err)
|
b64String, err)
|
||||||
}
|
}
|
||||||
return string(b), nil
|
return string(b), nil
|
||||||
}
|
}
|
||||||
|
|||||||
2
internal/configuration/sources/env/log.go
vendored
2
internal/configuration/sources/env/log.go
vendored
@@ -48,7 +48,7 @@ func parseLogLevel(s string) (level logging.Level, err error) {
|
|||||||
return logging.LevelError, nil
|
return logging.LevelError, nil
|
||||||
default:
|
default:
|
||||||
return level, fmt.Errorf(
|
return level, fmt.Errorf(
|
||||||
"%w: %s: can be one of: debug, info, warning or error",
|
"%w: %q is not valid and can be one of debug, info, warning or error",
|
||||||
ErrLogLevelUnknown, s)
|
ErrLogLevelUnknown, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ func (r *Reader) readProvider(vpnType string) (provider settings.Provider, err e
|
|||||||
|
|
||||||
provider.ServerSelection, err = r.readServerSelection(providerName, vpnType)
|
provider.ServerSelection, err = r.readServerSelection(providerName, vpnType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return provider, fmt.Errorf("cannot read server selection settings: %w", err)
|
return provider, fmt.Errorf("server selection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
provider.PortForwarding, err = r.readPortForward()
|
provider.PortForwarding, err = r.readPortForward()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return provider, fmt.Errorf("cannot read port forwarding settings: %w", err)
|
return provider, fmt.Errorf("port forwarding: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return provider, nil
|
return provider, nil
|
||||||
|
|||||||
6
internal/configuration/sources/env/vpn.go
vendored
6
internal/configuration/sources/env/vpn.go
vendored
@@ -13,17 +13,17 @@ func (r *Reader) readVPN() (vpn settings.VPN, err error) {
|
|||||||
|
|
||||||
vpn.Provider, err = r.readProvider(vpn.Type)
|
vpn.Provider, err = r.readProvider(vpn.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return vpn, fmt.Errorf("cannot read provider settings: %w", err)
|
return vpn, fmt.Errorf("VPN provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
vpn.OpenVPN, err = r.readOpenVPN()
|
vpn.OpenVPN, err = r.readOpenVPN()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return vpn, fmt.Errorf("cannot read OpenVPN settings: %w", err)
|
return vpn, fmt.Errorf("OpenVPN: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
vpn.Wireguard, err = r.readWireguard()
|
vpn.Wireguard, err = r.readWireguard()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return vpn, fmt.Errorf("cannot read Wireguard settings: %w", err)
|
return vpn, fmt.Errorf("wireguard: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return vpn, nil
|
return vpn, nil
|
||||||
|
|||||||
@@ -16,12 +16,12 @@ const (
|
|||||||
func (r *Reader) readOpenVPN() (settings settings.OpenVPN, err error) {
|
func (r *Reader) readOpenVPN() (settings settings.OpenVPN, err error) {
|
||||||
settings.ClientKey, err = ReadFromFile(OpenVPNClientKeyPath)
|
settings.ClientKey, err = ReadFromFile(OpenVPNClientKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return settings, fmt.Errorf("cannot read client key: %w", err)
|
return settings, fmt.Errorf("client key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings.ClientCrt, err = ReadFromFile(OpenVPNClientCertificatePath)
|
settings.ClientCrt, err = ReadFromFile(OpenVPNClientCertificatePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return settings, fmt.Errorf("cannot read client certificate: %w", err)
|
return settings, fmt.Errorf("client certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return settings, nil
|
return settings, nil
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
func (r *Reader) readVPN() (vpn settings.VPN, err error) {
|
func (r *Reader) readVPN() (vpn settings.VPN, err error) {
|
||||||
vpn.OpenVPN, err = r.readOpenVPN()
|
vpn.OpenVPN, err = r.readOpenVPN()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return vpn, fmt.Errorf("cannot read OpenVPN settings: %w", err)
|
return vpn, fmt.Errorf("OpenVPN: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return vpn, nil
|
return vpn, nil
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func (r *Reader) Read() (settings settings.Settings, err error) {
|
|||||||
for _, source := range r.sources {
|
for _, source := range r.sources {
|
||||||
settingsFromSource, err := source.Read()
|
settingsFromSource, err := source.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return settings, fmt.Errorf("cannot read from source %T: %w", source, err)
|
return settings, fmt.Errorf("reading from source %T: %w", source, err)
|
||||||
}
|
}
|
||||||
settings.MergeWith(settingsFromSource)
|
settings.MergeWith(settingsFromSource)
|
||||||
}
|
}
|
||||||
@@ -42,7 +42,7 @@ func (r *Reader) ReadHealth() (settings settings.Health, err error) {
|
|||||||
for _, source := range r.sources {
|
for _, source := range r.sources {
|
||||||
settingsFromSource, err := source.ReadHealth()
|
settingsFromSource, err := source.ReadHealth()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return settings, fmt.Errorf("cannot read from source %T: %w", source, err)
|
return settings, fmt.Errorf("reading from source %T: %w", source, err)
|
||||||
}
|
}
|
||||||
settings.MergeWith(settingsFromSource)
|
settings.MergeWith(settingsFromSource)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,16 +2,9 @@ package firewall
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrEnable = errors.New("failed enabling firewall")
|
|
||||||
ErrDisable = errors.New("failed disabling firewall")
|
|
||||||
ErrUserPostRules = errors.New("cannot run user post firewall rules")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Enabler interface {
|
type Enabler interface {
|
||||||
SetEnabled(ctx context.Context, enabled bool) (err error)
|
SetEnabled(ctx context.Context, enabled bool) (err error)
|
||||||
}
|
}
|
||||||
@@ -32,7 +25,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
|||||||
if !enabled {
|
if !enabled {
|
||||||
c.logger.Info("disabling...")
|
c.logger.Info("disabling...")
|
||||||
if err = c.disable(ctx); err != nil {
|
if err = c.disable(ctx); err != nil {
|
||||||
return err
|
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||||
}
|
}
|
||||||
c.enabled = false
|
c.enabled = false
|
||||||
c.logger.Info("disabled successfully")
|
c.logger.Info("disabled successfully")
|
||||||
@@ -42,7 +35,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
|||||||
c.logger.Info("enabling...")
|
c.logger.Info("enabling...")
|
||||||
|
|
||||||
if err := c.enable(ctx); err != nil {
|
if err := c.enable(ctx); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrEnable, err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
c.enabled = true
|
c.enabled = true
|
||||||
c.logger.Info("enabled successfully")
|
c.logger.Info("enabled successfully")
|
||||||
@@ -52,13 +45,13 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
|||||||
|
|
||||||
func (c *Config) disable(ctx context.Context) (err error) {
|
func (c *Config) disable(ctx context.Context) (err error) {
|
||||||
if err = c.clearAllRules(ctx); err != nil {
|
if err = c.clearAllRules(ctx); err != nil {
|
||||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
return fmt.Errorf("cannot clear all rules: %w", err)
|
||||||
}
|
}
|
||||||
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
|
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
return fmt.Errorf("cannot set ipv4 policies: %w", err)
|
||||||
}
|
}
|
||||||
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
|
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
return fmt.Errorf("cannot set ipv6 policies: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -76,12 +69,12 @@ func (c *Config) fallbackToDisabled(ctx context.Context) {
|
|||||||
func (c *Config) enable(ctx context.Context) (err error) {
|
func (c *Config) enable(ctx context.Context) (err error) {
|
||||||
touched := false
|
touched := false
|
||||||
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
|
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
touched = true
|
touched = true
|
||||||
|
|
||||||
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
|
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const remove = false
|
const remove = false
|
||||||
@@ -94,33 +87,33 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
// Loopback traffic
|
// Loopback traffic
|
||||||
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
if c.vpnConnection.IP != nil {
|
if c.vpnConnection.IP != nil {
|
||||||
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range c.localNetworks {
|
for _, network := range c.localNetworks {
|
||||||
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil {
|
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, subnet := range c.outboundSubnets {
|
for _, subnet := range c.outboundSubnets {
|
||||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
|
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,18 +121,18 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
|||||||
// to reach Gluetun.
|
// to reach Gluetun.
|
||||||
for _, network := range c.localNetworks {
|
for _, network := range c.localNetworks {
|
||||||
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, *network.IPNet, remove); err != nil {
|
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, *network.IPNet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for port, intf := range c.allowedInputPorts {
|
for port, intf := range c.allowedInputPorts {
|
||||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrUserPostRules, err)
|
return fmt.Errorf("cannot run user defined post firewall rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrIP6Tables = errors.New("failed ip6tables command")
|
|
||||||
ErrIP6NotSupported = errors.New("ip6tables not supported")
|
ErrIP6NotSupported = errors.New("ip6tables not supported")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,18 +43,18 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
|
|||||||
flags := strings.Fields(instruction)
|
flags := strings.Fields(instruction)
|
||||||
cmd := exec.CommandContext(ctx, "ip6tables", flags...)
|
cmd := exec.CommandContext(ctx, "ip6tables", flags...)
|
||||||
if output, err := c.runner.Run(cmd); err != nil {
|
if output, err := c.runner.Run(cmd); err != nil {
|
||||||
return fmt.Errorf("%w: \"ip6tables %s\": %s: %s", ErrIP6Tables, instruction, output, err)
|
return fmt.Errorf("command failed: \"ip6tables %s\": %s: %w", instruction, output, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errPolicyNotValid = errors.New("policy is not valid")
|
var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||||
|
|
||||||
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||||
switch policy {
|
switch policy {
|
||||||
case "ACCEPT", "DROP":
|
case "ACCEPT", "DROP":
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", errPolicyNotValid, policy)
|
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||||
}
|
}
|
||||||
return c.runIP6tablesInstructions(ctx, []string{
|
return c.runIP6tablesInstructions(ctx, []string{
|
||||||
"--policy INPUT " + policy,
|
"--policy INPUT " + policy,
|
||||||
|
|||||||
@@ -16,10 +16,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
||||||
ErrIPTables = errors.New("failed iptables command")
|
|
||||||
ErrPolicyUnknown = errors.New("unknown policy")
|
ErrPolicyUnknown = errors.New("unknown policy")
|
||||||
ErrClearRules = errors.New("cannot clear all rules")
|
|
||||||
ErrSetIPtablesPolicies = errors.New("cannot set iptables policies")
|
|
||||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -79,33 +76,30 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
|||||||
flags := strings.Fields(instruction)
|
flags := strings.Fields(instruction)
|
||||||
cmd := exec.CommandContext(ctx, "iptables", flags...)
|
cmd := exec.CommandContext(ctx, "iptables", flags...)
|
||||||
if output, err := c.runner.Run(cmd); err != nil {
|
if output, err := c.runner.Run(cmd); err != nil {
|
||||||
return fmt.Errorf("%w \"iptables %s\": %s: %s", ErrIPTables, instruction, output, err)
|
return fmt.Errorf("command failed: \"iptables %s\": %s: %w", instruction, output, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) clearAllRules(ctx context.Context) error {
|
func (c *Config) clearAllRules(ctx context.Context) error {
|
||||||
if err := c.runMixedIptablesInstructions(ctx, []string{
|
return c.runMixedIptablesInstructions(ctx, []string{
|
||||||
"--flush", // flush all chains
|
"--flush", // flush all chains
|
||||||
"--delete-chain", // delete all chains
|
"--delete-chain", // delete all chains
|
||||||
}); err != nil {
|
})
|
||||||
return fmt.Errorf("%w: %s", ErrClearRules, err.Error())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
|
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||||
switch policy {
|
switch policy {
|
||||||
case "ACCEPT", "DROP":
|
case "ACCEPT", "DROP":
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s: %s", ErrSetIPtablesPolicies, ErrPolicyUnknown, policy)
|
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||||
}
|
}
|
||||||
if err := c.runIptablesInstructions(ctx, []string{
|
if err := c.runIptablesInstructions(ctx, []string{
|
||||||
"--policy INPUT " + policy,
|
"--policy INPUT " + policy,
|
||||||
"--policy OUTPUT " + policy,
|
"--policy OUTPUT " + policy,
|
||||||
"--policy FORWARD " + policy,
|
"--policy FORWARD " + policy,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrSetIPtablesPolicies, err)
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Info("setting allowed subnets through firewall...")
|
c.logger.Info("setting allowed subnets...")
|
||||||
|
|
||||||
subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange(c.outboundSubnets, subnets)
|
subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange(c.outboundSubnets, subnets)
|
||||||
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
||||||
@@ -32,7 +32,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e
|
|||||||
|
|
||||||
c.removeOutboundSubnets(ctx, subnetsToRemove)
|
c.removeOutboundSubnets(ctx, subnetsToRemove)
|
||||||
if err := c.addOutboundSubnets(ctx, subnetsToAdd); err != nil {
|
if err := c.addOutboundSubnets(ctx, subnetsToAdd); err != nil {
|
||||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
return fmt.Errorf("cannot set allowed outbound subnets: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -42,7 +42,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet)
|
|||||||
const remove = true
|
const remove = true
|
||||||
for _, subNet := range subnets {
|
for _, subNet := range subnets {
|
||||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subNet, remove); err != nil {
|
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subNet, remove); err != nil {
|
||||||
c.logger.Error("cannot remove outdated outbound subnet through firewall: " + err.Error())
|
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c.outboundSubnets = subnet.RemoveSubnetFromSubnets(c.outboundSubnets, subNet)
|
c.outboundSubnets = subnet.RemoveSubnetFromSubnets(c.outboundSubnets, subNet)
|
||||||
@@ -53,7 +53,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) er
|
|||||||
const remove = false
|
const remove = false
|
||||||
for _, subnet := range subnets {
|
for _, subnet := range subnets {
|
||||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
|
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
c.outboundSubnets = append(c.outboundSubnets, subnet)
|
c.outboundSubnets = append(c.outboundSubnets, subnet)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,13 +33,13 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
|
|||||||
}
|
}
|
||||||
const remove = true
|
const remove = true
|
||||||
if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil {
|
if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil {
|
||||||
return fmt.Errorf("cannot remove old allowed port %d through interface %s: %w", port, existingIntf, err)
|
return fmt.Errorf("cannot remove old allowed port %d: %w", port, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const remove = false
|
const remove = false
|
||||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||||
return fmt.Errorf("cannot set allowed port %d through interface %s: %w", port, intf, err)
|
return fmt.Errorf("cannot allow input to port %d: %w", port, err)
|
||||||
}
|
}
|
||||||
c.allowedInputPorts[port] = intf
|
c.allowedInputPorts[port] = intf
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + " through firewall...")
|
c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + " ...")
|
||||||
|
|
||||||
intf, ok := c.allowedInputPorts[port]
|
intf, ok := c.allowedInputPorts[port]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -69,7 +69,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
|||||||
|
|
||||||
const remove = true
|
const remove = true
|
||||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||||
return fmt.Errorf("cannot remove allowed port %d through interface %s: %w", port, intf, err)
|
return fmt.Errorf("cannot remove allowed port %d: %w", port, err)
|
||||||
}
|
}
|
||||||
delete(c.allowedInputPorts, port)
|
delete(c.allowedInputPorts, port)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Info("setting VPN connection through firewall...")
|
c.logger.Info("allowing VPN connection...")
|
||||||
|
|
||||||
if c.vpnConnection.Equal(connection) {
|
if c.vpnConnection.Equal(connection) {
|
||||||
return nil
|
return nil
|
||||||
@@ -32,14 +32,14 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
|||||||
remove := true
|
remove := true
|
||||||
if c.vpnConnection.IP != nil {
|
if c.vpnConnection.IP != nil {
|
||||||
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
||||||
c.logger.Error("cannot remove outdated VPN connection through firewall: " + err.Error())
|
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.vpnConnection = models.Connection{}
|
c.vpnConnection = models.Connection{}
|
||||||
|
|
||||||
if c.vpnIntf != "" {
|
if c.vpnIntf != "" {
|
||||||
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||||
c.logger.Error("cannot remove outdated VPN interface from firewall: " + err.Error())
|
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.vpnIntf = ""
|
c.vpnIntf = ""
|
||||||
@@ -47,7 +47,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
|||||||
remove = false
|
remove = false
|
||||||
|
|
||||||
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
|
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
|
||||||
return fmt.Errorf("cannot set VPN connection through firewall: %w", err)
|
return fmt.Errorf("cannot allow output traffic through VPN connection: %w", err)
|
||||||
}
|
}
|
||||||
c.vpnConnection = connection
|
c.vpnConnection = connection
|
||||||
|
|
||||||
|
|||||||
@@ -45,5 +45,6 @@ func (c *Client) Check(ctx context.Context, url string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %s: %s", ErrHTTPStatusNotOK, response.Status, string(b))
|
return fmt.Errorf("%w: %d %s: %s", ErrHTTPStatusNotOK,
|
||||||
|
response.StatusCode, response.Status, string(b))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,12 +17,12 @@ func (e *Extractor) Data(filepath string) (lines []string,
|
|||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
lines, err = readCustomConfigLines(filepath)
|
lines, err = readCustomConfigLines(filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, connection, fmt.Errorf("%w: %s", ErrRead, err)
|
return nil, connection, fmt.Errorf("cannot read configuration file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connection, err = extractDataFromLines(lines)
|
connection, err = extractDataFromLines(lines)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, connection, fmt.Errorf("%w: %s", ErrExtractConnection, err)
|
return nil, connection, fmt.Errorf("cannot extract connection from file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return lines, connection, nil
|
return lines, connection, nil
|
||||||
|
|||||||
@@ -48,25 +48,20 @@ func extractDataFromLines(lines []string) (
|
|||||||
return connection, nil
|
return connection, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
errExtractProto = errors.New("failed extracting protocol from proto line")
|
|
||||||
errExtractRemote = errors.New("failed extracting from remote line")
|
|
||||||
)
|
|
||||||
|
|
||||||
func extractDataFromLine(line string) (
|
func extractDataFromLine(line string) (
|
||||||
ip net.IP, port uint16, protocol string, err error) {
|
ip net.IP, port uint16, protocol string, err error) {
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(line, "proto "):
|
case strings.HasPrefix(line, "proto "):
|
||||||
protocol, err = extractProto(line)
|
protocol, err = extractProto(line)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, "", fmt.Errorf("%w: %s", errExtractProto, err)
|
return nil, 0, "", fmt.Errorf("failed extracting protocol from proto line: %w", err)
|
||||||
}
|
}
|
||||||
return nil, 0, protocol, nil
|
return nil, 0, protocol, nil
|
||||||
|
|
||||||
case strings.HasPrefix(line, "remote "):
|
case strings.HasPrefix(line, "remote "):
|
||||||
ip, port, protocol, err = extractRemote(line)
|
ip, port, protocol, err = extractRemote(line)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, "", fmt.Errorf("%w: %s", errExtractRemote, err)
|
return nil, 0, "", fmt.Errorf("failed extracting from remote line: %w", err)
|
||||||
}
|
}
|
||||||
return ip, port, protocol, nil
|
return ip, port, protocol, nil
|
||||||
}
|
}
|
||||||
@@ -122,7 +117,7 @@ func extractRemote(line string) (ip net.IP, port uint16,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line)
|
return nil, 0, "", fmt.Errorf("%w: %s", errPortNotValid, line)
|
||||||
} else if portInt < 1 || portInt > 65535 {
|
} else if portInt < 1 || portInt > 65535 {
|
||||||
return nil, 0, "", fmt.Errorf("%w: not between 1 and 65535: %d", errPortNotValid, portInt)
|
return nil, 0, "", fmt.Errorf("%w: %d must be between 1 and 65535", errPortNotValid, portInt)
|
||||||
}
|
}
|
||||||
port = uint16(portInt)
|
port = uint16(portInt)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func Test_extractDataFromLine(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"extract proto error": {
|
"extract proto error": {
|
||||||
line: "proto bad",
|
line: "proto bad",
|
||||||
isErr: errExtractProto,
|
isErr: errProtocolNotSupported,
|
||||||
},
|
},
|
||||||
"extract proto success": {
|
"extract proto success": {
|
||||||
line: "proto tcp",
|
line: "proto tcp",
|
||||||
@@ -106,7 +106,7 @@ func Test_extractDataFromLine(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"extract remote error": {
|
"extract remote error": {
|
||||||
line: "remote bad",
|
line: "remote bad",
|
||||||
isErr: errExtractRemote,
|
isErr: errHostNotIP,
|
||||||
},
|
},
|
||||||
"extract remote success": {
|
"extract remote success": {
|
||||||
line: "remote 1.2.3.4 1194 udp",
|
line: "remote 1.2.3.4 1194 udp",
|
||||||
@@ -213,15 +213,15 @@ func Test_extractRemote(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"port is zero": {
|
"port is zero": {
|
||||||
line: "remote 1.2.3.4 0",
|
line: "remote 1.2.3.4 0",
|
||||||
err: errors.New("port is not valid: not between 1 and 65535: 0"),
|
err: errors.New("port is not valid: 0 must be between 1 and 65535"),
|
||||||
},
|
},
|
||||||
"port is minus one": {
|
"port is minus one": {
|
||||||
line: "remote 1.2.3.4 -1",
|
line: "remote 1.2.3.4 -1",
|
||||||
err: errors.New("port is not valid: not between 1 and 65535: -1"),
|
err: errors.New("port is not valid: -1 must be between 1 and 65535"),
|
||||||
},
|
},
|
||||||
"port is over 65535": {
|
"port is over 65535": {
|
||||||
line: "remote 1.2.3.4 65536",
|
line: "remote 1.2.3.4 65536",
|
||||||
err: errors.New("port is not valid: not between 1 and 65535: 65536"),
|
err: errors.New("port is not valid: 65536 must be between 1 and 65535"),
|
||||||
},
|
},
|
||||||
"IP host and port": {
|
"IP host and port": {
|
||||||
line: "remote 1.2.3.4 8000",
|
line: "remote 1.2.3.4 8000",
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
func ExtractCert(b []byte) (certData string, err error) {
|
func ExtractCert(b []byte) (certData string, err error) {
|
||||||
certData, err = extractPEM(b, "CERTIFICATE")
|
certData, err = extractPEM(b, "CERTIFICATE")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("%w: %s", ErrExtractPEM, err)
|
return "", fmt.Errorf("cannot extract PEM data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return certData, nil
|
return certData, nil
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
package parse
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrExtractPEM = errors.New("cannot extract PEM data")
|
|
||||||
)
|
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
func ExtractPrivateKey(b []byte) (keyData string, err error) {
|
func ExtractPrivateKey(b []byte) (keyData string, err error) {
|
||||||
keyData, err = extractPEM(b, "PRIVATE KEY")
|
keyData, err = extractPEM(b, "PRIVATE KEY")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("%w: %s", ErrExtractPEM, err)
|
return "", fmt.Errorf("cannot extract PEM data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return keyData, nil
|
return keyData, nil
|
||||||
|
|||||||
@@ -27,6 +27,6 @@ func (l *Loop) firewallAllowPort(ctx context.Context) {
|
|||||||
startData := l.state.GetStartData()
|
startData := l.state.GetStartData()
|
||||||
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
|
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.logger.Error("cannot allow port through firewall: " + err.Error())
|
l.logger.Error("cannot allow port: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrVPNTypeNotSupported = errors.New("VPN type not supported for custom provider")
|
ErrVPNTypeNotSupported = errors.New("VPN type not supported for custom provider")
|
||||||
ErrExtractConnection = errors.New("cannot extract connection")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetConnection gets the connection from the OpenVPN configuration file.
|
// GetConnection gets the connection from the OpenVPN configuration file.
|
||||||
@@ -34,7 +33,7 @@ func getOpenVPNConnection(extractor extract.Interface,
|
|||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
_, connection, err = extractor.Data(*selection.OpenVPN.ConfFile)
|
_, connection, err = extractor.Data(*selection.OpenVPN.ConfFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return connection, fmt.Errorf("%w: %s", ErrExtractConnection, err)
|
return connection, fmt.Errorf("cannot extract connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connection.Port = getPort(connection.Port, selection)
|
connection.Port = getPort(connection.Port, selection)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func (p *Provider) BuildConf(connection models.Connection,
|
|||||||
settings settings.OpenVPN) (lines []string, err error) {
|
settings settings.OpenVPN) (lines []string, err error) {
|
||||||
lines, _, err = p.extractor.Data(*settings.ConfFile)
|
lines, _, err = p.extractor.Data(*settings.ConfFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrExtractData, err)
|
return nil, fmt.Errorf("failed extracting information from custom configuration file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
lines = modifyConfig(lines, connection, settings)
|
lines = modifyConfig(lines, connection, settings)
|
||||||
|
|||||||
@@ -1,23 +1,16 @@
|
|||||||
package ipvanish
|
package ipvanish
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrProtocolUnsupported = errors.New("network protocol is not supported")
|
|
||||||
|
|
||||||
func (i *Ipvanish) GetConnection(selection settings.ServerSelection) (
|
func (i *Ipvanish) GetConnection(selection settings.ServerSelection) (
|
||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
const port = 443
|
const port = 443
|
||||||
const protocol = constants.UDP
|
const protocol = constants.UDP
|
||||||
if *selection.OpenVPN.TCP {
|
|
||||||
return connection, ErrProtocolUnsupported
|
|
||||||
}
|
|
||||||
|
|
||||||
servers, err := i.filterServers(selection)
|
servers, err := i.filterServers(selection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,24 +1,16 @@
|
|||||||
package privado
|
package privado
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrProtocolUnsupported = errors.New("network protocol is not supported")
|
|
||||||
|
|
||||||
func (p *Privado) GetConnection(selection settings.ServerSelection) (
|
func (p *Privado) GetConnection(selection settings.ServerSelection) (
|
||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
const port = 1194
|
const port = 1194
|
||||||
const protocol = constants.UDP
|
const protocol = constants.UDP
|
||||||
if *selection.OpenVPN.TCP {
|
|
||||||
return connection, fmt.Errorf("%w: TCP for provider Privado", ErrProtocolUnsupported)
|
|
||||||
}
|
|
||||||
|
|
||||||
servers, err := p.filterServers(selection)
|
servers, err := p.filterServers(selection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -13,18 +12,14 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrParseCertificate = errors.New("cannot parse X509 certificate")
|
|
||||||
)
|
|
||||||
|
|
||||||
func newHTTPClient(serverName string) (client *http.Client, err error) {
|
func newHTTPClient(serverName string) (client *http.Client, err error) {
|
||||||
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PiaCAStrong)
|
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PiaCAStrong)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrParseCertificate, err)
|
return nil, fmt.Errorf("cannot parse X509 certificate: %w", err)
|
||||||
}
|
}
|
||||||
certificate, err := x509.ParseCertificate(certificateBytes)
|
certificate, err := x509.ParseCertificate(certificateBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrParseCertificate, err)
|
return nil, fmt.Errorf("cannot parse X509 certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gomnd
|
//nolint:gomnd
|
||||||
|
|||||||
@@ -23,10 +23,6 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
|
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
|
||||||
ErrServerNameEmpty = errors.New("server name is empty")
|
ErrServerNameEmpty = errors.New("server name is empty")
|
||||||
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
|
|
||||||
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
|
|
||||||
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
|
|
||||||
ErrBindPort = errors.New("cannot bind port")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PortForward obtains a VPN server side port forwarded from PIA.
|
// PortForward obtains a VPN server side port forwarded from PIA.
|
||||||
@@ -53,12 +49,12 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
|||||||
|
|
||||||
privateIPClient, err := newHTTPClient(serverName)
|
privateIPClient, err := newHTTPClient(serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
|
return 0, fmt.Errorf("cannot create custom HTTP client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
|
return 0, fmt.Errorf("cannot read saved port forwarded data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataFound := data.Port > 0
|
dataFound := data.Port > 0
|
||||||
@@ -79,7 +75,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
|||||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||||
p.portForwardPath, p.authFilePath)
|
p.portForwardPath, p.authFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err)
|
return 0, fmt.Errorf("cannot refresh port forward data: %w", err)
|
||||||
}
|
}
|
||||||
durationToExpiration = data.Expiration.Sub(p.timeNow())
|
durationToExpiration = data.Expiration.Sub(p.timeNow())
|
||||||
}
|
}
|
||||||
@@ -87,7 +83,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
|||||||
|
|
||||||
// First time binding
|
// First time binding
|
||||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||||
return 0, fmt.Errorf("%w: %s", ErrBindPort, err)
|
return 0, fmt.Errorf("cannot bind port: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return data.Port, nil
|
return data.Port, nil
|
||||||
@@ -101,12 +97,12 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
|||||||
port uint16, gateway net.IP, serverName string) (err error) {
|
port uint16, gateway net.IP, serverName string) (err error) {
|
||||||
privateIPClient, err := newHTTPClient(serverName)
|
privateIPClient, err := newHTTPClient(serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
|
return fmt.Errorf("cannot create custom HTTP client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
|
return fmt.Errorf("cannot read saved port forwarded data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||||
@@ -128,7 +124,7 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
|||||||
case <-keepAliveTimer.C:
|
case <-keepAliveTimer.C:
|
||||||
err := bindPort(ctx, privateIPClient, gateway, data)
|
err := bindPort(ctx, privateIPClient, gateway, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrBindPort, err)
|
return fmt.Errorf("cannot bind port: %w", err)
|
||||||
}
|
}
|
||||||
keepAliveTimer.Reset(keepAlivePeriod)
|
keepAliveTimer.Reset(keepAlivePeriod)
|
||||||
case <-expiryTimer.C:
|
case <-expiryTimer.C:
|
||||||
@@ -138,26 +134,20 @@ func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrFetchToken = errors.New("cannot fetch token")
|
|
||||||
ErrFetchPortForwarding = errors.New("cannot fetch port forwarding data")
|
|
||||||
ErrPersistPortForwarding = errors.New("cannot persist port forwarding data")
|
|
||||||
)
|
|
||||||
|
|
||||||
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||||
gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
|
gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
|
||||||
data.Token, err = fetchToken(ctx, client, authFilePath)
|
data.Token, err = fetchToken(ctx, client, authFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("%w: %s", ErrFetchToken, err)
|
return data, fmt.Errorf("cannot fetch token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err)
|
return data, fmt.Errorf("cannot fetch port forwarding data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writePIAPortForwardData(portForwardPath, data); err != nil {
|
if err := writePIAPortForwardData(portForwardPath, data); err != nil {
|
||||||
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err)
|
return data, fmt.Errorf("cannot persist port forwarding data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
@@ -242,7 +232,6 @@ func packPayload(port uint16, token string, expiration time.Time) (payload strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errGetCredentials = errors.New("cannot get username and password")
|
|
||||||
errEmptyToken = errors.New("token received is empty")
|
errEmptyToken = errors.New("token received is empty")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -250,7 +239,7 @@ func fetchToken(ctx context.Context, client *http.Client,
|
|||||||
authFilePath string) (token string, err error) {
|
authFilePath string) (token string, err error) {
|
||||||
username, password, err := getOpenvpnCredentials(authFilePath)
|
username, password, err := getOpenvpnCredentials(authFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("%w: %s", errGetCredentials, err)
|
return "", fmt.Errorf("cannot get username and password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
errSubstitutions := map[string]string{
|
errSubstitutions := map[string]string{
|
||||||
@@ -284,7 +273,7 @@ func fetchToken(ctx context.Context, client *http.Client,
|
|||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
if err := decoder.Decode(&result); err != nil {
|
if err := decoder.Decode(&result); err != nil {
|
||||||
return "", fmt.Errorf("%w: %s", ErrUnmarshalResponse, err)
|
return "", fmt.Errorf("cannot unmarshal response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.Token == "" {
|
if result.Token == "" {
|
||||||
@@ -294,7 +283,6 @@ func fetchToken(ctx context.Context, client *http.Client,
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errAuthFileRead = errors.New("cannot read OpenVPN authentication file")
|
|
||||||
errAuthFileMalformed = errors.New("authentication file is malformed")
|
errAuthFileMalformed = errors.New("authentication file is malformed")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -302,13 +290,13 @@ func getOpenvpnCredentials(authFilePath string) (
|
|||||||
username, password string, err error) {
|
username, password string, err error) {
|
||||||
file, err := os.Open(authFilePath)
|
file, err := os.Open(authFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
|
return "", "", fmt.Errorf("cannot read OpenVPN authentication file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authData, err := io.ReadAll(file)
|
authData, err := io.ReadAll(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = file.Close()
|
_ = file.Close()
|
||||||
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
|
return "", "", fmt.Errorf("authentication file is malformed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := file.Close(); err != nil {
|
if err := file.Close(); err != nil {
|
||||||
@@ -325,11 +313,6 @@ func getOpenvpnCredentials(authFilePath string) (
|
|||||||
return username, password, nil
|
return username, password, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
errGetSignaturePayload = errors.New("cannot obtain signature payload")
|
|
||||||
errUnpackPayload = errors.New("cannot unpack payload data")
|
|
||||||
)
|
|
||||||
|
|
||||||
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (
|
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (
|
||||||
port uint16, signature string, expiration time.Time, err error) {
|
port uint16, signature string, expiration time.Time, err error) {
|
||||||
errSubstitutions := map[string]string{token: "<token>"}
|
errSubstitutions := map[string]string{token: "<token>"}
|
||||||
@@ -345,13 +328,13 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
|
|||||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = replaceInErr(err, errSubstitutions)
|
err = replaceInErr(err, errSubstitutions)
|
||||||
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err)
|
return 0, "", expiration, fmt.Errorf("cannot obtain signature payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Do(request)
|
response, err := client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = replaceInErr(err, errSubstitutions)
|
err = replaceInErr(err, errSubstitutions)
|
||||||
return 0, "", expiration, fmt.Errorf("%w: %s", errGetSignaturePayload, err)
|
return 0, "", expiration, fmt.Errorf("cannot obtain signature payload: %w", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
@@ -366,7 +349,7 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
|
|||||||
Signature string `json:"signature"`
|
Signature string `json:"signature"`
|
||||||
}
|
}
|
||||||
if err := decoder.Decode(&data); err != nil {
|
if err := decoder.Decode(&data); err != nil {
|
||||||
return 0, "", expiration, fmt.Errorf("%w: %s", ErrUnmarshalResponse, err)
|
return 0, "", expiration, fmt.Errorf("cannot unmarshal response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if data.Status != "OK" {
|
if data.Status != "OK" {
|
||||||
@@ -375,21 +358,19 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway net.
|
|||||||
|
|
||||||
port, _, expiration, err = unpackPayload(data.Payload)
|
port, _, expiration, err = unpackPayload(data.Payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", expiration, fmt.Errorf("%w: %s", errUnpackPayload, err)
|
return 0, "", expiration, fmt.Errorf("cannot unpack payload data: %w", err)
|
||||||
}
|
}
|
||||||
return port, data.Signature, expiration, err
|
return port, data.Signature, expiration, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrSerializePayload = errors.New("cannot serialize payload")
|
|
||||||
ErrUnmarshalResponse = errors.New("cannot unmarshal response")
|
|
||||||
ErrBadResponse = errors.New("bad response received")
|
ErrBadResponse = errors.New("bad response received")
|
||||||
)
|
)
|
||||||
|
|
||||||
func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
|
func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
|
||||||
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
payload, err := packPayload(data.Port, data.Token, data.Expiration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrSerializePayload, err)
|
return fmt.Errorf("cannot serialize payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
queryParams := make(url.Values)
|
queryParams := make(url.Values)
|
||||||
@@ -428,7 +409,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
if err := decoder.Decode(&responseData); err != nil {
|
if err := decoder.Decode(&responseData); err != nil {
|
||||||
return fmt.Errorf("%w: from %s: %s", ErrUnmarshalResponse, url.String(), err)
|
return fmt.Errorf("cannot unmarshal response: from %s: %w", url.String(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseData.Status != "OK" {
|
if responseData.Status != "OK" {
|
||||||
@@ -464,6 +445,7 @@ func makeNOKStatusError(response *http.Response, substitutions map[string]string
|
|||||||
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
|
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
|
||||||
shortenMessage = replaceInString(shortenMessage, substitutions)
|
shortenMessage = replaceInString(shortenMessage, substitutions)
|
||||||
|
|
||||||
return fmt.Errorf("%w: %s: %s: response received: %s",
|
return fmt.Errorf("%w: %s: %d %s: response received: %s",
|
||||||
ErrHTTPStatusCodeNotOK, url, response.Status, shortenMessage)
|
ErrHTTPStatusCodeNotOK, url, response.StatusCode,
|
||||||
|
response.Status, shortenMessage)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +1,16 @@
|
|||||||
package vpnunlimited
|
package vpnunlimited
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrProtocolUnsupported = errors.New("network protocol is not supported")
|
|
||||||
|
|
||||||
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
func (p *Provider) GetConnection(selection settings.ServerSelection) (
|
||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
const port = 1194
|
const port = 1194
|
||||||
const protocol = constants.UDP
|
const protocol = constants.UDP
|
||||||
if *selection.OpenVPN.TCP {
|
|
||||||
return connection, ErrProtocolUnsupported
|
|
||||||
}
|
|
||||||
|
|
||||||
servers, err := p.filterServers(selection)
|
servers, err := p.filterServers(selection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,24 +1,16 @@
|
|||||||
package vyprvpn
|
package vyprvpn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrProtocolUnsupported = errors.New("network protocol is not supported")
|
|
||||||
|
|
||||||
func (v *Vyprvpn) GetConnection(selection settings.ServerSelection) (
|
func (v *Vyprvpn) GetConnection(selection settings.ServerSelection) (
|
||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
const port = 443
|
const port = 443
|
||||||
const protocol = constants.UDP
|
const protocol = constants.UDP
|
||||||
if *selection.OpenVPN.TCP {
|
|
||||||
return connection, fmt.Errorf("%w: TCP for provider VyprVPN", ErrProtocolUnsupported)
|
|
||||||
}
|
|
||||||
|
|
||||||
servers, err := v.filterServers(selection)
|
servers, err := v.filterServers(selection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,5 +4,4 @@ import "errors"
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrBadStatusCode = errors.New("bad HTTP status")
|
ErrBadStatusCode = errors.New("bad HTTP status")
|
||||||
ErrCannotReadBody = errors.New("cannot read response body")
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,12 +56,13 @@ func (f *Fetch) FetchPublicIP(ctx context.Context) (ip net.IP, err error) {
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
|
return nil, fmt.Errorf("%w from %s: %d %s", ErrBadStatusCode,
|
||||||
|
url, response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
content, err := io.ReadAll(response.Body)
|
content, err := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err)
|
return nil, fmt.Errorf("cannot ready response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := strings.ReplaceAll(string(content), "\n", "")
|
s := strings.ReplaceAll(string(content), "\n", "")
|
||||||
|
|||||||
@@ -38,7 +38,8 @@ func Info(ctx context.Context, client *http.Client, ip net.IP) ( //nolint:interf
|
|||||||
case http.StatusTooManyRequests:
|
case http.StatusTooManyRequests:
|
||||||
return result, fmt.Errorf("%w: %s", ErrTooManyRequests, baseURL)
|
return result, fmt.Errorf("%w: %s", ErrTooManyRequests, baseURL)
|
||||||
default:
|
default:
|
||||||
return result, fmt.Errorf("%w: %d", ErrBadHTTPStatus, response.StatusCode)
|
return result, fmt.Errorf("%w: %d %s", ErrBadHTTPStatus,
|
||||||
|
response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type DefaultRouteGetter interface {
|
|||||||
func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
|
return "", nil, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Dst == nil {
|
if route.Dst == nil {
|
||||||
@@ -27,7 +27,7 @@ func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
|
|||||||
linkIndex := route.LinkIndex
|
linkIndex := route.LinkIndex
|
||||||
link, err := r.netLinker.LinkByIndex(linkIndex)
|
link, err := r.netLinker.LinkByIndex(linkIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err)
|
return "", nil, fmt.Errorf("cannot obtain link by index: for default route at index %d: %w", linkIndex, err)
|
||||||
}
|
}
|
||||||
attributes := link.Attrs()
|
attributes := link.Attrs()
|
||||||
defaultInterface = attributes.Name
|
defaultInterface = attributes.Name
|
||||||
@@ -46,7 +46,7 @@ type DefaultIPGetter interface {
|
|||||||
func (r *Routing) DefaultIP() (ip net.IP, err error) {
|
func (r *Routing) DefaultIP() (ip net.IP, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
|
return nil, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultLinkName := ""
|
defaultLinkName := ""
|
||||||
@@ -55,7 +55,7 @@ func (r *Routing) DefaultIP() (ip net.IP, err error) {
|
|||||||
linkIndex := route.LinkIndex
|
linkIndex := route.LinkIndex
|
||||||
link, err := r.netLinker.LinkByIndex(linkIndex)
|
link, err := r.netLinker.LinkByIndex(linkIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err)
|
return nil, fmt.Errorf("cannot find link by index: for default route at index %d: %w", linkIndex, err)
|
||||||
}
|
}
|
||||||
defaultLinkName = link.Attrs().Name
|
defaultLinkName = link.Attrs().Name
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,9 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrDefaultRoute = errors.New("cannot get default route")
|
|
||||||
ErrAddInboundFromDefault = errors.New("cannot add routes for inbound traffic from default IP")
|
|
||||||
ErrDelInboundFromDefault = errors.New("cannot remove routes for inbound traffic from default IP")
|
|
||||||
ErrSubnetsOutboundSet = errors.New("cannot set outbound subnets routes")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Setuper interface {
|
type Setuper interface {
|
||||||
Setup() (err error)
|
Setup() (err error)
|
||||||
}
|
}
|
||||||
@@ -19,7 +11,7 @@ type Setuper interface {
|
|||||||
func (r *Routing) Setup() (err error) {
|
func (r *Routing) Setup() (err error) {
|
||||||
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
|
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrDefaultRoute, err)
|
return fmt.Errorf("cannot get default route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
touched := false
|
touched := false
|
||||||
@@ -35,14 +27,14 @@ func (r *Routing) Setup() (err error) {
|
|||||||
|
|
||||||
err = r.routeInboundFromDefault(defaultGateway, defaultInterfaceName)
|
err = r.routeInboundFromDefault(defaultGateway, defaultInterfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrAddInboundFromDefault, err)
|
return fmt.Errorf("cannot add routes for inbound traffic from default IP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.stateMutex.RLock()
|
r.stateMutex.RLock()
|
||||||
outboundSubnets := r.outboundSubnets
|
outboundSubnets := r.outboundSubnets
|
||||||
r.stateMutex.RUnlock()
|
r.stateMutex.RUnlock()
|
||||||
if err := r.setOutboundRoutes(outboundSubnets, defaultInterfaceName, defaultGateway); err != nil {
|
if err := r.setOutboundRoutes(outboundSubnets, defaultInterfaceName, defaultGateway); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrSubnetsOutboundSet, err)
|
return fmt.Errorf("cannot set outbound subnets routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -55,16 +47,16 @@ type TearDowner interface {
|
|||||||
func (r *Routing) TearDown() error {
|
func (r *Routing) TearDown() error {
|
||||||
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
|
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrDefaultRoute, err)
|
return fmt.Errorf("cannot get default route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.unrouteInboundFromDefault(defaultGateway, defaultInterfaceName)
|
err = r.unrouteInboundFromDefault(defaultGateway, defaultInterfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrDelInboundFromDefault, err)
|
return fmt.Errorf("cannot remove routes for inbound traffic from default IP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.setOutboundRoutes(nil, defaultInterfaceName, defaultGateway); err != nil {
|
if err := r.setOutboundRoutes(nil, defaultInterfaceName, defaultGateway); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrSubnetsOutboundSet, err)
|
return fmt.Errorf("cannot set outbound subnets routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -5,7 +5,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrLinkByIndex = errors.New("cannot obtain link by index")
|
|
||||||
ErrLinkDefaultNotFound = errors.New("default link not found")
|
ErrLinkDefaultNotFound = errors.New("default link not found")
|
||||||
ErrRoutesList = errors.New("cannot list routes")
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
@@ -13,19 +12,15 @@ const (
|
|||||||
inboundPriority = 100
|
inboundPriority = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errDefaultIP = errors.New("cannot get default IP address")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *Routing) routeInboundFromDefault(defaultGateway net.IP,
|
func (r *Routing) routeInboundFromDefault(defaultGateway net.IP,
|
||||||
defaultInterface string) (err error) {
|
defaultInterface string) (err error) {
|
||||||
if err := r.addRuleInboundFromDefault(inboundTable); err != nil {
|
if err := r.addRuleInboundFromDefault(inboundTable); err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRuleAdd, err)
|
return fmt.Errorf("cannot add rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
|
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
|
||||||
if err := r.addRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil {
|
if err := r.addRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRouteAdd, err)
|
return fmt.Errorf("cannot add route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -35,11 +30,11 @@ func (r *Routing) unrouteInboundFromDefault(defaultGateway net.IP,
|
|||||||
defaultInterface string) (err error) {
|
defaultInterface string) (err error) {
|
||||||
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
|
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
|
||||||
if err := r.deleteRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil {
|
if err := r.deleteRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRouteDelete, err)
|
return fmt.Errorf("cannot delete route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.delRuleInboundFromDefault(inboundTable); err != nil {
|
if err := r.delRuleInboundFromDefault(inboundTable); err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRuleDelete, err)
|
return fmt.Errorf("cannot delete rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -48,14 +43,14 @@ func (r *Routing) unrouteInboundFromDefault(defaultGateway net.IP,
|
|||||||
func (r *Routing) addRuleInboundFromDefault(table int) (err error) {
|
func (r *Routing) addRuleInboundFromDefault(table int) (err error) {
|
||||||
defaultIP, err := r.DefaultIP()
|
defaultIP, err := r.DefaultIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errDefaultIP, err)
|
return fmt.Errorf("cannot find default IP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultIPMasked32 := netlink.NewIPNet(defaultIP)
|
defaultIPMasked32 := netlink.NewIPNet(defaultIP)
|
||||||
ruleDstNet := (*net.IPNet)(nil)
|
ruleDstNet := (*net.IPNet)(nil)
|
||||||
err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
|
err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRuleAdd, err)
|
return fmt.Errorf("cannot add rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -64,14 +59,14 @@ func (r *Routing) addRuleInboundFromDefault(table int) (err error) {
|
|||||||
func (r *Routing) delRuleInboundFromDefault(table int) (err error) {
|
func (r *Routing) delRuleInboundFromDefault(table int) (err error) {
|
||||||
defaultIP, err := r.DefaultIP()
|
defaultIP, err := r.DefaultIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errDefaultIP, err)
|
return fmt.Errorf("cannot find default IP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultIPMasked32 := netlink.NewIPNet(defaultIP)
|
defaultIPMasked32 := netlink.NewIPNet(defaultIP)
|
||||||
ruleDstNet := (*net.IPNet)(nil)
|
ruleDstNet := (*net.IPNet)(nil)
|
||||||
err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
|
err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRuleDelete, err)
|
return fmt.Errorf("cannot delete rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -13,18 +13,16 @@ func IPIsPrivate(ip net.IP) bool {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
errInterfaceIPNotFound = errors.New("IP address not found for interface")
|
errInterfaceIPNotFound = errors.New("IP address not found for interface")
|
||||||
errInterfaceListAddr = errors.New("cannot list interface addresses")
|
|
||||||
errInterfaceNotFound = errors.New("network interface not found")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *Routing) assignedIP(interfaceName string) (ip net.IP, err error) {
|
func (r *Routing) assignedIP(interfaceName string) (ip net.IP, err error) {
|
||||||
iface, err := net.InterfaceByName(interfaceName)
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s: %s", errInterfaceNotFound, interfaceName, err)
|
return nil, fmt.Errorf("network interface %s not found: %w", interfaceName, err)
|
||||||
}
|
}
|
||||||
addresses, err := iface.Addrs()
|
addresses, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s: %s", errInterfaceListAddr, interfaceName, err)
|
return nil, fmt.Errorf("cannot list interface %s addresses: %w", interfaceName, err)
|
||||||
}
|
}
|
||||||
for _, address := range addresses {
|
for _, address := range addresses {
|
||||||
switch value := address.(type) {
|
switch value := address.(type) {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrLinkList = errors.New("cannot list links")
|
|
||||||
ErrLinkLocalNotFound = errors.New("local link not found")
|
ErrLinkLocalNotFound = errors.New("local link not found")
|
||||||
ErrSubnetDefaultNotFound = errors.New("default subnet not found")
|
ErrSubnetDefaultNotFound = errors.New("default subnet not found")
|
||||||
ErrSubnetLocalNotFound = errors.New("local subnet not found")
|
ErrSubnetLocalNotFound = errors.New("local subnet not found")
|
||||||
@@ -28,7 +27,7 @@ type LocalSubnetGetter interface {
|
|||||||
func (r *Routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
func (r *Routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return defaultSubnet, fmt.Errorf("%w: %s", ErrRoutesList, err)
|
return defaultSubnet, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultLinkIndex := -1
|
defaultLinkIndex := -1
|
||||||
@@ -61,7 +60,7 @@ type LocalNetworksGetter interface {
|
|||||||
func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
||||||
links, err := r.netLinker.LinkList()
|
links, err := r.netLinker.LinkList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return localNetworks, fmt.Errorf("%w: %s", ErrLinkList, err)
|
return localNetworks, fmt.Errorf("cannot list links: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
localLinks := make(map[int]struct{})
|
localLinks := make(map[int]struct{})
|
||||||
@@ -81,7 +80,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
|||||||
|
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_V4)
|
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_V4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return localNetworks, fmt.Errorf("%w: %s", ErrRoutesList, err)
|
return localNetworks, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
@@ -98,7 +97,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
|||||||
|
|
||||||
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
|
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return localNetworks, fmt.Errorf("%w: at index %d: %s", ErrLinkByIndex, route.LinkIndex, err)
|
return localNetworks, fmt.Errorf("cannot find link at index %d: %w", route.LinkIndex, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
localNet.InterfaceName = link.Attrs().Name
|
localNet.InterfaceName = link.Attrs().Name
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
@@ -13,10 +12,6 @@ const (
|
|||||||
outboundPriority = 99
|
outboundPriority = 99
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errAddOutboundSubnet = errors.New("cannot add outbound subnet to routes")
|
|
||||||
)
|
|
||||||
|
|
||||||
type OutboundRoutesSetter interface {
|
type OutboundRoutesSetter interface {
|
||||||
SetOutboundRoutes(outboundSubnets []net.IPNet) error
|
SetOutboundRoutes(outboundSubnets []net.IPNet) error
|
||||||
}
|
}
|
||||||
@@ -48,7 +43,7 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet,
|
|||||||
|
|
||||||
err = r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway)
|
err = r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errAddOutboundSubnet, err)
|
return fmt.Errorf("cannot add outbound subnet to routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -68,7 +63,7 @@ func (r *Routing) removeOutboundSubnets(subnets []net.IPNet,
|
|||||||
err = r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
err = r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
warnings = append(warnings,
|
warnings = append(warnings,
|
||||||
errRuleDelete.Error()+": for subnet "+subNet.String()+": "+err.Error())
|
"cannot delete rule: for subnet "+subNet.String()+": "+err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,16 +78,14 @@ func (r *Routing) addOutboundSubnets(subnets []net.IPNet,
|
|||||||
for i, subnet := range subnets {
|
for i, subnet := range subnets {
|
||||||
err := r.addRouteVia(subnet, defaultGateway, defaultInterfaceName, outboundTable)
|
err := r.addRouteVia(subnet, defaultGateway, defaultInterfaceName, outboundTable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: for subnet %s: %s",
|
return fmt.Errorf("cannot add route for subnet %s: %w", subnet, err)
|
||||||
errRouteAdd, subnet, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleSrcNet := (*net.IPNet)(nil)
|
ruleSrcNet := (*net.IPNet)(nil)
|
||||||
ruleDstNet := &subnets[i]
|
ruleDstNet := &subnets[i]
|
||||||
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: for subnet %s: %s",
|
return fmt.Errorf("cannot add rule: for subnet %s: %w", subnet, err)
|
||||||
errRuleAdd, subnet, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.outboundSubnets = append(r.outboundSubnets, subnet)
|
r.outboundSubnets = append(r.outboundSubnets, subnet)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -9,12 +8,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errLinkByName = errors.New("cannot obtain link by name")
|
|
||||||
errRouteAdd = errors.New("cannot add route")
|
|
||||||
errRouteDelete = errors.New("cannot delete route")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
|
func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
|
||||||
iface string, table int) error {
|
iface string, table int) error {
|
||||||
destinationStr := destination.String()
|
destinationStr := destination.String()
|
||||||
@@ -26,7 +19,7 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
|
|||||||
|
|
||||||
link, err := r.netLinker.LinkByName(iface)
|
link, err := r.netLinker.LinkByName(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: interface %s: %s", errLinkByName, iface, err)
|
return fmt.Errorf("cannot find link for interface %s: %w", iface, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
route := netlink.Route{
|
route := netlink.Route{
|
||||||
@@ -36,8 +29,8 @@ func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP,
|
|||||||
Table: table,
|
Table: table,
|
||||||
}
|
}
|
||||||
if err := r.netLinker.RouteReplace(&route); err != nil {
|
if err := r.netLinker.RouteReplace(&route); err != nil {
|
||||||
return fmt.Errorf("%w: for subnet %s at interface %s",
|
return fmt.Errorf("cannot replace route for subnet %s at interface %s: %w",
|
||||||
err, destinationStr, iface)
|
destinationStr, iface, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -54,7 +47,7 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP,
|
|||||||
|
|
||||||
link, err := r.netLinker.LinkByName(iface)
|
link, err := r.netLinker.LinkByName(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: for interface %s: %s", errLinkByName, iface, err)
|
return fmt.Errorf("cannot find link for interface %s: %w", iface, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
route := netlink.Route{
|
route := netlink.Route{
|
||||||
@@ -64,8 +57,8 @@ func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP,
|
|||||||
Table: table,
|
Table: table,
|
||||||
}
|
}
|
||||||
if err := r.netLinker.RouteDel(&route); err != nil {
|
if err := r.netLinker.RouteDel(&route); err != nil {
|
||||||
return fmt.Errorf("%w: for subnet %s at interface %s",
|
return fmt.Errorf("cannot delete route: for subnet %s at interface %s: %w",
|
||||||
err, destinationStr, iface)
|
destinationStr, iface, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,19 +2,12 @@ package routing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errRulesList = errors.New("cannot list rules")
|
|
||||||
errRuleAdd = errors.New("cannot add rule")
|
|
||||||
errRuleDelete = errors.New("cannot delete rule")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
|
func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
|
||||||
const add = true
|
const add = true
|
||||||
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
|
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
|
||||||
@@ -27,7 +20,7 @@ func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
|
|||||||
|
|
||||||
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRulesList, err)
|
return fmt.Errorf("cannot list rules: %w", err)
|
||||||
}
|
}
|
||||||
for i := range existingRules {
|
for i := range existingRules {
|
||||||
if !rulesAreEqual(&existingRules[i], rule) {
|
if !rulesAreEqual(&existingRules[i], rule) {
|
||||||
@@ -37,7 +30,7 @@ func (r *Routing) addIPRule(src, dst *net.IPNet, table, priority int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.netLinker.RuleAdd(rule); err != nil {
|
if err := r.netLinker.RuleAdd(rule); err != nil {
|
||||||
return fmt.Errorf("%w: for rule: %s", err, rule)
|
return fmt.Errorf("cannot add rule %s: %w", rule, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -54,14 +47,14 @@ func (r *Routing) deleteIPRule(src, dst *net.IPNet, table, priority int) error {
|
|||||||
|
|
||||||
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errRulesList, err)
|
return fmt.Errorf("cannot list rules: %w", err)
|
||||||
}
|
}
|
||||||
for i := range existingRules {
|
for i := range existingRules {
|
||||||
if !rulesAreEqual(&existingRules[i], rule) {
|
if !rulesAreEqual(&existingRules[i], rule) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := r.netLinker.RuleDel(rule); err != nil {
|
if err := r.netLinker.RuleDel(rule); err != nil {
|
||||||
return fmt.Errorf("%w: for rule: %s", err, rule)
|
return fmt.Errorf("cannot delete rule %s: %w", rule, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func Test_Routing_addIPRule(t *testing.T) {
|
|||||||
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
|
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
|
||||||
err: errDummy,
|
err: errDummy,
|
||||||
},
|
},
|
||||||
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"),
|
err: errors.New("cannot add rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"),
|
||||||
},
|
},
|
||||||
"add rule success": {
|
"add rule success": {
|
||||||
src: makeIPNet(t, 1),
|
src: makeIPNet(t, 1),
|
||||||
@@ -193,7 +193,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
|||||||
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
|
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
|
||||||
err: errDummy,
|
err: errDummy,
|
||||||
},
|
},
|
||||||
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"),
|
err: errors.New("cannot delete rule ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99: dummy error"),
|
||||||
},
|
},
|
||||||
"rule deleted": {
|
"rule deleted": {
|
||||||
src: makeIPNet(t, 1),
|
src: makeIPNet(t, 1),
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type VPNDestinationIPGetter interface {
|
|||||||
func (r *Routing) VPNDestinationIP() (ip net.IP, err error) {
|
func (r *Routing) VPNDestinationIP() (ip net.IP, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
|
return nil, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultLinkIndex := -1
|
defaultLinkIndex := -1
|
||||||
@@ -53,12 +53,12 @@ type VPNLocalGatewayIPGetter interface {
|
|||||||
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
|
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
|
return nil, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
|
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err)
|
return nil, fmt.Errorf("cannot find link at index %d: %w", route.LinkIndex, err)
|
||||||
}
|
}
|
||||||
interfaceName := link.Attrs().Name
|
interfaceName := link.Attrs().Name
|
||||||
if interfaceName == vpnIntf &&
|
if interfaceName == vpnIntf &&
|
||||||
|
|||||||
@@ -35,8 +35,6 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errDecodeVersions = errors.New("cannot decode versions")
|
|
||||||
errDecodeServers = errors.New("cannot decode servers")
|
|
||||||
errDecodeProvider = errors.New("cannot decode servers for provider")
|
errDecodeProvider = errors.New("cannot decode servers for provider")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,12 +42,12 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
|
|||||||
servers models.AllServers, err error) {
|
servers models.AllServers, err error) {
|
||||||
var versions allVersions
|
var versions allVersions
|
||||||
if err := json.Unmarshal(b, &versions); err != nil {
|
if err := json.Unmarshal(b, &versions); err != nil {
|
||||||
return servers, fmt.Errorf("%w: %s", errDecodeVersions, err)
|
return servers, fmt.Errorf("cannot decode versions: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rawMessages allJSONRawMessages
|
var rawMessages allJSONRawMessages
|
||||||
if err := json.Unmarshal(b, &rawMessages); err != nil {
|
if err := json.Unmarshal(b, &rawMessages); err != nil {
|
||||||
return servers, fmt.Errorf("%w: %s", errDecodeServers, err)
|
return servers, fmt.Errorf("cannot decode servers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO simplify with generics in Go 1.18
|
// TODO simplify with generics in Go 1.18
|
||||||
|
|||||||
@@ -1,18 +1,12 @@
|
|||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrCannotReadFile = errors.New("cannot read servers from file")
|
|
||||||
ErrCannotWriteFile = errors.New("cannot write servers to file")
|
|
||||||
)
|
|
||||||
|
|
||||||
func countServers(allServers models.AllServers) int {
|
func countServers(allServers models.AllServers) int {
|
||||||
return len(allServers.Cyberghost.Servers) +
|
return len(allServers.Cyberghost.Servers) +
|
||||||
len(allServers.Expressvpn.Servers) +
|
len(allServers.Expressvpn.Servers) +
|
||||||
@@ -39,7 +33,7 @@ func countServers(allServers models.AllServers) int {
|
|||||||
func (s *Storage) SyncServers() (err error) {
|
func (s *Storage) SyncServers() (err error) {
|
||||||
serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers)
|
serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrCannotReadFile, err)
|
return fmt.Errorf("cannot read servers from file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hardcodedCount := countServers(s.hardcodedServers)
|
hardcodedCount := countServers(s.hardcodedServers)
|
||||||
@@ -64,7 +58,7 @@ func (s *Storage) SyncServers() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := flushToFile(s.filepath, s.mergedServers); err != nil {
|
if err := flushToFile(s.filepath, s.mergedServers); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrCannotWriteFile, err)
|
return fmt.Errorf("cannot write servers to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,24 +12,21 @@ type Checker interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrTUNNotAvailable = errors.New("TUN device is not available")
|
|
||||||
ErrTUNStat = errors.New("cannot stat TUN file")
|
|
||||||
ErrTUNInfo = errors.New("cannot get syscall stat info of TUN file")
|
ErrTUNInfo = errors.New("cannot get syscall stat info of TUN file")
|
||||||
ErrTUNBadRdev = errors.New("TUN file has an unexpected rdev")
|
ErrTUNBadRdev = errors.New("TUN file has an unexpected rdev")
|
||||||
ErrTUNClose = errors.New("cannot close TUN device")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check checks the tunnel device specified by path is present and accessible.
|
// Check checks the tunnel device specified by path is present and accessible.
|
||||||
func (t *Tun) Check(path string) error {
|
func (t *Tun) Check(path string) error {
|
||||||
f, err := os.OpenFile(path, os.O_RDWR, 0)
|
f, err := os.OpenFile(path, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrTUNNotAvailable, err)
|
return fmt.Errorf("TUN device is not available: %w", err)
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
info, err := f.Stat()
|
info, err := f.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrTUNStat, err)
|
return fmt.Errorf("cannot stat TUN file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sys, ok := info.Sys().(*syscall.Stat_t)
|
sys, ok := info.Sys().(*syscall.Stat_t)
|
||||||
@@ -44,7 +41,7 @@ func (t *Tun) Check(path string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := f.Close(); err != nil {
|
if err := f.Close(); err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrTUNClose, err)
|
return fmt.Errorf("cannot close TUN device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package tun
|
package tun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -13,12 +12,6 @@ type Creator interface {
|
|||||||
Create(path string) error
|
Create(path string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrMknod = errors.New("cannot create TUN device file node")
|
|
||||||
ErrUnixOpen = errors.New("cannot Unix Open TUN device file")
|
|
||||||
ErrSetNonBlock = errors.New("cannot set non block to TUN device file descriptor")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Create creates a TUN device at the path specified.
|
// Create creates a TUN device at the path specified.
|
||||||
func (t *Tun) Create(path string) error {
|
func (t *Tun) Create(path string) error {
|
||||||
parentDir := filepath.Dir(path)
|
parentDir := filepath.Dir(path)
|
||||||
@@ -33,18 +26,18 @@ func (t *Tun) Create(path string) error {
|
|||||||
dev := unix.Mkdev(major, minor)
|
dev := unix.Mkdev(major, minor)
|
||||||
err := t.mknod(path, unix.S_IFCHR, int(dev))
|
err := t.mknod(path, unix.S_IFCHR, int(dev))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrMknod, err)
|
return fmt.Errorf("cannot create TUN device file node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := unix.Open(path, 0, 0)
|
fd, err := unix.Open(path, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrUnixOpen, err)
|
return fmt.Errorf("cannot Unix Open TUN device file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const nonBlocking = true
|
const nonBlocking = true
|
||||||
err = unix.SetNonblock(fd, nonBlocking)
|
err = unix.SetNonblock(fd, nonBlocking)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", ErrSetNonBlock, err)
|
return fmt.Errorf("cannot set non block to TUN device file descriptor: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -9,11 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errBuildRequest = errors.New("cannot build HTTP request")
|
|
||||||
errDoRequest = errors.New("failed doing HTTP request")
|
|
||||||
errHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
errHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
||||||
errUnmarshalResponseBody = errors.New("failed unmarshaling response body")
|
|
||||||
errCloseBody = errors.New("failed closing HTTP body")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type apiData struct {
|
type apiData struct {
|
||||||
@@ -40,12 +36,12 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
|||||||
|
|
||||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("%w: %s", errBuildRequest, err)
|
return data, err
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Do(request)
|
response, err := client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("%w: %s", errDoRequest, err)
|
return data, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
@@ -57,11 +53,11 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
|||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
if err := decoder.Decode(&data); err != nil {
|
if err := decoder.Decode(&data); err != nil {
|
||||||
_ = response.Body.Close()
|
_ = response.Body.Close()
|
||||||
return data, fmt.Errorf("%w: %s", errUnmarshalResponseBody, err)
|
return data, fmt.Errorf("failed unmarshaling response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := response.Body.Close(); err != nil {
|
if err := response.Body.Close(); err != nil {
|
||||||
return data, fmt.Errorf("%w: %s", errCloseBody, err)
|
return data, fmt.Errorf("cannot close response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrFetchAPI = errors.New("failed fetching API")
|
|
||||||
ErrNotEnoughServers = errors.New("not enough servers found")
|
ErrNotEnoughServers = errors.New("not enough servers found")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,7 +22,7 @@ func GetServers(ctx context.Context, client *http.Client,
|
|||||||
servers []models.IvpnServer, warnings []string, err error) {
|
servers []models.IvpnServer, warnings []string, err error) {
|
||||||
data, err := fetchAPI(ctx, client)
|
data, err := fetchAPI(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: %s", ErrFetchAPI, err)
|
return nil, nil, fmt.Errorf("failed fetching API: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts := make([]string, 0, len(data.Servers))
|
hosts := make([]string, 0, len(data.Servers))
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (data []serverData, err
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrHTTPStatusCodeNotOK, response.Status)
|
return nil, fmt.Errorf("%w: %d %s", ErrHTTPStatusCodeNotOK,
|
||||||
|
response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
||||||
ErrUnmarshalResponseBody = errors.New("failed unmarshaling response body")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type serverData struct {
|
type serverData struct {
|
||||||
@@ -44,7 +43,7 @@ func fetchAPI(ctx context.Context, client *http.Client) (data []serverData, err
|
|||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
if err := decoder.Decode(&data); err != nil {
|
if err := decoder.Decode(&data); err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrUnmarshalResponseBody, err)
|
return nil, fmt.Errorf("failed unmarshaling response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := response.Body.Close(); err != nil {
|
if err := response.Body.Close(); err != nil {
|
||||||
|
|||||||
@@ -51,7 +51,8 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
return data, fmt.Errorf("%w: %s", ErrHTTPStatusCodeNotOK, response.Status)
|
return data, fmt.Errorf("%w: %d %s", ErrHTTPStatusCodeNotOK,
|
||||||
|
response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := io.ReadAll(response.Body)
|
b, err := io.ReadAll(response.Body)
|
||||||
|
|||||||
@@ -44,8 +44,7 @@ func parseFilename(fileName string) (
|
|||||||
parts := strings.Split(s, "-")
|
parts := strings.Split(s, "-")
|
||||||
const minParts = 2
|
const minParts = 2
|
||||||
if len(parts) < minParts {
|
if len(parts) < minParts {
|
||||||
return "", "", fmt.Errorf("%w: %s",
|
return "", "", fmt.Errorf("%w: %s", errNotEnoughParts, fileName)
|
||||||
errNotEnoughParts, fileName)
|
|
||||||
}
|
}
|
||||||
countryCode, city = parts[0], parts[1]
|
countryCode, city = parts[0], parts[1]
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
||||||
ErrUnmarshalResponseBody = errors.New("failed unmarshaling response body")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type apiData struct {
|
type apiData struct {
|
||||||
@@ -49,12 +48,13 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
return data, fmt.Errorf("%w: %s", ErrHTTPStatusCodeNotOK, response.Status)
|
return data, fmt.Errorf("%w: %d %s", ErrHTTPStatusCodeNotOK,
|
||||||
|
response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
if err := decoder.Decode(&data); err != nil {
|
if err := decoder.Decode(&data); err != nil {
|
||||||
return data, fmt.Errorf("%w: %s", ErrUnmarshalResponseBody, err)
|
return data, fmt.Errorf("failed unmarshaling response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := response.Body.Close(); err != nil {
|
if err := response.Body.Close(); err != nil {
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ func addServersFromAPI(ctx context.Context, client *http.Client,
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
||||||
ErrUnmarshalResponseBody = errors.New("failed unmarshaling response body")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type serverData struct {
|
type serverData struct {
|
||||||
@@ -66,7 +65,7 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
|||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
if err := decoder.Decode(&servers); err != nil {
|
if err := decoder.Decode(&servers); err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrUnmarshalResponseBody, err)
|
return nil, fmt.Errorf("failed unmarshaling response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := response.Body.Close(); err != nil {
|
if err := response.Body.Close(); err != nil {
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrGetZip = errors.New("cannot get OpenVPN ZIP file")
|
|
||||||
ErrGetAPI = errors.New("cannot fetch server information from API")
|
|
||||||
ErrNotEnoughServers = errors.New("not enough servers found")
|
ErrNotEnoughServers = errors.New("not enough servers found")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,12 +24,12 @@ func GetServers(ctx context.Context, unzipper unzip.Unzipper,
|
|||||||
|
|
||||||
err = addServersFromAPI(ctx, client, hts)
|
err = addServersFromAPI(ctx, client, hts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: %s", ErrGetAPI, err)
|
return nil, nil, fmt.Errorf("cannot fetch server information from API: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
warnings, err = addOpenVPNServersFromZip(ctx, unzipper, hts)
|
warnings, err = addOpenVPNServersFromZip(ctx, unzipper, hts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: %s", ErrGetZip, err)
|
return nil, nil, fmt.Errorf("cannot get OpenVPN ZIP file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
getRemainingServers(hts)
|
getRemainingServers(hts)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
ErrHTTPStatusCodeNotOK = errors.New("HTTP status code not OK")
|
||||||
ErrUnmarshalResponseBody = errors.New("failed unmarshaling response body")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type apiData struct {
|
type apiData struct {
|
||||||
@@ -63,7 +62,7 @@ func fetchAPI(ctx context.Context, client *http.Client) (
|
|||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
if err := decoder.Decode(&data); err != nil {
|
if err := decoder.Decode(&data); err != nil {
|
||||||
return data, fmt.Errorf("%w: %s", ErrUnmarshalResponseBody, err)
|
return data, fmt.Errorf("failed unmarshaling response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ func (r *repeat) Resolve(ctx context.Context, host string, settings RepeatSettin
|
|||||||
var (
|
var (
|
||||||
ErrMaxNoNew = errors.New("reached the maximum number of no new update")
|
ErrMaxNoNew = errors.New("reached the maximum number of no new update")
|
||||||
ErrMaxFails = errors.New("reached the maximum number of consecutive failures")
|
ErrMaxFails = errors.New("reached the maximum number of consecutive failures")
|
||||||
ErrTimeout = errors.New("reached the timeout")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *repeat) resolveOnce(ctx, timedCtx context.Context, host string,
|
func (r *repeat) resolveOnce(ctx, timedCtx context.Context, host string,
|
||||||
@@ -120,7 +119,7 @@ func (r *repeat) resolveOnce(ctx, timedCtx context.Context, host string,
|
|||||||
return noNewCounter, failCounter, err
|
return noNewCounter, failCounter, err
|
||||||
}
|
}
|
||||||
return noNewCounter, failCounter,
|
return noNewCounter, failCounter,
|
||||||
fmt.Errorf("%w: %s", ErrTimeout, timedCtx.Err())
|
fmt.Errorf("reached the timeout: %w", timedCtx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ func (u *unzipper) FetchAndExtract(ctx context.Context, url string) (
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("%w: %s for %s", ErrHTTPStatusCodeNotOK, response.Status, url)
|
return nil, fmt.Errorf("%w: %s: %d %s", ErrHTTPStatusCodeNotOK,
|
||||||
|
url, response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := io.ReadAll(response.Body)
|
b, err := io.ReadAll(response.Body)
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ func getGithubReleases(ctx context.Context, client *http.Client) (releases []git
|
|||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
if response.StatusCode != http.StatusOK {
|
if response.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("%w: %s", errHTTPStatusCode, response.Status)
|
return nil, fmt.Errorf("%w: %d %s", errHTTPStatusCode,
|
||||||
|
response.StatusCode, response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder := json.NewDecoder(response.Body)
|
decoder := json.NewDecoder(response.Body)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package vpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
@@ -12,14 +11,6 @@ import (
|
|||||||
"github.com/qdm12/golibs/command"
|
"github.com/qdm12/golibs/command"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errServerConn = errors.New("failed finding a valid server connection")
|
|
||||||
errBuildConfig = errors.New("failed building configuration")
|
|
||||||
errWriteConfig = errors.New("failed writing configuration to file")
|
|
||||||
errWriteAuth = errors.New("failed writing auth to file")
|
|
||||||
errFirewall = errors.New("failed allowing VPN connection through firewall")
|
|
||||||
)
|
|
||||||
|
|
||||||
// setupOpenVPN sets OpenVPN up using the configurators and settings given.
|
// setupOpenVPN sets OpenVPN up using the configurators and settings given.
|
||||||
// It returns a serverName for port forwarding (PIA) and an error if it fails.
|
// It returns a serverName for port forwarding (PIA) and an error if it fails.
|
||||||
func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
|
func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
|
||||||
@@ -28,27 +19,27 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
|
|||||||
runner vpnRunner, serverName string, err error) {
|
runner vpnRunner, serverName string, err error) {
|
||||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
|
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errServerConn, err)
|
return nil, "", fmt.Errorf("failed finding a valid server connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
lines, err := providerConf.BuildConf(connection, settings.OpenVPN)
|
lines, err := providerConf.BuildConf(connection, settings.OpenVPN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err)
|
return nil, "", fmt.Errorf("failed building configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := openvpnConf.WriteConfig(lines); err != nil {
|
if err := openvpnConf.WriteConfig(lines); err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errWriteConfig, err)
|
return nil, "", fmt.Errorf("failed writing configuration to file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.OpenVPN.User != "" {
|
if settings.OpenVPN.User != "" {
|
||||||
err := openvpnConf.WriteAuthFile(settings.OpenVPN.User, settings.OpenVPN.Password)
|
err := openvpnConf.WriteAuthFile(settings.OpenVPN.User, settings.OpenVPN.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errWriteAuth, err)
|
return nil, "", fmt.Errorf("failed writing auth to file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
|
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
|
return nil, "", fmt.Errorf("failed allowing VPN connection through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
|
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
|
||||||
|
|||||||
@@ -2,18 +2,12 @@ package vpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/portforward"
|
"github.com/qdm12/gluetun/internal/portforward"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errObtainVPNLocalGateway = errors.New("cannot obtain VPN local gateway IP")
|
|
||||||
errStartPortForwarding = errors.New("cannot start port forwarding")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
|
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
|
||||||
if !data.portForwarding {
|
if !data.portForwarding {
|
||||||
return nil
|
return nil
|
||||||
@@ -22,7 +16,7 @@ func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err
|
|||||||
// only used for PIA for now
|
// only used for PIA for now
|
||||||
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
|
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: for interface %s: %s", errObtainVPNLocalGateway, data.vpnIntf, err)
|
return fmt.Errorf("cannot obtain VPN local gateway IP for interface %s: %w", data.vpnIntf, err)
|
||||||
}
|
}
|
||||||
l.logger.Info("VPN gateway IP address: " + gateway.String())
|
l.logger.Info("VPN gateway IP address: " + gateway.String())
|
||||||
|
|
||||||
@@ -34,7 +28,7 @@ func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err
|
|||||||
}
|
}
|
||||||
_, err = l.portForward.Start(ctx, pfData)
|
_, err = l.portForward.Start(ctx, pfData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errStartPortForwarding, err)
|
return fmt.Errorf("cannot start port forwarding: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package vpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
@@ -13,11 +12,6 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/wireguard"
|
"github.com/qdm12/gluetun/internal/wireguard"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errGetServer = errors.New("failed finding a VPN server")
|
|
||||||
errCreateWireguard = errors.New("failed creating Wireguard")
|
|
||||||
)
|
|
||||||
|
|
||||||
// setupWireguard sets Wireguard up using the configurators and settings given.
|
// setupWireguard sets Wireguard up using the configurators and settings given.
|
||||||
// It returns a serverName for port forwarding (PIA) and an error if it fails.
|
// It returns a serverName for port forwarding (PIA) and an error if it fails.
|
||||||
func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
|
func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
|
||||||
@@ -26,7 +20,7 @@ func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
|
|||||||
wireguarder wireguard.Wireguarder, serverName string, err error) {
|
wireguarder wireguard.Wireguarder, serverName string, err error) {
|
||||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
|
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errGetServer, err)
|
return nil, "", fmt.Errorf("failed finding a VPN server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard)
|
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard)
|
||||||
@@ -37,12 +31,12 @@ func setupWireguard(ctx context.Context, netlinker netlink.NetLinker,
|
|||||||
|
|
||||||
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
|
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errCreateWireguard, err)
|
return nil, "", fmt.Errorf("failed creating Wireguard: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
|
return nil, "", fmt.Errorf("failed setting firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return wireguarder, connection.Hostname, nil
|
return wireguarder, connection.Hostname, nil
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
@@ -9,20 +8,15 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errMakeConfig = errors.New("cannot make device configuration")
|
|
||||||
errConfigureDevice = errors.New("cannot configure device")
|
|
||||||
)
|
|
||||||
|
|
||||||
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
func configureDevice(client *wgctrl.Client, settings Settings) (err error) {
|
||||||
deviceConfig, err := makeDeviceConfig(settings)
|
deviceConfig, err := makeDeviceConfig(settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errMakeConfig, err)
|
return fmt.Errorf("cannot make device configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.ConfigureDevice(settings.InterfaceName, deviceConfig)
|
err = client.ConfigureDevice(settings.InterfaceName, deviceConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %s", errConfigureDevice, err)
|
return fmt.Errorf("cannot configure device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,27 +1,21 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
errLinkList = errors.New("cannot list links")
|
|
||||||
errRouteList = errors.New("cannot list routes")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (w *Wireguard) isIPv6Supported() (supported bool, err error) {
|
func (w *Wireguard) isIPv6Supported() (supported bool, err error) {
|
||||||
links, err := w.netlink.LinkList()
|
links, err := w.netlink.LinkList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("%w: %s", errLinkList, err)
|
return false, fmt.Errorf("cannot list links: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, link := range links {
|
for _, link := range links {
|
||||||
routes, err := w.netlink.RouteList(link, netlink.FAMILY_V6)
|
routes, err := w.netlink.RouteList(link, netlink.FAMILY_V6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("%w: %s", errRouteList, err)
|
return false, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(routes) > 0 {
|
if len(routes) > 0 {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
|
|||||||
|
|
||||||
err = w.netlink.RouteAdd(route)
|
err = w.netlink.RouteAdd(route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: when adding route: %s", err, route)
|
return fmt.Errorf("cannot add route %s: %w", route, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
|||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
},
|
},
|
||||||
routeAddErr: errDummy,
|
routeAddErr: errDummy,
|
||||||
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820 Realm: 0}"), //nolint:lll
|
err: errors.New("cannot add route {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820 Realm: 0}: dummy"), //nolint:lll
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,13 +14,13 @@ func (w *Wireguard) addRule(rulePriority, firewallMark int) (
|
|||||||
rule.Mark = firewallMark
|
rule.Mark = firewallMark
|
||||||
rule.Table = firewallMark
|
rule.Table = firewallMark
|
||||||
if err := w.netlink.RuleAdd(rule); err != nil {
|
if err := w.netlink.RuleAdd(rule); err != nil {
|
||||||
return nil, fmt.Errorf("%w: when adding rule: %s", err, rule)
|
return nil, fmt.Errorf("cannot add rule %s: %w", rule, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup = func() error {
|
cleanup = func() error {
|
||||||
err := w.netlink.RuleDel(rule)
|
err := w.netlink.RuleDel(rule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: when deleting rule: %s", err, rule)
|
return fmt.Errorf("cannot delete rule %s: %w", rule, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func Test_Wireguard_addRule(t *testing.T) {
|
|||||||
SuppressPrefixlen: -1,
|
SuppressPrefixlen: -1,
|
||||||
},
|
},
|
||||||
ruleAddErr: errDummy,
|
ruleAddErr: errDummy,
|
||||||
err: errors.New("dummy: when adding rule: ip rule 987: from all to all table 456"),
|
err: errors.New("cannot add rule ip rule 987: from all to all table 456: dummy"),
|
||||||
},
|
},
|
||||||
"rule delete error": {
|
"rule delete error": {
|
||||||
expectedRule: &netlink.Rule{
|
expectedRule: &netlink.Rule{
|
||||||
@@ -66,7 +66,7 @@ func Test_Wireguard_addRule(t *testing.T) {
|
|||||||
SuppressPrefixlen: -1,
|
SuppressPrefixlen: -1,
|
||||||
},
|
},
|
||||||
ruleDelErr: errDummy,
|
ruleDelErr: errDummy,
|
||||||
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from all to all table 456"),
|
cleanupErr: errors.New("cannot delete rule ip rule 987: from all to all table 456: dummy"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user