chore(portforward): remove PIA dependency on storage package
This commit is contained in:
@@ -16,17 +16,20 @@ type Connection struct {
|
||||
// Hostname is used for IPVanish, IVPN, Privado
|
||||
// and Windscribe for TLS verification.
|
||||
Hostname string `json:"hostname"`
|
||||
// ServerName is used for PIA for port forwarding
|
||||
ServerName string `json:"server_name,omitempty"`
|
||||
// PubKey is the public key of the VPN server,
|
||||
// used only for Wireguard.
|
||||
PubKey string `json:"pubkey"`
|
||||
// ServerName is used for PIA for port forwarding
|
||||
ServerName string `json:"server_name,omitempty"`
|
||||
// PortForward is used for PIA for port forwarding
|
||||
PortForward bool `json:"port_forward"`
|
||||
}
|
||||
|
||||
func (c *Connection) Equal(other Connection) bool {
|
||||
return c.IP.Compare(other.IP) == 0 && c.Port == other.Port &&
|
||||
c.Protocol == other.Protocol && c.Hostname == other.Hostname &&
|
||||
c.ServerName == other.ServerName && c.PubKey == other.PubKey
|
||||
c.PubKey == other.PubKey && c.ServerName == other.ServerName &&
|
||||
c.PortForward == other.PortForward
|
||||
}
|
||||
|
||||
// UpdateEmptyWith updates each field of the connection where the
|
||||
|
||||
@@ -14,6 +14,7 @@ type Settings struct {
|
||||
Filepath string
|
||||
Interface string // needed for PIA and ProtonVPN, tun0 for example
|
||||
ServerName string // needed for PIA
|
||||
CanPortForward bool // needed for PIA
|
||||
ListeningPort uint16
|
||||
}
|
||||
|
||||
@@ -23,6 +24,7 @@ func (s Settings) Copy() (copied Settings) {
|
||||
copied.Filepath = s.Filepath
|
||||
copied.Interface = s.Interface
|
||||
copied.ServerName = s.ServerName
|
||||
copied.CanPortForward = s.CanPortForward
|
||||
copied.ListeningPort = s.ListeningPort
|
||||
return copied
|
||||
}
|
||||
@@ -33,6 +35,7 @@ func (s *Settings) OverrideWith(update Settings) {
|
||||
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
|
||||
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
|
||||
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
|
||||
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
|
||||
s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
Gateway: gateway,
|
||||
Client: s.client,
|
||||
ServerName: s.settings.ServerName,
|
||||
CanPortForward: s.settings.CanPortForward,
|
||||
}
|
||||
port, err := s.settings.PortForwarder.PortForward(ctx, obj)
|
||||
if err != nil {
|
||||
|
||||
@@ -92,21 +92,6 @@ func (mr *MockStorageMockRecorder) FilterServers(arg0, arg1 interface{}) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterServers", reflect.TypeOf((*MockStorage)(nil).FilterServers), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetServerByName mocks base method.
|
||||
func (m *MockStorage) GetServerByName(arg0, arg1 string) (models.Server, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServerByName", arg0, arg1)
|
||||
ret0, _ := ret[0].(models.Server)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServerByName indicates an expected call of GetServerByName.
|
||||
func (mr *MockStorageMockRecorder) GetServerByName(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerByName", reflect.TypeOf((*MockStorage)(nil).GetServerByName), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockUnzipper is a mock of Unzipper interface.
|
||||
type MockUnzipper struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@@ -8,5 +8,4 @@ import (
|
||||
type Storage interface {
|
||||
FilterServers(provider string, selection settings.ServerSelection) (
|
||||
servers []models.Server, err error)
|
||||
GetServerByName(provider, name string) (server models.Server, ok bool)
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ func getOpenVPNConnection(extractor Extractor,
|
||||
// Set the server name for PIA port forwarding code used
|
||||
// together with the custom provider.
|
||||
connection.ServerName = selection.Names[0]
|
||||
connection.PortForward = true
|
||||
}
|
||||
|
||||
return connection, nil
|
||||
@@ -62,6 +63,7 @@ func getWireguardConnection(selection settings.ServerSelection) (
|
||||
// Set the server name for PIA port forwarding code used
|
||||
// together with the custom provider.
|
||||
connection.ServerName = selection.Names[0]
|
||||
connection.PortForward = true
|
||||
}
|
||||
return connection
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/golibs/format"
|
||||
)
|
||||
@@ -37,16 +36,10 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
|
||||
serverName := objects.ServerName
|
||||
|
||||
server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName)
|
||||
}
|
||||
|
||||
logger := objects.Logger
|
||||
|
||||
if !server.PortForward {
|
||||
logger.Error("The server " + serverName +
|
||||
" (region " + server.Region + ") does not support port forwarding")
|
||||
if !objects.CanPortForward {
|
||||
logger.Error("The server " + serverName + " does not support port forwarding")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ type Providers struct {
|
||||
type Storage interface {
|
||||
FilterServers(provider string, selection settings.ServerSelection) (
|
||||
servers []models.Server, err error)
|
||||
GetServerByName(provider, name string) (server models.Server, ok bool)
|
||||
}
|
||||
|
||||
type Extractor interface {
|
||||
|
||||
@@ -66,6 +66,7 @@ func GetConnection(provider string,
|
||||
Protocol: protocol,
|
||||
Hostname: hostname,
|
||||
ServerName: server.ServerName,
|
||||
PortForward: server.PortForward,
|
||||
PubKey: server.WgPubKey, // Wireguard
|
||||
}
|
||||
connections = append(connections, connection)
|
||||
|
||||
@@ -15,11 +15,10 @@ type PortForwardObjects struct {
|
||||
Gateway netip.Addr
|
||||
// Client is used to query the VPN gateway for Private Internet Access.
|
||||
Client *http.Client
|
||||
// ServerName is used by Private Internet Access for port forwarding,
|
||||
// and to look up the server data from storage.
|
||||
// TODO use server data directly to remove storage dependency for port
|
||||
// forwarding implementation.
|
||||
// ServerName is used by Private Internet Access for port forwarding.
|
||||
ServerName string
|
||||
// CanPortForward is used by Private Internet Access for port forwarding.
|
||||
CanPortForward bool
|
||||
}
|
||||
|
||||
type Routing interface {
|
||||
|
||||
@@ -33,29 +33,6 @@ func (s *Storage) SetServers(provider string, servers []models.Server) (err erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServerByName returns the server for the given provider
|
||||
// and server name. It returns `ok` as false if the server is
|
||||
// not found. The returned server is also deep copied so it is
|
||||
// safe for mutation and/or thread safe use.
|
||||
func (s *Storage) GetServerByName(provider, name string) (
|
||||
server models.Server, ok bool) {
|
||||
if provider == providers.Custom {
|
||||
return server, false
|
||||
}
|
||||
|
||||
s.mergedMutex.RLock()
|
||||
defer s.mergedMutex.RUnlock()
|
||||
|
||||
serversObject := s.getMergedServersObject(provider)
|
||||
for _, server := range serversObject.Servers {
|
||||
if server.ServerName == name {
|
||||
return copyServer(server), true
|
||||
}
|
||||
}
|
||||
|
||||
return server, false
|
||||
}
|
||||
|
||||
// GetServersCount returns the number of servers for the provider given.
|
||||
func (s *Storage) GetServersCount(provider string) (count int) {
|
||||
if provider == providers.Custom {
|
||||
|
||||
@@ -18,7 +18,6 @@ type Storage interface {
|
||||
ServersAreEqual(provider string, servers []models.Server) (equal bool)
|
||||
// Extra methods to match the provider.New storage interface
|
||||
FilterServers(provider string, selection settings.ServerSelection) (filtered []models.Server, err error)
|
||||
GetServerByName(provider string, name string) (server models.Server, ok bool)
|
||||
}
|
||||
|
||||
type Unzipper interface {
|
||||
|
||||
@@ -53,7 +53,6 @@ type PortForwarder interface {
|
||||
|
||||
type Storage interface {
|
||||
FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error)
|
||||
GetServerByName(provider, name string) (server models.Server, ok bool)
|
||||
}
|
||||
|
||||
type NetLinker interface {
|
||||
|
||||
@@ -15,37 +15,38 @@ import (
|
||||
func setupOpenVPN(ctx context.Context, fw Firewall,
|
||||
openvpnConf OpenVPN, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6Supported bool, starter command.Starter,
|
||||
logger openvpn.Logger) (runner *openvpn.Runner, serverName string, err error) {
|
||||
logger openvpn.Logger) (runner *openvpn.Runner, serverName string,
|
||||
canPortForward bool, err error) {
|
||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("finding a valid server connection: %w", err)
|
||||
return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err)
|
||||
}
|
||||
|
||||
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
|
||||
|
||||
if err := openvpnConf.WriteConfig(lines); err != nil {
|
||||
return nil, "", fmt.Errorf("writing configuration to file: %w", err)
|
||||
return nil, "", false, fmt.Errorf("writing configuration to file: %w", err)
|
||||
}
|
||||
|
||||
if *settings.OpenVPN.User != "" {
|
||||
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("writing auth to file: %w", err)
|
||||
return nil, "", false, fmt.Errorf("writing auth to file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if *settings.OpenVPN.KeyPassphrase != "" {
|
||||
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("writing askpass file: %w", err)
|
||||
return nil, "", false, fmt.Errorf("writing askpass file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
|
||||
return nil, "", fmt.Errorf("allowing VPN connection through firewall: %w", err)
|
||||
return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err)
|
||||
}
|
||||
|
||||
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
|
||||
|
||||
return runner, connection.ServerName, nil
|
||||
return runner, connection.ServerName, connection.PortForward, nil
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
|
||||
PortForwarder: data.portForwarder,
|
||||
Interface: data.vpnIntf,
|
||||
ServerName: data.serverName,
|
||||
CanPortForward: data.canPortForward,
|
||||
},
|
||||
}
|
||||
return l.portForward.UpdateWith(partialUpdate)
|
||||
|
||||
@@ -29,15 +29,16 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
|
||||
}
|
||||
var serverName, vpnInterface string
|
||||
var canPortForward bool
|
||||
var err error
|
||||
subLogger := l.logger.New(log.SetComponent(settings.Type))
|
||||
if settings.Type == vpn.OpenVPN {
|
||||
vpnInterface = settings.OpenVPN.Interface
|
||||
vpnRunner, serverName, err = setupOpenVPN(ctx, l.fw,
|
||||
vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw,
|
||||
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
|
||||
} else { // Wireguard
|
||||
vpnInterface = settings.Wireguard.Interface
|
||||
vpnRunner, serverName, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||
providerConf, settings, l.ipv6Supported, subLogger)
|
||||
}
|
||||
if err != nil {
|
||||
@@ -46,6 +47,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
}
|
||||
tunnelUpData := tunnelUpData{
|
||||
serverName: serverName,
|
||||
canPortForward: canPortForward,
|
||||
portForwarder: portForwarder,
|
||||
vpnIntf: vpnInterface,
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@ import (
|
||||
type tunnelUpData struct {
|
||||
// Port forwarding
|
||||
vpnIntf string
|
||||
serverName string
|
||||
serverName string // used for PIA
|
||||
canPortForward bool // used for PIA
|
||||
portForwarder PortForwarder
|
||||
}
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
func setupWireguard(ctx context.Context, netlinker NetLinker,
|
||||
fw Firewall, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
|
||||
wireguarder *wireguard.Wireguard, serverName string, err error) {
|
||||
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error) {
|
||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("finding a VPN server: %w", err)
|
||||
return nil, "", false, fmt.Errorf("finding a VPN server: %w", err)
|
||||
}
|
||||
|
||||
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
|
||||
@@ -30,13 +30,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
|
||||
|
||||
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating Wireguard: %w", err)
|
||||
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
|
||||
}
|
||||
|
||||
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("setting firewall: %w", err)
|
||||
return nil, "", false, fmt.Errorf("setting firewall: %w", err)
|
||||
}
|
||||
|
||||
return wireguarder, connection.ServerName, nil
|
||||
return wireguarder, connection.ServerName, connection.PortForward, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user