chore(portforward): move vpn gateway obtention within port forwarding service
This commit is contained in:
@@ -22,30 +22,33 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrServerNameNotFound = errors.New("server name not found in servers")
|
||||
ErrGatewayIPIsNotValid = errors.New("gateway IP address is not valid")
|
||||
ErrServerNameEmpty = errors.New("server name is empty")
|
||||
ErrServerNameNotFound = errors.New("server name not found in servers")
|
||||
)
|
||||
|
||||
// PortForward obtains a VPN server side port forwarded from PIA.
|
||||
func (p *Provider) PortForward(ctx context.Context, client *http.Client,
|
||||
logger utils.Logger, gateway netip.Addr, serverName string) (
|
||||
port uint16, err error) {
|
||||
func (p *Provider) PortForward(ctx context.Context,
|
||||
objects utils.PortForwardObjects) (port uint16, err error) {
|
||||
switch {
|
||||
case objects.ServerName == "":
|
||||
panic("server name cannot be empty")
|
||||
case !objects.Gateway.IsValid():
|
||||
panic("gateway is not set")
|
||||
}
|
||||
|
||||
serverName := objects.ServerName
|
||||
|
||||
server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName)
|
||||
}
|
||||
|
||||
logger := objects.Logger
|
||||
|
||||
if !server.PortForward {
|
||||
logger.Error("The server " + serverName +
|
||||
" (region " + server.Region + ") does not support port forwarding")
|
||||
return 0, nil
|
||||
}
|
||||
if !gateway.IsValid() {
|
||||
return 0, fmt.Errorf("%w: %s", ErrGatewayIPIsNotValid, gateway)
|
||||
} else if serverName == "" {
|
||||
return 0, ErrServerNameEmpty
|
||||
}
|
||||
|
||||
privateIPClient, err := newHTTPClient(serverName)
|
||||
if err != nil {
|
||||
@@ -70,7 +73,8 @@ func (p *Provider) PortForward(ctx context.Context, client *http.Client,
|
||||
}
|
||||
|
||||
if !dataFound || expired {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
client := objects.Client
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("refreshing port forward data: %w", err)
|
||||
@@ -80,7 +84,7 @@ func (p *Provider) PortForward(ctx context.Context, client *http.Client,
|
||||
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
||||
|
||||
// First time binding
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
|
||||
return 0, fmt.Errorf("binding port: %w", err)
|
||||
}
|
||||
|
||||
@@ -91,9 +95,16 @@ var (
|
||||
ErrPortForwardedExpired = errors.New("port forwarded data expired")
|
||||
)
|
||||
|
||||
func (p *Provider) KeepPortForward(ctx context.Context, _ uint16,
|
||||
gateway netip.Addr, serverName string, _ utils.Logger) (err error) {
|
||||
privateIPClient, err := newHTTPClient(serverName)
|
||||
func (p *Provider) KeepPortForward(ctx context.Context,
|
||||
objects utils.PortForwardObjects) (err error) {
|
||||
switch {
|
||||
case objects.ServerName == "":
|
||||
panic("server name cannot be empty")
|
||||
case !objects.Gateway.IsValid():
|
||||
panic("gateway is not set")
|
||||
}
|
||||
|
||||
privateIPClient, err := newHTTPClient(objects.ServerName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating custom HTTP client: %w", err)
|
||||
}
|
||||
@@ -120,7 +131,7 @@ func (p *Provider) KeepPortForward(ctx context.Context, _ uint16,
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-keepAliveTimer.C:
|
||||
err := bindPort(ctx, privateIPClient, gateway, data)
|
||||
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("binding port: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user