chore(port-forward): support multiple port forwarded

This commit is contained in:
Quentin McGaw
2024-07-28 19:49:45 +00:00
parent 4c47b6f142
commit 8c730a6e4a
16 changed files with 147 additions and 57 deletions

View File

@@ -8,7 +8,7 @@ import (
type Service interface { type Service interface {
Start(ctx context.Context) (runError <-chan error, err error) Start(ctx context.Context) (runError <-chan error, err error)
Stop() (err error) Stop() (err error)
GetPortForwarded() (port uint16) GetPortsForwarded() (ports []uint16)
} }
type Routing interface { type Routing interface {

View File

@@ -150,11 +150,11 @@ func (l *Loop) Stop() (err error) {
return nil return nil
} }
func (l *Loop) GetPortForwarded() (port uint16) { func (l *Loop) GetPortsForwarded() (ports []uint16) {
if l.service == nil { if l.service == nil {
return 0 return nil
} }
return l.service.GetPortForwarded() return l.service.GetPortsForwarded()
} }
func ptrTo[T any](value T) *T { func ptrTo[T any](value T) *T {

View File

@@ -3,13 +3,20 @@ package service
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
) )
func (s *Service) writePortForwardedFile(port uint16) (err error) { func (s *Service) writePortForwardedFile(ports []uint16) (err error) {
portStrings := make([]string, len(ports))
for i, port := range ports {
portStrings[i] = fmt.Sprint(int(port))
}
fileData := []byte(strings.Join(portStrings, "\n"))
filepath := s.settings.Filepath filepath := s.settings.Filepath
s.logger.Info("writing port file " + filepath) s.logger.Info("writing port file " + filepath)
const perms = os.FileMode(0644) const perms = os.FileMode(0644)
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms) err = os.WriteFile(filepath, fileData, perms)
if err != nil { if err != nil {
return fmt.Errorf("writing file: %w", err) return fmt.Errorf("writing file: %w", err)
} }

View File

@@ -0,0 +1,22 @@
package service
import (
"fmt"
"strings"
)
func portsToString(ports []uint16) (s string) {
switch len(ports) {
case 0:
return "no port forwarded"
case 1:
return "port forwarded is " + fmt.Sprint(int(ports[0]))
default:
portStrings := make([]string, len(ports))
for i, port := range ports {
portStrings[i] = fmt.Sprint(int(port))
}
return "ports forwarded are " + strings.Join(portStrings[:len(portStrings)-1], ", ") +
" and " + portStrings[len(portStrings)-1]
}
}

View File

@@ -0,0 +1,43 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_portsToString(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ports []uint16
s string
}{
"no_port": {
s: "no port forwarded",
},
"one_port": {
ports: []uint16{123},
s: "port forwarded is 123",
},
"two_ports": {
ports: []uint16{123, 456},
s: "ports forwarded are 123 and 456",
},
"three_ports": {
ports: []uint16{123, 456, 789},
s: "ports forwarded are 123, 456 and 789",
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
s := portsToString(testCase.ports)
assert.Equal(t, testCase.s, s)
})
}
}

View File

@@ -28,6 +28,6 @@ type Logger interface {
type PortForwarder interface { type PortForwarder interface {
Name() string Name() string
PortForward(ctx context.Context, objects utils.PortForwardObjects) ( PortForward(ctx context.Context, objects utils.PortForwardObjects) (
port uint16, err error) ports []uint16, err error)
KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error)
} }

View File

@@ -9,7 +9,7 @@ import (
type Service struct { type Service struct {
// State // State
portMutex sync.RWMutex portMutex sync.RWMutex
port uint16 ports []uint16
// Fixed parameters // Fixed parameters
settings Settings settings Settings
puid int puid int
@@ -40,8 +40,10 @@ func New(settings Settings, routing Routing, client *http.Client,
} }
} }
func (s *Service) GetPortForwarded() (port uint16) { func (s *Service) GetPortsForwarded() (ports []uint16) {
s.portMutex.RLock() s.portMutex.RLock()
defer s.portMutex.RUnlock() defer s.portMutex.RUnlock()
return s.port ports = make([]uint16, len(s.ports))
copy(ports, s.ports)
return ports
} }

View File

@@ -31,33 +31,35 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
Username: s.settings.Username, Username: s.settings.Username,
Password: s.settings.Password, Password: s.settings.Password,
} }
port, err := s.settings.PortForwarder.PortForward(ctx, obj) ports, err := s.settings.PortForwarder.PortForward(ctx, obj)
if err != nil { if err != nil {
return nil, fmt.Errorf("port forwarding for the first time: %w", err) return nil, fmt.Errorf("port forwarding for the first time: %w", err)
} }
s.logger.Info("port forwarded is " + fmt.Sprint(int(port))) s.logger.Info(portsToString(ports))
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface) for _, port := range ports {
if err != nil { err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
return nil, fmt.Errorf("allowing port in firewall: %w", err)
}
if s.settings.ListeningPort != 0 {
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
if err != nil { if err != nil {
return nil, fmt.Errorf("redirecting port in firewall: %w", err) return nil, fmt.Errorf("allowing port in firewall: %w", err)
}
if s.settings.ListeningPort != 0 {
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
if err != nil {
return nil, fmt.Errorf("redirecting port in firewall: %w", err)
}
} }
} }
err = s.writePortForwardedFile(port) err = s.writePortForwardedFile(ports)
if err != nil { if err != nil {
_ = s.cleanup() _ = s.cleanup()
return nil, fmt.Errorf("writing port file: %w", err) return nil, fmt.Errorf("writing port file: %w", err)
} }
s.portMutex.Lock() s.portMutex.Lock()
s.port = port s.ports = ports
s.portMutex.Unlock() s.portMutex.Unlock()
keepPortCtx, keepPortCancel := context.WithCancel(context.Background()) keepPortCtx, keepPortCancel := context.WithCancel(context.Background())

