chore(portforward): move vpn gateway obtention within port forwarding service

This commit is contained in:
Quentin McGaw
2023-09-23 11:46:14 +00:00
parent 71201411f4
commit 0406de399d
14 changed files with 135 additions and 92 deletions

View File

@@ -22,30 +22,33 @@ import (
)
var (
ErrServerNameNotFound = errors.New("server name not found in servers")
ErrGatewayIPIsNotValid = errors.New("gateway IP address is not valid")
ErrServerNameEmpty = errors.New("server name is empty")
ErrServerNameNotFound = errors.New("server name not found in servers")
)
// PortForward obtains a VPN server side port forwarded from PIA.
func (p *Provider) PortForward(ctx context.Context, client *http.Client,
logger utils.Logger, gateway netip.Addr, serverName string) (
port uint16, err error) {
func (p *Provider) PortForward(ctx context.Context,
objects utils.PortForwardObjects) (port uint16, err error) {
switch {
case objects.ServerName == "":
panic("server name cannot be empty")
case !objects.Gateway.IsValid():
panic("gateway is not set")
}
serverName := objects.ServerName
server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName)
if !ok {
return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName)
}
logger := objects.Logger
if !server.PortForward {
logger.Error("The server " + serverName +
" (region " + server.Region + ") does not support port forwarding")
return 0, nil
}
if !gateway.IsValid() {
return 0, fmt.Errorf("%w: %s", ErrGatewayIPIsNotValid, gateway)
} else if serverName == "" {
return 0, ErrServerNameEmpty
}
privateIPClient, err := newHTTPClient(serverName)
if err != nil {
@@ -70,7 +73,8 @@ func (p *Provider) PortForward(ctx context.Context, client *http.Client,
}
if !dataFound || expired {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
client := objects.Client
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
p.portForwardPath, p.authFilePath)
if err != nil {
return 0, fmt.Errorf("refreshing port forward data: %w", err)
@@ -80,7 +84,7 @@ func (p *Provider) PortForward(ctx context.Context, client *http.Client,
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
// First time binding
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
return 0, fmt.Errorf("binding port: %w", err)
}
@@ -91,9 +95,16 @@ var (
ErrPortForwardedExpired = errors.New("port forwarded data expired")
)
func (p *Provider) KeepPortForward(ctx context.Context, _ uint16,
gateway netip.Addr, serverName string, _ utils.Logger) (err error) {
privateIPClient, err := newHTTPClient(serverName)
func (p *Provider) KeepPortForward(ctx context.Context,
objects utils.PortForwardObjects) (err error) {
switch {
case objects.ServerName == "":
panic("server name cannot be empty")
case !objects.Gateway.IsValid():
panic("gateway is not set")
}
privateIPClient, err := newHTTPClient(objects.ServerName)
if err != nil {
return fmt.Errorf("creating custom HTTP client: %w", err)
}
@@ -120,7 +131,7 @@ func (p *Provider) KeepPortForward(ctx context.Context, _ uint16,
}
return ctx.Err()
case <-keepAliveTimer.C:
err := bindPort(ctx, privateIPClient, gateway, data)
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
if err != nil {
return fmt.Errorf("binding port: %w", err)
}