From 8c730a6e4aa1f242fb15665222cf1b9ac0950cc4 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sun, 28 Jul 2024 19:49:45 +0000 Subject: [PATCH] chore(port-forward): support multiple port forwarded --- internal/portforward/interfaces.go | 2 +- internal/portforward/loop.go | 6 +-- internal/portforward/service/fs.go | 11 ++++- internal/portforward/service/helpers.go | 22 ++++++++++ internal/portforward/service/helpers_test.go | 43 +++++++++++++++++++ internal/portforward/service/interfaces.go | 2 +- internal/portforward/service/service.go | 8 ++-- internal/portforward/service/start.go | 26 +++++------ internal/portforward/service/stop.go | 26 +++++------ .../privateinternetaccess/portforward.go | 14 +++--- internal/provider/protonvpn/portforward.go | 13 +++--- internal/server/interfaces.go | 2 +- internal/server/openvpn.go | 17 ++++++-- internal/server/wrappers.go | 6 ++- internal/vpn/interfaces.go | 2 +- internal/vpn/portforward.go | 4 +- 16 files changed, 147 insertions(+), 57 deletions(-) create mode 100644 internal/portforward/service/helpers.go create mode 100644 internal/portforward/service/helpers_test.go diff --git a/internal/portforward/interfaces.go b/internal/portforward/interfaces.go index 639ec0bf..6b9eee0e 100644 --- a/internal/portforward/interfaces.go +++ b/internal/portforward/interfaces.go @@ -8,7 +8,7 @@ import ( type Service interface { Start(ctx context.Context) (runError <-chan error, err error) Stop() (err error) - GetPortForwarded() (port uint16) + GetPortsForwarded() (ports []uint16) } type Routing interface { diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index ba3acde8..c8d02a3b 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -150,11 +150,11 @@ func (l *Loop) Stop() (err error) { return nil } -func (l *Loop) GetPortForwarded() (port uint16) { +func (l *Loop) GetPortsForwarded() (ports []uint16) { if l.service == nil { - return 0 + return nil } - return l.service.GetPortForwarded() + return l.service.GetPortsForwarded() } func ptrTo[T any](value T) *T { diff --git a/internal/portforward/service/fs.go b/internal/portforward/service/fs.go index ed24fdf1..11d7d67c 100644 --- a/internal/portforward/service/fs.go +++ b/internal/portforward/service/fs.go @@ -3,13 +3,20 @@ package service import ( "fmt" "os" + "strings" ) -func (s *Service) writePortForwardedFile(port uint16) (err error) { +func (s *Service) writePortForwardedFile(ports []uint16) (err error) { + portStrings := make([]string, len(ports)) + for i, port := range ports { + portStrings[i] = fmt.Sprint(int(port)) + } + fileData := []byte(strings.Join(portStrings, "\n")) + filepath := s.settings.Filepath s.logger.Info("writing port file " + filepath) const perms = os.FileMode(0644) - err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms) + err = os.WriteFile(filepath, fileData, perms) if err != nil { return fmt.Errorf("writing file: %w", err) } diff --git a/internal/portforward/service/helpers.go b/internal/portforward/service/helpers.go new file mode 100644 index 00000000..db2eb9f4 --- /dev/null +++ b/internal/portforward/service/helpers.go @@ -0,0 +1,22 @@ +package service + +import ( + "fmt" + "strings" +) + +func portsToString(ports []uint16) (s string) { + switch len(ports) { + case 0: + return "no port forwarded" + case 1: + return "port forwarded is " + fmt.Sprint(int(ports[0])) + default: + portStrings := make([]string, len(ports)) + for i, port := range ports { + portStrings[i] = fmt.Sprint(int(port)) + } + return "ports forwarded are " + strings.Join(portStrings[:len(portStrings)-1], ", ") + + " and " + portStrings[len(portStrings)-1] + } +} diff --git a/internal/portforward/service/helpers_test.go b/internal/portforward/service/helpers_test.go new file mode 100644 index 00000000..e92f3c11 --- /dev/null +++ b/internal/portforward/service/helpers_test.go @@ -0,0 +1,43 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_portsToString(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + ports []uint16 + s string + }{ + "no_port": { + s: "no port forwarded", + }, + "one_port": { + ports: []uint16{123}, + s: "port forwarded is 123", + }, + "two_ports": { + ports: []uint16{123, 456}, + s: "ports forwarded are 123 and 456", + }, + "three_ports": { + ports: []uint16{123, 456, 789}, + s: "ports forwarded are 123, 456 and 789", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + s := portsToString(testCase.ports) + + assert.Equal(t, testCase.s, s) + }) + } +} diff --git a/internal/portforward/service/interfaces.go b/internal/portforward/service/interfaces.go index 4214ba44..7cf7d4c1 100644 --- a/internal/portforward/service/interfaces.go +++ b/internal/portforward/service/interfaces.go @@ -28,6 +28,6 @@ type Logger interface { type PortForwarder interface { Name() string PortForward(ctx context.Context, objects utils.PortForwardObjects) ( - port uint16, err error) + ports []uint16, err error) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) } diff --git a/internal/portforward/service/service.go b/internal/portforward/service/service.go index b6d55f83..081b54d2 100644 --- a/internal/portforward/service/service.go +++ b/internal/portforward/service/service.go @@ -9,7 +9,7 @@ import ( type Service struct { // State portMutex sync.RWMutex - port uint16 + ports []uint16 // Fixed parameters settings Settings puid int @@ -40,8 +40,10 @@ func New(settings Settings, routing Routing, client *http.Client, } } -func (s *Service) GetPortForwarded() (port uint16) { +func (s *Service) GetPortsForwarded() (ports []uint16) { s.portMutex.RLock() defer s.portMutex.RUnlock() - return s.port + ports = make([]uint16, len(s.ports)) + copy(ports, s.ports) + return ports } diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index 5c21b3cf..4f308375 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -31,33 +31,35 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) Username: s.settings.Username, Password: s.settings.Password, } - port, err := s.settings.PortForwarder.PortForward(ctx, obj) + ports, err := s.settings.PortForwarder.PortForward(ctx, obj) if err != nil { return nil, fmt.Errorf("port forwarding for the first time: %w", err) } - s.logger.Info("port forwarded is " + fmt.Sprint(int(port))) + s.logger.Info(portsToString(ports)) - err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface) - if err != nil { - return nil, fmt.Errorf("allowing port in firewall: %w", err) - } - - if s.settings.ListeningPort != 0 { - err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort) + for _, port := range ports { + err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface) if err != nil { - return nil, fmt.Errorf("redirecting port in firewall: %w", err) + return nil, fmt.Errorf("allowing port in firewall: %w", err) + } + + if s.settings.ListeningPort != 0 { + err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort) + if err != nil { + return nil, fmt.Errorf("redirecting port in firewall: %w", err) + } } } - err = s.writePortForwardedFile(port) + err = s.writePortForwardedFile(ports) if err != nil { _ = s.cleanup() return nil, fmt.Errorf("writing port file: %w", err) } s.portMutex.Lock() - s.port = port + s.ports = ports s.portMutex.Unlock() keepPortCtx, keepPortCancel := context.WithCancel(context.Background()) diff --git a/internal/portforward/service/stop.go b/internal/portforward/service/stop.go index beef5072..e8b8d681 100644 --- a/internal/portforward/service/stop.go +++ b/internal/portforward/service/stop.go @@ -11,7 +11,7 @@ func (s *Service) Stop() (err error) { defer s.startStopMutex.Unlock() s.portMutex.RLock() - serviceNotRunning := s.port == 0 + serviceNotRunning := len(s.ports) == 0 s.portMutex.RUnlock() if serviceNotRunning { // TODO replace with goservices.ErrAlreadyStopped @@ -30,21 +30,23 @@ func (s *Service) cleanup() (err error) { s.portMutex.Lock() defer s.portMutex.Unlock() - err = s.portAllower.RemoveAllowedPort(context.Background(), s.port) - if err != nil { - return fmt.Errorf("blocking previous port in firewall: %w", err) - } - - if s.settings.ListeningPort != 0 { - ctx := context.Background() - const listeningPort = 0 // 0 to clear the redirection - err = s.portAllower.RedirectPort(ctx, s.settings.Interface, s.port, listeningPort) + for _, port := range s.ports { + err = s.portAllower.RemoveAllowedPort(context.Background(), port) if err != nil { - return fmt.Errorf("removing previous port redirection in firewall: %w", err) + return fmt.Errorf("blocking previous port in firewall: %w", err) + } + + if s.settings.ListeningPort != 0 { + ctx := context.Background() + const listeningPort = 0 // 0 to clear the redirection + err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, listeningPort) + if err != nil { + return fmt.Errorf("removing previous port redirection in firewall: %w", err) + } } } - s.port = 0 + s.ports = nil filepath := s.settings.Filepath s.logger.Info("removing port file " + filepath) diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 3955ced9..14b03e58 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -26,7 +26,7 @@ var ( // PortForward obtains a VPN server side port forwarded from PIA. func (p *Provider) PortForward(ctx context.Context, - objects utils.PortForwardObjects) (port uint16, err error) { + objects utils.PortForwardObjects) (ports []uint16, err error) { switch { case objects.ServerName == "": panic("server name cannot be empty") @@ -43,17 +43,17 @@ func (p *Provider) PortForward(ctx context.Context, logger := objects.Logger if !objects.CanPortForward { - return 0, fmt.Errorf("%w: for server %s", ErrServerNameNotFound, serverName) + return nil, fmt.Errorf("%w: for server %s", ErrServerNameNotFound, serverName) } privateIPClient, err := newHTTPClient(serverName) if err != nil { - return 0, fmt.Errorf("creating custom HTTP client: %w", err) + return nil, fmt.Errorf("creating custom HTTP client: %w", err) } data, err := readPIAPortForwardData(p.portForwardPath) if err != nil { - return 0, fmt.Errorf("reading saved port forwarded data: %w", err) + return nil, fmt.Errorf("reading saved port forwarded data: %w", err) } dataFound := data.Port > 0 @@ -73,7 +73,7 @@ func (p *Provider) PortForward(ctx context.Context, data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway, p.portForwardPath, objects.Username, objects.Password) if err != nil { - return 0, fmt.Errorf("refreshing port forward data: %w", err) + return nil, fmt.Errorf("refreshing port forward data: %w", err) } durationToExpiration = data.Expiration.Sub(p.timeNow()) } @@ -81,10 +81,10 @@ func (p *Provider) PortForward(ctx context.Context, // First time binding if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil { - return 0, fmt.Errorf("binding port: %w", err) + return nil, fmt.Errorf("binding port: %w", err) } - return data.Port, nil + return []uint16{data.Port}, nil } var ( diff --git a/internal/provider/protonvpn/portforward.go b/internal/provider/protonvpn/portforward.go index d808c428..54d3fca7 100644 --- a/internal/provider/protonvpn/portforward.go +++ b/internal/provider/protonvpn/portforward.go @@ -13,7 +13,7 @@ import ( // PortForward obtains a VPN server side port forwarded from ProtonVPN gateway. func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) ( - port uint16, err error) { + ports []uint16, err error) { client := natpmp.New() _, externalIPv4Address, err := client.ExternalAddress(ctx, objects.Gateway) @@ -21,7 +21,7 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj if strings.HasSuffix(err.Error(), "connection refused") { err = fmt.Errorf("%w - make sure you have +pmp at the end of your OpenVPN username", err) } - return 0, fmt.Errorf("getting external IPv4 address: %w", err) + return nil, fmt.Errorf("getting external IPv4 address: %w", err) } logger := objects.Logger @@ -34,7 +34,7 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj client.AddPortMapping(ctx, objects.Gateway, "udp", internalPort, externalPort, lifetime) if err != nil { - return 0, fmt.Errorf("adding UDP port mapping: %w", err) + return nil, fmt.Errorf("adding UDP port mapping: %w", err) } checkLifetime(logger, "UDP", lifetime, assignedLifetime) @@ -42,16 +42,15 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj client.AddPortMapping(ctx, objects.Gateway, "tcp", internalPort, externalPort, lifetime) if err != nil { - return 0, fmt.Errorf("adding TCP port mapping: %w", err) + return nil, fmt.Errorf("adding TCP port mapping: %w", err) } checkLifetime(logger, "TCP", lifetime, assignedLifetime) checkExternalPorts(logger, assignedUDPExternalPort, assignedTCPExternalPort) - port = assignedTCPExternalPort - p.portForwarded = port + p.portForwarded = assignedTCPExternalPort - return port, nil + return []uint16{assignedTCPExternalPort}, nil } func checkLifetime(logger utils.Logger, protocol string, diff --git a/internal/server/interfaces.go b/internal/server/interfaces.go index 62c24acb..5b470251 100644 --- a/internal/server/interfaces.go +++ b/internal/server/interfaces.go @@ -22,7 +22,7 @@ type DNSLoop interface { } type PortForwardedGetter interface { - GetPortForwarded() (portForwarded uint16) + GetPortsForwarded() (ports []uint16) } type PublicIPLoop interface { diff --git a/internal/server/openvpn.go b/internal/server/openvpn.go index 6f8a2d08..d34dfd0f 100644 --- a/internal/server/openvpn.go +++ b/internal/server/openvpn.go @@ -123,12 +123,21 @@ func (h *openvpnHandler) getSettings(w http.ResponseWriter) { } func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) { - port := h.pf.GetPortForwarded() + ports := h.pf.GetPortsForwarded() encoder := json.NewEncoder(w) - data := portWrapper{Port: port} - if err := encoder.Encode(data); err != nil { + var data any + switch len(ports) { + case 0: + data = portWrapper{Port: 0} // TODO v4 change to portsWrapper + case 1: + data = portWrapper{Port: ports[0]} // TODO v4 change to portsWrapper + default: + data = portsWrapper{Ports: ports} + } + + err := encoder.Encode(data) + if err != nil { h.warner.Warn(err.Error()) w.WriteHeader(http.StatusInternalServerError) - return } } diff --git a/internal/server/wrappers.go b/internal/server/wrappers.go index 26081c9a..c35f1c6a 100644 --- a/internal/server/wrappers.go +++ b/internal/server/wrappers.go @@ -25,10 +25,14 @@ func (sw *statusWrapper) getStatus() (status models.LoopStatus, err error) { } } -type portWrapper struct { +type portWrapper struct { // TODO v4 remove Port uint16 `json:"port"` } +type portsWrapper struct { + Ports []uint16 `json:"ports"` +} + type outcomeWrapper struct { Outcome string `json:"outcome"` } diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 6a4a145b..d116b8dc 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -47,7 +47,7 @@ type Provider interface { type PortForwarder interface { Name() string PortForward(ctx context.Context, objects utils.PortForwardObjects) ( - port uint16, err error) + ports []uint16, err error) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) } diff --git a/internal/vpn/portforward.go b/internal/vpn/portforward.go index c773c3fc..b31600c8 100644 --- a/internal/vpn/portforward.go +++ b/internal/vpn/portforward.go @@ -61,8 +61,8 @@ func (n *noPortForwarder) Name() string { } func (n *noPortForwarder) PortForward(context.Context, pfutils.PortForwardObjects) ( - port uint16, err error) { - return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) + ports []uint16, err error) { + return nil, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) } func (n *noPortForwarder) KeepPortForward(context.Context, pfutils.PortForwardObjects) (err error) {