chore(portforward): remove PIA dependency on storage package

This commit is contained in:
Quentin McGaw
2024-05-02 09:17:30 +00:00
parent e0a977cf83
commit 6dd27e53d4
18 changed files with 63 additions and 98 deletions

View File

@@ -16,17 +16,20 @@ type Connection struct {
// Hostname is used for IPVanish, IVPN, Privado // Hostname is used for IPVanish, IVPN, Privado
// and Windscribe for TLS verification. // and Windscribe for TLS verification.
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
// ServerName is used for PIA for port forwarding
ServerName string `json:"server_name,omitempty"`
// PubKey is the public key of the VPN server, // PubKey is the public key of the VPN server,
// used only for Wireguard. // used only for Wireguard.
PubKey string `json:"pubkey"` PubKey string `json:"pubkey"`
// ServerName is used for PIA for port forwarding
ServerName string `json:"server_name,omitempty"`
// PortForward is used for PIA for port forwarding
PortForward bool `json:"port_forward"`
} }
func (c *Connection) Equal(other Connection) bool { func (c *Connection) Equal(other Connection) bool {
return c.IP.Compare(other.IP) == 0 && c.Port == other.Port && return c.IP.Compare(other.IP) == 0 && c.Port == other.Port &&
c.Protocol == other.Protocol && c.Hostname == other.Hostname && c.Protocol == other.Protocol && c.Hostname == other.Hostname &&
c.ServerName == other.ServerName && c.PubKey == other.PubKey c.PubKey == other.PubKey && c.ServerName == other.ServerName &&
c.PortForward == other.PortForward
} }
// UpdateEmptyWith updates each field of the connection where the // UpdateEmptyWith updates each field of the connection where the

View File

@@ -9,12 +9,13 @@ import (
) )
type Settings struct { type Settings struct {
Enabled *bool Enabled *bool
PortForwarder PortForwarder PortForwarder PortForwarder
Filepath string Filepath string
Interface string // needed for PIA and ProtonVPN, tun0 for example Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA ServerName string // needed for PIA
ListeningPort uint16 CanPortForward bool // needed for PIA
ListeningPort uint16
} }
func (s Settings) Copy() (copied Settings) { func (s Settings) Copy() (copied Settings) {
@@ -23,6 +24,7 @@ func (s Settings) Copy() (copied Settings) {
copied.Filepath = s.Filepath copied.Filepath = s.Filepath
copied.Interface = s.Interface copied.Interface = s.Interface
copied.ServerName = s.ServerName copied.ServerName = s.ServerName
copied.CanPortForward = s.CanPortForward
copied.ListeningPort = s.ListeningPort copied.ListeningPort = s.ListeningPort
return copied return copied
} }
@@ -33,6 +35,7 @@ func (s *Settings) OverrideWith(update Settings) {
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath) s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface) s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName) s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort) s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort)
} }

View File

@@ -23,10 +23,11 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
} }
obj := utils.PortForwardObjects{ obj := utils.PortForwardObjects{
Logger: s.logger, Logger: s.logger,
Gateway: gateway, Gateway: gateway,
Client: s.client, Client: s.client,
ServerName: s.settings.ServerName, ServerName: s.settings.ServerName,
CanPortForward: s.settings.CanPortForward,
} }
port, err := s.settings.PortForwarder.PortForward(ctx, obj) port, err := s.settings.PortForwarder.PortForward(ctx, obj)
if err != nil { if err != nil {

View File

@@ -92,21 +92,6 @@ func (mr *MockStorageMockRecorder) FilterServers(arg0, arg1 interface{}) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterServers", reflect.TypeOf((*MockStorage)(nil).FilterServers), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterServers", reflect.TypeOf((*MockStorage)(nil).FilterServers), arg0, arg1)
} }
// GetServerByName mocks base method.
func (m *MockStorage) GetServerByName(arg0, arg1 string) (models.Server, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServerByName", arg0, arg1)
ret0, _ := ret[0].(models.Server)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// GetServerByName indicates an expected call of GetServerByName.
func (mr *MockStorageMockRecorder) GetServerByName(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerByName", reflect.TypeOf((*MockStorage)(nil).GetServerByName), arg0, arg1)
}
// MockUnzipper is a mock of Unzipper interface. // MockUnzipper is a mock of Unzipper interface.
type MockUnzipper struct { type MockUnzipper struct {
ctrl *gomock.Controller ctrl *gomock.Controller

View File

@@ -8,5 +8,4 @@ import (
type Storage interface { type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) ( FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error) servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
} }

View File

@@ -44,6 +44,7 @@ func getOpenVPNConnection(extractor Extractor,
// Set the server name for PIA port forwarding code used // Set the server name for PIA port forwarding code used
// together with the custom provider. // together with the custom provider.
connection.ServerName = selection.Names[0] connection.ServerName = selection.Names[0]
connection.PortForward = true
} }
return connection, nil return connection, nil
@@ -62,6 +63,7 @@ func getWireguardConnection(selection settings.ServerSelection) (
// Set the server name for PIA port forwarding code used // Set the server name for PIA port forwarding code used
// together with the custom provider. // together with the custom provider.
connection.ServerName = selection.Names[0] connection.ServerName = selection.Names[0]
connection.PortForward = true
} }
return connection return connection
} }