View File

@@ -11,7 +11,7 @@ func (s *Service) Stop() (err error) {
defer s.startStopMutex.Unlock() defer s.startStopMutex.Unlock()
s.portMutex.RLock() s.portMutex.RLock()
serviceNotRunning := s.port == 0 serviceNotRunning := len(s.ports) == 0
s.portMutex.RUnlock() s.portMutex.RUnlock()
if serviceNotRunning { if serviceNotRunning {
// TODO replace with goservices.ErrAlreadyStopped // TODO replace with goservices.ErrAlreadyStopped
@@ -30,21 +30,23 @@ func (s *Service) cleanup() (err error) {
s.portMutex.Lock() s.portMutex.Lock()
defer s.portMutex.Unlock() defer s.portMutex.Unlock()
err = s.portAllower.RemoveAllowedPort(context.Background(), s.port) for _, port := range s.ports {
if err != nil { err = s.portAllower.RemoveAllowedPort(context.Background(), port)
return fmt.Errorf("blocking previous port in firewall: %w", err)
}
if s.settings.ListeningPort != 0 {
ctx := context.Background()
const listeningPort = 0 // 0 to clear the redirection
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, s.port, listeningPort)
if err != nil { if err != nil {
return fmt.Errorf("removing previous port redirection in firewall: %w", err) return fmt.Errorf("blocking previous port in firewall: %w", err)
}
if s.settings.ListeningPort != 0 {
ctx := context.Background()
const listeningPort = 0 // 0 to clear the redirection
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, listeningPort)
if err != nil {
return fmt.Errorf("removing previous port redirection in firewall: %w", err)
}
} }
} }
s.port = 0 s.ports = nil
filepath := s.settings.Filepath filepath := s.settings.Filepath
s.logger.Info("removing port file " + filepath) s.logger.Info("removing port file " + filepath)

View File

@@ -26,7 +26,7 @@ var (
// PortForward obtains a VPN server side port forwarded from PIA. // PortForward obtains a VPN server side port forwarded from PIA.
func (p *Provider) PortForward(ctx context.Context, func (p *Provider) PortForward(ctx context.Context,
objects utils.PortForwardObjects) (port uint16, err error) { objects utils.PortForwardObjects) (ports []uint16, err error) {
switch { switch {
case objects.ServerName == "": case objects.ServerName == "":
panic("server name cannot be empty") panic("server name cannot be empty")
@@ -43,17 +43,17 @@ func (p *Provider) PortForward(ctx context.Context,
logger := objects.Logger logger := objects.Logger
if !objects.CanPortForward { if !objects.CanPortForward {
return 0, fmt.Errorf("%w: for server %s", ErrServerNameNotFound, serverName) return nil, fmt.Errorf("%w: for server %s", ErrServerNameNotFound, serverName)
} }
privateIPClient, err := newHTTPClient(serverName) privateIPClient, err := newHTTPClient(serverName)
if err != nil { if err != nil {
return 0, fmt.Errorf("creating custom HTTP client: %w", err) return nil, fmt.Errorf("creating custom HTTP client: %w", err)
} }
data, err := readPIAPortForwardData(p.portForwardPath) data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil { if err != nil {
return 0, fmt.Errorf("reading saved port forwarded data: %w", err) return nil, fmt.Errorf("reading saved port forwarded data: %w", err)
} }
dataFound := data.Port > 0 dataFound := data.Port > 0
@@ -73,7 +73,7 @@ func (p *Provider) PortForward(ctx context.Context,
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway, data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
p.portForwardPath, objects.Username, objects.Password) p.portForwardPath, objects.Username, objects.Password)
if err != nil { if err != nil {
return 0, fmt.Errorf("refreshing port forward data: %w", err) return nil, fmt.Errorf("refreshing port forward data: %w", err)
} }
durationToExpiration = data.Expiration.Sub(p.timeNow()) durationToExpiration = data.Expiration.Sub(p.timeNow())
} }
@@ -81,10 +81,10 @@ func (p *Provider) PortForward(ctx context.Context,
// First time binding // First time binding
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil { if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
return 0, fmt.Errorf("binding port: %w", err) return nil, fmt.Errorf("binding port: %w", err)
} }
return data.Port, nil return []uint16{data.Port}, nil
} }
var ( var (

View File

@@ -13,7 +13,7 @@ import (
// PortForward obtains a VPN server side port forwarded from ProtonVPN gateway. // PortForward obtains a VPN server side port forwarded from ProtonVPN gateway.
func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) ( func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) (
port uint16, err error) { ports []uint16, err error) {
client := natpmp.New() client := natpmp.New()
_, externalIPv4Address, err := client.ExternalAddress(ctx, _, externalIPv4Address, err := client.ExternalAddress(ctx,
objects.Gateway) objects.Gateway)
@@ -21,7 +21,7 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj
if strings.HasSuffix(err.Error(), "connection refused") { if strings.HasSuffix(err.Error(), "connection refused") {
err = fmt.Errorf("%w - make sure you have +pmp at the end of your OpenVPN username", err) err = fmt.Errorf("%w - make sure you have +pmp at the end of your OpenVPN username", err)
} }
return 0, fmt.Errorf("getting external IPv4 address: %w", err) return nil, fmt.Errorf("getting external IPv4 address: %w", err)
} }
logger := objects.Logger logger := objects.Logger
@@ -34,7 +34,7 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj
client.AddPortMapping(ctx, objects.Gateway, "udp", client.AddPortMapping(ctx, objects.Gateway, "udp",
internalPort, externalPort, lifetime) internalPort, externalPort, lifetime)
if err != nil { if err != nil {
return 0, fmt.Errorf("adding UDP port mapping: %w", err) return nil, fmt.Errorf("adding UDP port mapping: %w", err)
} }
checkLifetime(logger, "UDP", lifetime, assignedLifetime) checkLifetime(logger, "UDP", lifetime, assignedLifetime)
@@ -42,16 +42,15 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj
client.AddPortMapping(ctx, objects.Gateway, "tcp", client.AddPortMapping(ctx, objects.Gateway, "tcp",
internalPort, externalPort, lifetime) internalPort, externalPort, lifetime)
if err != nil { if err != nil {
return 0, fmt.Errorf("adding TCP port mapping: %w", err) return nil, fmt.Errorf("adding TCP port mapping: %w", err)
} }
checkLifetime(logger, "TCP", lifetime, assignedLifetime) checkLifetime(logger, "TCP", lifetime, assignedLifetime)
checkExternalPorts(logger, assignedUDPExternalPort, assignedTCPExternalPort) checkExternalPorts(logger, assignedUDPExternalPort, assignedTCPExternalPort)
port = assignedTCPExternalPort
p.portForwarded = port p.portForwarded = assignedTCPExternalPort
return port, nil return []uint16{assignedTCPExternalPort}, nil
} }
func checkLifetime(logger utils.Logger, protocol string, func checkLifetime(logger utils.Logger, protocol string,

View File

@@ -22,7 +22,7 @@ type DNSLoop interface {
} }
type PortForwardedGetter interface { type PortForwardedGetter interface {
GetPortForwarded() (portForwarded uint16) GetPortsForwarded() (ports []uint16)
} }
type PublicIPLoop interface { type PublicIPLoop interface {

View File

@@ -123,12 +123,21 @@ func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
} }
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) { func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
port := h.pf.GetPortForwarded() ports := h.pf.GetPortsForwarded()
encoder := json.NewEncoder(w) encoder := json.NewEncoder(w)
data := portWrapper{Port: port} var data any
if err := encoder.Encode(data); err != nil { switch len(ports) {
case 0:
data = portWrapper{Port: 0} // TODO v4 change to portsWrapper
case 1:
data = portWrapper{Port: ports[0]} // TODO v4 change to portsWrapper
default:
data = portsWrapper{Ports: ports}
}
err := encoder.Encode(data)
if err != nil {
h.warner.Warn(err.Error()) h.warner.Warn(err.Error())
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return
} }
} }

View File

@@ -25,10 +25,14 @@ func (sw *statusWrapper) getStatus() (status models.LoopStatus, err error) {
} }
} }
type portWrapper struct { type portWrapper struct { // TODO v4 remove
Port uint16 `json:"port"` Port uint16 `json:"port"`
} }
type portsWrapper struct {
Ports []uint16 `json:"ports"`
}
type outcomeWrapper struct { type outcomeWrapper struct {
Outcome string `json:"outcome"` Outcome string `json:"outcome"`
} }

View File

@@ -47,7 +47,7 @@ type Provider interface {
type PortForwarder interface { type PortForwarder interface {
Name() string Name() string
PortForward(ctx context.Context, objects utils.PortForwardObjects) ( PortForward(ctx context.Context, objects utils.PortForwardObjects) (
port uint16, err error) ports []uint16, err error)
KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error)
} }

View File

@@ -61,8 +61,8 @@ func (n *noPortForwarder) Name() string {
} }
func (n *noPortForwarder) PortForward(context.Context, pfutils.PortForwardObjects) ( func (n *noPortForwarder) PortForward(context.Context, pfutils.PortForwardObjects) (
port uint16, err error) { ports []uint16, err error) {
return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) return nil, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
} }
func (n *noPortForwarder) KeepPortForward(context.Context, pfutils.PortForwardObjects) (err error) { func (n *noPortForwarder) KeepPortForward(context.Context, pfutils.PortForwardObjects) (err error) {