diff --git a/internal/models/connection.go b/internal/models/connection.go index 60689af6..c85b05e1 100644 --- a/internal/models/connection.go +++ b/internal/models/connection.go @@ -16,17 +16,20 @@ type Connection struct { // Hostname is used for IPVanish, IVPN, Privado // and Windscribe for TLS verification. Hostname string `json:"hostname"` - // ServerName is used for PIA for port forwarding - ServerName string `json:"server_name,omitempty"` // PubKey is the public key of the VPN server, // used only for Wireguard. PubKey string `json:"pubkey"` + // ServerName is used for PIA for port forwarding + ServerName string `json:"server_name,omitempty"` + // PortForward is used for PIA for port forwarding + PortForward bool `json:"port_forward"` } func (c *Connection) Equal(other Connection) bool { return c.IP.Compare(other.IP) == 0 && c.Port == other.Port && c.Protocol == other.Protocol && c.Hostname == other.Hostname && - c.ServerName == other.ServerName && c.PubKey == other.PubKey + c.PubKey == other.PubKey && c.ServerName == other.ServerName && + c.PortForward == other.PortForward } // UpdateEmptyWith updates each field of the connection where the diff --git a/internal/portforward/service/settings.go b/internal/portforward/service/settings.go index 8deee301..0e35117d 100644 --- a/internal/portforward/service/settings.go +++ b/internal/portforward/service/settings.go @@ -9,12 +9,13 @@ import ( ) type Settings struct { - Enabled *bool - PortForwarder PortForwarder - Filepath string - Interface string // needed for PIA and ProtonVPN, tun0 for example - ServerName string // needed for PIA - ListeningPort uint16 + Enabled *bool + PortForwarder PortForwarder + Filepath string + Interface string // needed for PIA and ProtonVPN, tun0 for example + ServerName string // needed for PIA + CanPortForward bool // needed for PIA + ListeningPort uint16 } func (s Settings) Copy() (copied Settings) { @@ -23,6 +24,7 @@ func (s Settings) Copy() (copied Settings) { copied.Filepath = s.Filepath copied.Interface = s.Interface copied.ServerName = s.ServerName + copied.CanPortForward = s.CanPortForward copied.ListeningPort = s.ListeningPort return copied } @@ -33,6 +35,7 @@ func (s *Settings) OverrideWith(update Settings) { s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath) s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface) s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName) + s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward) s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort) } diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index 0ad55b74..99f514ec 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -23,10 +23,11 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) } obj := utils.PortForwardObjects{ - Logger: s.logger, - Gateway: gateway, - Client: s.client, - ServerName: s.settings.ServerName, + Logger: s.logger, + Gateway: gateway, + Client: s.client, + ServerName: s.settings.ServerName, + CanPortForward: s.settings.CanPortForward, } port, err := s.settings.PortForwarder.PortForward(ctx, obj) if err != nil { diff --git a/internal/provider/common/mocks.go b/internal/provider/common/mocks.go index affeb7b1..d7ddaa1a 100644 --- a/internal/provider/common/mocks.go +++ b/internal/provider/common/mocks.go @@ -92,21 +92,6 @@ func (mr *MockStorageMockRecorder) FilterServers(arg0, arg1 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterServers", reflect.TypeOf((*MockStorage)(nil).FilterServers), arg0, arg1) } -// GetServerByName mocks base method. -func (m *MockStorage) GetServerByName(arg0, arg1 string) (models.Server, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetServerByName", arg0, arg1) - ret0, _ := ret[0].(models.Server) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// GetServerByName indicates an expected call of GetServerByName. -func (mr *MockStorageMockRecorder) GetServerByName(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerByName", reflect.TypeOf((*MockStorage)(nil).GetServerByName), arg0, arg1) -} - // MockUnzipper is a mock of Unzipper interface. type MockUnzipper struct { ctrl *gomock.Controller diff --git a/internal/provider/common/storage.go b/internal/provider/common/storage.go index 37517e0d..eb720a12 100644 --- a/internal/provider/common/storage.go +++ b/internal/provider/common/storage.go @@ -8,5 +8,4 @@ import ( type Storage interface { FilterServers(provider string, selection settings.ServerSelection) ( servers []models.Server, err error) - GetServerByName(provider, name string) (server models.Server, ok bool) } diff --git a/internal/provider/custom/connection.go b/internal/provider/custom/connection.go index 088d3bed..38a16a2d 100644 --- a/internal/provider/custom/connection.go +++ b/internal/provider/custom/connection.go @@ -44,6 +44,7 @@ func getOpenVPNConnection(extractor Extractor, // Set the server name for PIA port forwarding code used // together with the custom provider. connection.ServerName = selection.Names[0] + connection.PortForward = true } return connection, nil @@ -62,6 +63,7 @@ func getWireguardConnection(selection settings.ServerSelection) ( // Set the server name for PIA port forwarding code used // together with the custom provider. connection.ServerName = selection.Names[0] + connection.PortForward = true } return connection } diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 563c0453..fefc6eb6 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -16,7 +16,6 @@ import ( "strings" "time" - "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/golibs/format" ) @@ -37,16 +36,10 @@ func (p *Provider) PortForward(ctx context.Context, serverName := objects.ServerName - server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName) - if !ok { - return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName) - } - logger := objects.Logger - if !server.PortForward { - logger.Error("The server " + serverName + - " (region " + server.Region + ") does not support port forwarding") + if !objects.CanPortForward { + logger.Error("The server " + serverName + " does not support port forwarding") return 0, nil } diff --git a/internal/provider/providers.go b/internal/provider/providers.go index 53643671..37c9e735 100644 --- a/internal/provider/providers.go +++ b/internal/provider/providers.go @@ -43,7 +43,6 @@ type Providers struct { type Storage interface { FilterServers(provider string, selection settings.ServerSelection) ( servers []models.Server, err error) - GetServerByName(provider, name string) (server models.Server, ok bool) } type Extractor interface { diff --git a/internal/provider/utils/connection.go b/internal/provider/utils/connection.go index c3aeae1b..70bfed60 100644 --- a/internal/provider/utils/connection.go +++ b/internal/provider/utils/connection.go @@ -60,13 +60,14 @@ func GetConnection(provider string, } connection := models.Connection{ - Type: selection.VPN, - IP: ip, - Port: port, - Protocol: protocol, - Hostname: hostname, - ServerName: server.ServerName, - PubKey: server.WgPubKey, // Wireguard + Type: selection.VPN, + IP: ip, + Port: port, + Protocol: protocol, + Hostname: hostname, + ServerName: server.ServerName, + PortForward: server.PortForward, + PubKey: server.WgPubKey, // Wireguard } connections = append(connections, connection) } diff --git a/internal/provider/utils/portforward.go b/internal/provider/utils/portforward.go index baae073d..8a5430a3 100644 --- a/internal/provider/utils/portforward.go +++ b/internal/provider/utils/portforward.go @@ -15,11 +15,10 @@ type PortForwardObjects struct { Gateway netip.Addr // Client is used to query the VPN gateway for Private Internet Access. Client *http.Client - // ServerName is used by Private Internet Access for port forwarding, - // and to look up the server data from storage. - // TODO use server data directly to remove storage dependency for port - // forwarding implementation. + // ServerName is used by Private Internet Access for port forwarding. ServerName string + // CanPortForward is used by Private Internet Access for port forwarding. + CanPortForward bool } type Routing interface { diff --git a/internal/storage/servers.go b/internal/storage/servers.go index f8f9b02c..88aa4245 100644 --- a/internal/storage/servers.go +++ b/internal/storage/servers.go @@ -33,29 +33,6 @@ func (s *Storage) SetServers(provider string, servers []models.Server) (err erro return nil } -// GetServerByName returns the server for the given provider -// and server name. It returns `ok` as false if the server is -// not found. The returned server is also deep copied so it is -// safe for mutation and/or thread safe use. -func (s *Storage) GetServerByName(provider, name string) ( - server models.Server, ok bool) { - if provider == providers.Custom { - return server, false - } - - s.mergedMutex.RLock() - defer s.mergedMutex.RUnlock() - - serversObject := s.getMergedServersObject(provider) - for _, server := range serversObject.Servers { - if server.ServerName == name { - return copyServer(server), true - } - } - - return server, false -} - // GetServersCount returns the number of servers for the provider given. func (s *Storage) GetServersCount(provider string) (count int) { if provider == providers.Custom { diff --git a/internal/updater/interfaces.go b/internal/updater/interfaces.go index 1c50bb87..0f80613c 100644 --- a/internal/updater/interfaces.go +++ b/internal/updater/interfaces.go @@ -18,7 +18,6 @@ type Storage interface { ServersAreEqual(provider string, servers []models.Server) (equal bool) // Extra methods to match the provider.New storage interface FilterServers(provider string, selection settings.ServerSelection) (filtered []models.Server, err error) - GetServerByName(provider string, name string) (server models.Server, ok bool) } type Unzipper interface { diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 5ca8eca9..6fe8b4f6 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -53,7 +53,6 @@ type PortForwarder interface { type Storage interface { FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error) - GetServerByName(provider, name string) (server models.Server, ok bool) } type NetLinker interface { diff --git a/internal/vpn/openvpn.go b/internal/vpn/openvpn.go index dabb1eac..0a248b1d 100644 --- a/internal/vpn/openvpn.go +++ b/internal/vpn/openvpn.go @@ -15,37 +15,38 @@ import ( func setupOpenVPN(ctx context.Context, fw Firewall, openvpnConf OpenVPN, providerConf provider.Provider, settings settings.VPN, ipv6Supported bool, starter command.Starter, - logger openvpn.Logger) (runner *openvpn.Runner, serverName string, err error) { + logger openvpn.Logger) (runner *openvpn.Runner, serverName string, + canPortForward bool, err error) { connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) if err != nil { - return nil, "", fmt.Errorf("finding a valid server connection: %w", err) + return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err) } lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported) if err := openvpnConf.WriteConfig(lines); err != nil { - return nil, "", fmt.Errorf("writing configuration to file: %w", err) + return nil, "", false, fmt.Errorf("writing configuration to file: %w", err) } if *settings.OpenVPN.User != "" { err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password) if err != nil { - return nil, "", fmt.Errorf("writing auth to file: %w", err) + return nil, "", false, fmt.Errorf("writing auth to file: %w", err) } } if *settings.OpenVPN.KeyPassphrase != "" { err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase) if err != nil { - return nil, "", fmt.Errorf("writing askpass file: %w", err) + return nil, "", false, fmt.Errorf("writing askpass file: %w", err) } } if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil { - return nil, "", fmt.Errorf("allowing VPN connection through firewall: %w", err) + return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err) } runner = openvpn.NewRunner(settings.OpenVPN, starter, logger) - return runner, connection.ServerName, nil + return runner, connection.ServerName, connection.PortForward, nil } diff --git a/internal/vpn/portforward.go b/internal/vpn/portforward.go index f6f6b117..f7eb92f2 100644 --- a/internal/vpn/portforward.go +++ b/internal/vpn/portforward.go @@ -26,9 +26,10 @@ func (l *Loop) startPortForwarding(data tunnelUpData) (err error) { partialUpdate := portforward.Settings{ VPNIsUp: ptrTo(true), Service: service.Settings{ - PortForwarder: data.portForwarder, - Interface: data.vpnIntf, - ServerName: data.serverName, + PortForwarder: data.portForwarder, + Interface: data.vpnIntf, + ServerName: data.serverName, + CanPortForward: data.canPortForward, }, } return l.portForward.UpdateWith(partialUpdate) diff --git a/internal/vpn/run.go b/internal/vpn/run.go index 749dd92b..8926ac6d 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -29,15 +29,16 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{}) } var serverName, vpnInterface string + var canPortForward bool var err error subLogger := l.logger.New(log.SetComponent(settings.Type)) if settings.Type == vpn.OpenVPN { vpnInterface = settings.OpenVPN.Interface - vpnRunner, serverName, err = setupOpenVPN(ctx, l.fw, + vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw, l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger) } else { // Wireguard vpnInterface = settings.Wireguard.Interface - vpnRunner, serverName, err = setupWireguard(ctx, l.netLinker, l.fw, + vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw, providerConf, settings, l.ipv6Supported, subLogger) } if err != nil { @@ -45,9 +46,10 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { continue } tunnelUpData := tunnelUpData{ - serverName: serverName, - portForwarder: portForwarder, - vpnIntf: vpnInterface, + serverName: serverName, + canPortForward: canPortForward, + portForwarder: portForwarder, + vpnIntf: vpnInterface, } openvpnCtx, openvpnCancel := context.WithCancel(context.Background()) diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 74064f5b..58a88551 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -9,9 +9,10 @@ import ( type tunnelUpData struct { // Port forwarding - vpnIntf string - serverName string - portForwarder PortForwarder + vpnIntf string + serverName string // used for PIA + canPortForward bool // used for PIA + portForwarder PortForwarder } func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { diff --git a/internal/vpn/wireguard.go b/internal/vpn/wireguard.go index 0899c626..9c1c6f58 100644 --- a/internal/vpn/wireguard.go +++ b/internal/vpn/wireguard.go @@ -16,10 +16,10 @@ import ( func setupWireguard(ctx context.Context, netlinker NetLinker, fw Firewall, providerConf provider.Provider, settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) ( - wireguarder *wireguard.Wireguard, serverName string, err error) { + wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error) { connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) if err != nil { - return nil, "", fmt.Errorf("finding a VPN server: %w", err) + return nil, "", false, fmt.Errorf("finding a VPN server: %w", err) } wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported) @@ -30,13 +30,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker, wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger) if err != nil { - return nil, "", fmt.Errorf("creating Wireguard: %w", err) + return nil, "", false, fmt.Errorf("creating Wireguard: %w", err) } err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface) if err != nil { - return nil, "", fmt.Errorf("setting firewall: %w", err) + return nil, "", false, fmt.Errorf("setting firewall: %w", err) } - return wireguarder, connection.ServerName, nil + return wireguarder, connection.ServerName, connection.PortForward, nil }