View File

@@ -16,7 +16,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/golibs/format" "github.com/qdm12/golibs/format"
) )
@@ -37,16 +36,10 @@ func (p *Provider) PortForward(ctx context.Context,
serverName := objects.ServerName serverName := objects.ServerName
server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName)
if !ok {
return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName)
}
logger := objects.Logger logger := objects.Logger
if !server.PortForward { if !objects.CanPortForward {
logger.Error("The server " + serverName + logger.Error("The server " + serverName + " does not support port forwarding")
" (region " + server.Region + ") does not support port forwarding")
return 0, nil return 0, nil
} }

View File

@@ -43,7 +43,6 @@ type Providers struct {
type Storage interface { type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) ( FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error) servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
} }
type Extractor interface { type Extractor interface {

View File

@@ -60,13 +60,14 @@ func GetConnection(provider string,
} }
connection := models.Connection{ connection := models.Connection{
Type: selection.VPN, Type: selection.VPN,
IP: ip, IP: ip,
Port: port, Port: port,
Protocol: protocol, Protocol: protocol,
Hostname: hostname, Hostname: hostname,
ServerName: server.ServerName, ServerName: server.ServerName,
PubKey: server.WgPubKey, // Wireguard PortForward: server.PortForward,
PubKey: server.WgPubKey, // Wireguard
} }
connections = append(connections, connection) connections = append(connections, connection)
} }

View File

@@ -15,11 +15,10 @@ type PortForwardObjects struct {
Gateway netip.Addr Gateway netip.Addr
// Client is used to query the VPN gateway for Private Internet Access. // Client is used to query the VPN gateway for Private Internet Access.
Client *http.Client Client *http.Client
// ServerName is used by Private Internet Access for port forwarding, // ServerName is used by Private Internet Access for port forwarding.
// and to look up the server data from storage.
// TODO use server data directly to remove storage dependency for port
// forwarding implementation.
ServerName string ServerName string
// CanPortForward is used by Private Internet Access for port forwarding.
CanPortForward bool
} }
type Routing interface { type Routing interface {

View File

@@ -33,29 +33,6 @@ func (s *Storage) SetServers(provider string, servers []models.Server) (err erro
return nil return nil
} }
// GetServerByName returns the server for the given provider
// and server name. It returns `ok` as false if the server is
// not found. The returned server is also deep copied so it is
// safe for mutation and/or thread safe use.
func (s *Storage) GetServerByName(provider, name string) (
server models.Server, ok bool) {
if provider == providers.Custom {
return server, false
}
s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()
serversObject := s.getMergedServersObject(provider)
for _, server := range serversObject.Servers {
if server.ServerName == name {
return copyServer(server), true
}
}
return server, false
}
// GetServersCount returns the number of servers for the provider given. // GetServersCount returns the number of servers for the provider given.
func (s *Storage) GetServersCount(provider string) (count int) { func (s *Storage) GetServersCount(provider string) (count int) {
if provider == providers.Custom { if provider == providers.Custom {

View File

@@ -18,7 +18,6 @@ type Storage interface {
ServersAreEqual(provider string, servers []models.Server) (equal bool) ServersAreEqual(provider string, servers []models.Server) (equal bool)
// Extra methods to match the provider.New storage interface // Extra methods to match the provider.New storage interface
FilterServers(provider string, selection settings.ServerSelection) (filtered []models.Server, err error) FilterServers(provider string, selection settings.ServerSelection) (filtered []models.Server, err error)
GetServerByName(provider string, name string) (server models.Server, ok bool)
} }
type Unzipper interface { type Unzipper interface {

View File

@@ -53,7 +53,6 @@ type PortForwarder interface {
type Storage interface { type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error) FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
} }
type NetLinker interface { type NetLinker interface {

View File

@@ -15,37 +15,38 @@ import (
func setupOpenVPN(ctx context.Context, fw Firewall, func setupOpenVPN(ctx context.Context, fw Firewall,
openvpnConf OpenVPN, providerConf provider.Provider, openvpnConf OpenVPN, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, starter command.Starter, settings settings.VPN, ipv6Supported bool, starter command.Starter,
logger openvpn.Logger) (runner *openvpn.Runner, serverName string, err error) { logger openvpn.Logger) (runner *openvpn.Runner, serverName string,
canPortForward bool, err error) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("finding a valid server connection: %w", err) return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err)
} }
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported) lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
if err := openvpnConf.WriteConfig(lines); err != nil { if err := openvpnConf.WriteConfig(lines); err != nil {
return nil, "", fmt.Errorf("writing configuration to file: %w", err) return nil, "", false, fmt.Errorf("writing configuration to file: %w", err)
} }
if *settings.OpenVPN.User != "" { if *settings.OpenVPN.User != "" {
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password) err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("writing auth to file: %w", err) return nil, "", false, fmt.Errorf("writing auth to file: %w", err)
} }
} }
if *settings.OpenVPN.KeyPassphrase != "" { if *settings.OpenVPN.KeyPassphrase != "" {
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase) err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("writing askpass file: %w", err) return nil, "", false, fmt.Errorf("writing askpass file: %w", err)
} }
} }
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil { if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
return nil, "", fmt.Errorf("allowing VPN connection through firewall: %w", err) return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err)
} }
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger) runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
return runner, connection.ServerName, nil return runner, connection.ServerName, connection.PortForward, nil
} }

