chore(port-forward): support multiple port forwarded
This commit is contained in:
@@ -8,7 +8,7 @@ import (
|
||||
type Service interface {
|
||||
Start(ctx context.Context) (runError <-chan error, err error)
|
||||
Stop() (err error)
|
||||
GetPortForwarded() (port uint16)
|
||||
GetPortsForwarded() (ports []uint16)
|
||||
}
|
||||
|
||||
type Routing interface {
|
||||
|
||||
@@ -150,11 +150,11 @@ func (l *Loop) Stop() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Loop) GetPortForwarded() (port uint16) {
|
||||
func (l *Loop) GetPortsForwarded() (ports []uint16) {
|
||||
if l.service == nil {
|
||||
return 0
|
||||
return nil
|
||||
}
|
||||
return l.service.GetPortForwarded()
|
||||
return l.service.GetPortsForwarded()
|
||||
}
|
||||
|
||||
func ptrTo[T any](value T) *T {
|
||||
|
||||
@@ -3,13 +3,20 @@ package service
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (s *Service) writePortForwardedFile(port uint16) (err error) {
|
||||
func (s *Service) writePortForwardedFile(ports []uint16) (err error) {
|
||||
portStrings := make([]string, len(ports))
|
||||
for i, port := range ports {
|
||||
portStrings[i] = fmt.Sprint(int(port))
|
||||
}
|
||||
fileData := []byte(strings.Join(portStrings, "\n"))
|
||||
|
||||
filepath := s.settings.Filepath
|
||||
s.logger.Info("writing port file " + filepath)
|
||||
const perms = os.FileMode(0644)
|
||||
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
|
||||
err = os.WriteFile(filepath, fileData, perms)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing file: %w", err)
|
||||
}
|
||||
|
||||
22
internal/portforward/service/helpers.go
Normal file
22
internal/portforward/service/helpers.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func portsToString(ports []uint16) (s string) {
|
||||
switch len(ports) {
|
||||
case 0:
|
||||
return "no port forwarded"
|
||||
case 1:
|
||||
return "port forwarded is " + fmt.Sprint(int(ports[0]))
|
||||
default:
|
||||
portStrings := make([]string, len(ports))
|
||||
for i, port := range ports {
|
||||
portStrings[i] = fmt.Sprint(int(port))
|
||||
}
|
||||
return "ports forwarded are " + strings.Join(portStrings[:len(portStrings)-1], ", ") +
|
||||
" and " + portStrings[len(portStrings)-1]
|
||||
}
|
||||
}
|
||||
43
internal/portforward/service/helpers_test.go
Normal file
43
internal/portforward/service/helpers_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_portsToString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
ports []uint16
|
||||
s string
|
||||
}{
|
||||
"no_port": {
|
||||
s: "no port forwarded",
|
||||
},
|
||||
"one_port": {
|
||||
ports: []uint16{123},
|
||||
s: "port forwarded is 123",
|
||||
},
|
||||
"two_ports": {
|
||||
ports: []uint16{123, 456},
|
||||
s: "ports forwarded are 123 and 456",
|
||||
},
|
||||
"three_ports": {
|
||||
ports: []uint16{123, 456, 789},
|
||||
s: "ports forwarded are 123, 456 and 789",
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := portsToString(testCase.ports)
|
||||
|
||||
assert.Equal(t, testCase.s, s)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,6 @@ type Logger interface {
|
||||
type PortForwarder interface {
|
||||
Name() string
|
||||
PortForward(ctx context.Context, objects utils.PortForwardObjects) (
|
||||
port uint16, err error)
|
||||
ports []uint16, err error)
|
||||
KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
type Service struct {
|
||||
// State
|
||||
portMutex sync.RWMutex
|
||||
port uint16
|
||||
ports []uint16
|
||||
// Fixed parameters
|
||||
settings Settings
|
||||
puid int
|
||||
@@ -40,8 +40,10 @@ func New(settings Settings, routing Routing, client *http.Client,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) GetPortForwarded() (port uint16) {
|
||||
func (s *Service) GetPortsForwarded() (ports []uint16) {
|
||||
s.portMutex.RLock()
|
||||
defer s.portMutex.RUnlock()
|
||||
return s.port
|
||||
ports = make([]uint16, len(s.ports))
|
||||
copy(ports, s.ports)
|
||||
return ports
|
||||
}
|
||||
|
||||
@@ -31,13 +31,14 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
Username: s.settings.Username,
|
||||
Password: s.settings.Password,
|
||||
}
|
||||
port, err := s.settings.PortForwarder.PortForward(ctx, obj)
|
||||
ports, err := s.settings.PortForwarder.PortForward(ctx, obj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port forwarding for the first time: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("port forwarded is " + fmt.Sprint(int(port)))
|
||||
s.logger.Info(portsToString(ports))
|
||||
|
||||
for _, port := range ports {
|
||||
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("allowing port in firewall: %w", err)
|
||||
@@ -49,15 +50,16 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
||||
return nil, fmt.Errorf("redirecting port in firewall: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = s.writePortForwardedFile(port)
|
||||
err = s.writePortForwardedFile(ports)
|
||||
if err != nil {
|
||||
_ = s.cleanup()
|
||||
return nil, fmt.Errorf("writing port file: %w", err)
|
||||
}
|
||||
|
||||
s.portMutex.Lock()
|
||||
s.port = port
|
||||
s.ports = ports
|
||||
s.portMutex.Unlock()
|
||||
|
||||
keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
|
||||
|
||||
@@ -11,7 +11,7 @@ func (s *Service) Stop() (err error) {
|
||||
defer s.startStopMutex.Unlock()
|
||||
|
||||
s.portMutex.RLock()
|
||||
serviceNotRunning := s.port == 0
|
||||
serviceNotRunning := len(s.ports) == 0
|
||||
s.portMutex.RUnlock()
|
||||
if serviceNotRunning {
|
||||
// TODO replace with goservices.ErrAlreadyStopped
|
||||
@@ -30,7 +30,8 @@ func (s *Service) cleanup() (err error) {
|
||||
s.portMutex.Lock()
|
||||
defer s.portMutex.Unlock()
|
||||
|
||||
err = s.portAllower.RemoveAllowedPort(context.Background(), s.port)
|
||||
for _, port := range s.ports {
|
||||
err = s.portAllower.RemoveAllowedPort(context.Background(), port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("blocking previous port in firewall: %w", err)
|
||||
}
|
||||
@@ -38,13 +39,14 @@ func (s *Service) cleanup() (err error) {
|
||||
if s.settings.ListeningPort != 0 {
|
||||
ctx := context.Background()
|
||||
const listeningPort = 0 // 0 to clear the redirection
|
||||
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, s.port, listeningPort)
|
||||
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, listeningPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing previous port redirection in firewall: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.port = 0
|
||||
s.ports = nil
|
||||
|
||||
filepath := s.settings.Filepath
|
||||
s.logger.Info("removing port file " + filepath)
|
||||
|
||||
@@ -26,7 +26,7 @@ var (
|
||||
|
||||
// PortForward obtains a VPN server side port forwarded from PIA.
|
||||
func (p *Provider) PortForward(ctx context.Context,
|
||||
objects utils.PortForwardObjects) (port uint16, err error) {
|
||||
objects utils.PortForwardObjects) (ports []uint16, err error) {
|
||||
switch {
|
||||
case objects.ServerName == "":
|
||||
panic("server name cannot be empty")
|
||||
@@ -43,17 +43,17 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
logger := objects.Logger
|
||||
|
||||
if !objects.CanPortForward {
|
||||
return 0, fmt.Errorf("%w: for server %s", ErrServerNameNotFound, serverName)
|
||||
return nil, fmt.Errorf("%w: for server %s", ErrServerNameNotFound, serverName)
|
||||
}
|
||||
|
||||
privateIPClient, err := newHTTPClient(serverName)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("creating custom HTTP client: %w", err)
|
||||
return nil, fmt.Errorf("creating custom HTTP client: %w", err)
|
||||
}
|
||||
|
||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("reading saved port forwarded data: %w", err)
|
||||
return nil, fmt.Errorf("reading saved port forwarded data: %w", err)
|
||||
}
|
||||
|
||||
dataFound := data.Port > 0
|
||||
@@ -73,7 +73,7 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
|
||||
p.portForwardPath, objects.Username, objects.Password)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("refreshing port forward data: %w", err)
|
||||
return nil, fmt.Errorf("refreshing port forward data: %w", err)
|
||||
}
|
||||
durationToExpiration = data.Expiration.Sub(p.timeNow())
|
||||
}
|
||||
@@ -81,10 +81,10 @@ func (p *Provider) PortForward(ctx context.Context,
|
||||
|
||||
// First time binding
|
||||
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
|
||||
return 0, fmt.Errorf("binding port: %w", err)
|
||||
return nil, fmt.Errorf("binding port: %w", err)
|
||||
}
|
||||
|
||||
return data.Port, nil
|
||||
return []uint16{data.Port}, nil
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
// PortForward obtains a VPN server side port forwarded from ProtonVPN gateway.
|
||||
func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObjects) (
|
||||
port uint16, err error) {
|
||||
ports []uint16, err error) {
|
||||
client := natpmp.New()
|
||||
_, externalIPv4Address, err := client.ExternalAddress(ctx,
|
||||
objects.Gateway)
|
||||
@@ -21,7 +21,7 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj
|
||||
if strings.HasSuffix(err.Error(), "connection refused") {
|
||||
err = fmt.Errorf("%w - make sure you have +pmp at the end of your OpenVPN username", err)
|
||||
}
|
||||
return 0, fmt.Errorf("getting external IPv4 address: %w", err)
|
||||
return nil, fmt.Errorf("getting external IPv4 address: %w", err)
|
||||
}
|
||||
|
||||
logger := objects.Logger
|
||||
@@ -34,7 +34,7 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj
|
||||
client.AddPortMapping(ctx, objects.Gateway, "udp",
|
||||
internalPort, externalPort, lifetime)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("adding UDP port mapping: %w", err)
|
||||
return nil, fmt.Errorf("adding UDP port mapping: %w", err)
|
||||
}
|
||||
checkLifetime(logger, "UDP", lifetime, assignedLifetime)
|
||||
|
||||
@@ -42,16 +42,15 @@ func (p *Provider) PortForward(ctx context.Context, objects utils.PortForwardObj
|
||||
client.AddPortMapping(ctx, objects.Gateway, "tcp",
|
||||
internalPort, externalPort, lifetime)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("adding TCP port mapping: %w", err)
|
||||
return nil, fmt.Errorf("adding TCP port mapping: %w", err)
|
||||
}
|
||||
checkLifetime(logger, "TCP", lifetime, assignedLifetime)
|
||||
|
||||
checkExternalPorts(logger, assignedUDPExternalPort, assignedTCPExternalPort)
|
||||
port = assignedTCPExternalPort
|
||||
|
||||
p.portForwarded = port
|
||||
p.portForwarded = assignedTCPExternalPort
|
||||
|
||||
return port, nil
|
||||
return []uint16{assignedTCPExternalPort}, nil
|
||||
}
|
||||
|
||||
func checkLifetime(logger utils.Logger, protocol string,
|
||||
|
||||
@@ -22,7 +22,7 @@ type DNSLoop interface {
|
||||
}
|
||||
|
||||
type PortForwardedGetter interface {
|
||||
GetPortForwarded() (portForwarded uint16)
|
||||
GetPortsForwarded() (ports []uint16)
|
||||
}
|
||||
|
||||
type PublicIPLoop interface {
|
||||
|
||||
@@ -123,12 +123,21 @@ func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
|
||||
port := h.pf.GetPortForwarded()
|
||||
ports := h.pf.GetPortsForwarded()
|
||||
encoder := json.NewEncoder(w)
|
||||
data := portWrapper{Port: port}
|
||||
if err := encoder.Encode(data); err != nil {
|
||||
var data any
|
||||
switch len(ports) {
|
||||
case 0:
|
||||
data = portWrapper{Port: 0} // TODO v4 change to portsWrapper
|
||||
case 1:
|
||||
data = portWrapper{Port: ports[0]} // TODO v4 change to portsWrapper
|
||||
default:
|
||||
data = portsWrapper{Ports: ports}
|
||||
}
|
||||
|
||||
err := encoder.Encode(data)
|
||||
if err != nil {
|
||||
h.warner.Warn(err.Error())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,10 +25,14 @@ func (sw *statusWrapper) getStatus() (status models.LoopStatus, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
type portWrapper struct {
|
||||
type portWrapper struct { // TODO v4 remove
|
||||
Port uint16 `json:"port"`
|
||||
}
|
||||
|
||||
type portsWrapper struct {
|
||||
Ports []uint16 `json:"ports"`
|
||||
}
|
||||
|
||||
type outcomeWrapper struct {
|
||||
Outcome string `json:"outcome"`
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ type Provider interface {
|
||||
type PortForwarder interface {
|
||||
Name() string
|
||||
PortForward(ctx context.Context, objects utils.PortForwardObjects) (
|
||||
port uint16, err error)
|
||||
ports []uint16, err error)
|
||||
KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error)
|
||||
}
|
||||
|
||||
|
||||
@@ -61,8 +61,8 @@ func (n *noPortForwarder) Name() string {
|
||||
}
|
||||
|
||||
func (n *noPortForwarder) PortForward(context.Context, pfutils.PortForwardObjects) (
|
||||
port uint16, err error) {
|
||||
return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
|
||||
ports []uint16, err error) {
|
||||
return nil, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
|
||||
}
|
||||
|
||||
func (n *noPortForwarder) KeepPortForward(context.Context, pfutils.PortForwardObjects) (err error) {
|
||||
|
||||
Reference in New Issue
Block a user