Maint: port forwarding refactoring (#543)

- portforward package
- portforward run loop
- Less functional arguments and cycles
This commit is contained in:
Quentin McGaw
2021-07-28 08:35:44 -07:00
committed by GitHub
parent c777f8d97d
commit 2998cf5e48
25 changed files with 639 additions and 255 deletions

View File

@@ -42,6 +42,7 @@ func (l *Loop) collectLines(stdout, stderr <-chan string, done chan<- struct{})
}
if strings.Contains(line, "Initialization Sequence Completed") {
l.tunnelReady <- struct{}{}
l.startPFCh <- struct{}{}
}
}
}

View File

@@ -1,7 +1,6 @@
package openvpn
import (
"net"
"net/http"
"time"
@@ -11,6 +10,8 @@ import (
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn/state"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/logging"
)
@@ -22,8 +23,6 @@ type Looper interface {
loopstate.Applier
SettingsGetSetter
ServersGetterSetter
PortForwadedGetter
PortForwader
}
type Loop struct {
@@ -35,19 +34,21 @@ type Loop struct {
pgid int
targetConfPath string
// Configurators
conf StarterAuthWriter
fw firewallConfigurer
conf StarterAuthWriter
fw firewallConfigurer
routing routing.VPNLocalGatewayIPGetter
portForward portforward.StartStopper
// Other objects
logger, pfLogger logging.Logger
client *http.Client
tunnelReady chan<- struct{}
logger logging.Logger
client *http.Client
tunnelReady chan<- struct{}
// Internal channels and values
stop <-chan struct{}
stopped chan<- struct{}
start <-chan struct{}
running chan<- models.LoopStatus
portForwardSignals chan net.IP
userTrigger bool
stop <-chan struct{}
stopped chan<- struct{}
start <-chan struct{}
running chan<- models.LoopStatus
userTrigger bool
startPFCh chan struct{}
// Internal constant values
backoffTime time.Duration
}
@@ -63,7 +64,8 @@ const (
func NewLoop(settings configuration.OpenVPN, username string,
puid, pgid int, allServers models.AllServers, conf Configurator,
fw firewallConfigurer, logger logging.ParentLogger,
fw firewallConfigurer, routing routing.VPNLocalGatewayIPGetter,
portForward portforward.StartStopper, logger logging.Logger,
client *http.Client, tunnelReady chan<- struct{}) *Loop {
start := make(chan struct{})
running := make(chan models.LoopStatus)
@@ -74,24 +76,25 @@ func NewLoop(settings configuration.OpenVPN, username string,
state := state.New(statusManager, settings, allServers)
return &Loop{
statusManager: statusManager,
state: state,
username: username,
puid: puid,
pgid: pgid,
targetConfPath: constants.OpenVPNConf,
conf: conf,
fw: fw,
logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}),
pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}),
client: client,
tunnelReady: tunnelReady,
start: start,
running: running,
stop: stop,
stopped: stopped,
portForwardSignals: make(chan net.IP),
userTrigger: true,
backoffTime: defaultBackoffTime,
statusManager: statusManager,
state: state,
username: username,
puid: puid,
pgid: pgid,
targetConfPath: constants.OpenVPNConf,
conf: conf,
fw: fw,
routing: routing,
portForward: portForward,
logger: logger,
client: client,
tunnelReady: tunnelReady,
start: start,
running: running,
stop: stop,
stopped: stopped,
userTrigger: true,
startPFCh: make(chan struct{}),
backoffTime: defaultBackoffTime,
}
}

View File

@@ -0,0 +1,47 @@
package openvpn
import (
"context"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/provider"
)
func (l *Loop) startPortForwarding(ctx context.Context,
portForwarder provider.PortForwarder, serverName string) {
if !l.GetSettings().Provider.PortForwarding.Enabled {
return
}
// only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP()
if err != nil {
l.logger.Error("cannot obtain VPN local gateway IP: " + err.Error())
return
}
l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{
PortForwarder: portForwarder,
Gateway: gateway,
ServerName: serverName,
Interface: constants.TUN,
}
_, err = l.portForward.Start(ctx, pfData)
if err != nil {
l.logger.Error("cannot start port forwarding: " + err.Error())
}
}
func (l *Loop) stopPortForwarding(ctx context.Context, timeout time.Duration) {
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
_, err := l.portForward.Stop(ctx)
if err != nil {
l.logger.Error("cannot stop port forwarding: " + err.Error())
}
}