View File

@@ -26,9 +26,10 @@ func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
partialUpdate := portforward.Settings{ partialUpdate := portforward.Settings{
VPNIsUp: ptrTo(true), VPNIsUp: ptrTo(true),
Service: service.Settings{ Service: service.Settings{
PortForwarder: data.portForwarder, PortForwarder: data.portForwarder,
Interface: data.vpnIntf, Interface: data.vpnIntf,
ServerName: data.serverName, ServerName: data.serverName,
CanPortForward: data.canPortForward,
}, },
} }
return l.portForward.UpdateWith(partialUpdate) return l.portForward.UpdateWith(partialUpdate)

View File

@@ -29,15 +29,16 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{}) Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
} }
var serverName, vpnInterface string var serverName, vpnInterface string
var canPortForward bool
var err error var err error
subLogger := l.logger.New(log.SetComponent(settings.Type)) subLogger := l.logger.New(log.SetComponent(settings.Type))
if settings.Type == vpn.OpenVPN { if settings.Type == vpn.OpenVPN {
vpnInterface = settings.OpenVPN.Interface vpnInterface = settings.OpenVPN.Interface
vpnRunner, serverName, err = setupOpenVPN(ctx, l.fw, vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger) l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard } else { // Wireguard
vpnInterface = settings.Wireguard.Interface vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, err = setupWireguard(ctx, l.netLinker, l.fw, vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
providerConf, settings, l.ipv6Supported, subLogger) providerConf, settings, l.ipv6Supported, subLogger)
} }
if err != nil { if err != nil {
@@ -45,9 +46,10 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue continue
} }
tunnelUpData := tunnelUpData{ tunnelUpData := tunnelUpData{
serverName: serverName, serverName: serverName,
portForwarder: portForwarder, canPortForward: canPortForward,
vpnIntf: vpnInterface, portForwarder: portForwarder,
vpnIntf: vpnInterface,
} }
openvpnCtx, openvpnCancel := context.WithCancel(context.Background()) openvpnCtx, openvpnCancel := context.WithCancel(context.Background())

View File

@@ -9,9 +9,10 @@ import (
type tunnelUpData struct { type tunnelUpData struct {
// Port forwarding // Port forwarding
vpnIntf string vpnIntf string
serverName string serverName string // used for PIA
portForwarder PortForwarder canPortForward bool // used for PIA
portForwarder PortForwarder
} }
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {

View File

@@ -16,10 +16,10 @@ import (
func setupWireguard(ctx context.Context, netlinker NetLinker, func setupWireguard(ctx context.Context, netlinker NetLinker,
fw Firewall, providerConf provider.Provider, fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) ( settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, serverName string, err error) { wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("finding a VPN server: %w", err) return nil, "", false, fmt.Errorf("finding a VPN server: %w", err)
} }
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported) wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
@@ -30,13 +30,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger) wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("creating Wireguard: %w", err) return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
} }
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface) err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("setting firewall: %w", err) return nil, "", false, fmt.Errorf("setting firewall: %w", err)
} }
return wireguarder, connection.ServerName, nil return wireguarder, connection.ServerName, connection.PortForward, nil
} }