View File

@@ -1,39 +0,0 @@
package openvpn
import (
"context"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/openvpn/state"
"github.com/qdm12/gluetun/internal/provider"
)
type PortForwadedGetter = state.PortForwardedGetter
func (l *Loop) GetPortForwarded() (port uint16) {
return l.state.GetPortForwarded()
}
type PortForwader interface {
PortForward(vpnGatewayIP net.IP)
}
func (l *Loop) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
// portForward is a blocking operation which may or may not be infinite.
// You should therefore always call it in a goroutine.
func (l *Loop) portForward(ctx context.Context,
providerConf provider.Provider, client *http.Client, gateway net.IP) {
settings := l.state.GetSettings()
if !settings.Provider.PortForwarding.Enabled {
return
}
syncState := func(port uint16) (pfFilepath string) {
l.state.SetPortForwarded(port)
settings := l.state.GetSettings()
return settings.Provider.PortForwarding.Filepath
}
providerConf.PortForward(ctx, client, l.pfLogger,
gateway, l.fw, syncState)
}

View File

@@ -88,41 +88,31 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
<-lineCollectionDone
}
// Needs the stream line from main.go to know when the tunnel is up
portForwardDone := make(chan struct{})
go func(ctx context.Context) {
defer close(portForwardDone)
select {
// TODO have a way to disable pf with a context
case <-ctx.Done():
return
case gateway := <-l.portForwardSignals:
l.portForward(ctx, providerConf, l.client, gateway)
}
}(openvpnCtx)
l.backoffTime = defaultBackoffTime
l.signalOrSetStatus(constants.Running)
stayHere := true
for stayHere {
select {
case <-l.startPFCh:
l.startPortForwarding(ctx, providerConf, connection.Hostname)
case <-ctx.Done():
const pfTimeout = 100 * time.Millisecond
l.stopPortForwarding(context.Background(), pfTimeout)
openvpnCancel()
<-waitError
close(waitError)
closeStreams()
<-portForwardDone
return
case <-l.stop:
l.userTrigger = true
l.logger.Info("stopping")
l.stopPortForwarding(ctx, 0)
openvpnCancel()
<-waitError
// do not close waitError or the waitError
// select case will trigger
closeStreams()
<-portForwardDone
l.stopped <- struct{}{}
case <-l.start:
l.userTrigger = true
@@ -134,9 +124,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.statusManager.Lock() // prevent SetStatus from running in parallel
l.stopPortForwarding(ctx, 0)
openvpnCancel()
l.statusManager.SetStatus(constants.Crashed)
<-portForwardDone
l.logAndWait(ctx, err)
stayHere = false

View File

@@ -1,26 +0,0 @@
package state
type PortForwardedGetterSetter interface {
PortForwardedGetter
SetPortForwarded(port uint16)
}
type PortForwardedGetter interface {
GetPortForwarded() (port uint16)
}
// GetPortForwarded is used by the control HTTP server
// to obtain the port currently forwarded.
func (s *State) GetPortForwarded() (port uint16) {
s.portForwardedMu.RLock()
defer s.portForwardedMu.RUnlock()
return s.portForwarded
}
// SetPortForwarded is only used from within the OpenVPN loop
// to set the port forwarded.
func (s *State) SetPortForwarded(port uint16) {
s.portForwardedMu.Lock()
defer s.portForwardedMu.Unlock()
s.portForwarded = port
}

View File

@@ -13,7 +13,6 @@ var _ Manager = (*State)(nil)
type Manager interface {
SettingsGetSetter
ServersGetterSetter
PortForwardedGetterSetter
GetSettingsAndServers() (settings configuration.OpenVPN,
allServers models.AllServers)
}
@@ -36,9 +35,6 @@ type State struct {
allServers models.AllServers
allServersMu sync.RWMutex
portForwarded uint16
portForwardedMu sync.RWMutex
}
func (s *State) GetSettingsAndServers() (settings configuration.OpenVPN,