Compare commits

..

3 Commits

Author SHA1 Message Date
Quentin McGaw (desktop)
f569998c93 Fix: install latest apk-tools before using apk 2021-08-09 14:44:06 +00:00
Quentin McGaw (desktop)
9877366c51 Fix: install latest apk-tools by default 2021-08-09 14:43:46 +00:00
Quentin McGaw (desktop)
61066e3896 Fix restart mutex unlocking for loops 2021-08-09 14:38:15 +00:00
49 changed files with 2316 additions and 2244 deletions

View File

@@ -15,10 +15,6 @@ assignees: qdm12
**Host OS** (approximate answer is fine too): Ubuntu 18
<!---
🚧 If this is about the Unraid template see https://github.com/qdm12/gluetun/discussions/550
-->
**CPU arch** or **device name**: amd64
**What VPN provider are you using**:

View File

@@ -147,10 +147,10 @@ ENV VPNSP=pia \
# Shadowsocks
SHADOWSOCKS=off \
SHADOWSOCKS_LOG=off \
SHADOWSOCKS_ADDRESS=":8388" \
SHADOWSOCKS_PORT=8388 \
SHADOWSOCKS_PASSWORD= \
SHADOWSOCKS_PASSWORD_SECRETFILE=/run/secrets/shadowsocks_password \
SHADOWSOCKS_CIPHER=chacha20-ietf-poly1305 \
SHADOWSOCKS_METHOD=chacha20-ietf-poly1305 \
UPDATER_PERIOD=0
ENTRYPOINT ["/entrypoint"]
EXPOSE 8000/tcp 8888/tcp 8388/tcp 8388/udp

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
@@ -23,7 +24,6 @@ import (
"github.com/qdm12/gluetun/internal/httpproxy"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/gluetun/internal/server"
@@ -321,16 +321,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
tickersGroupHandler := goshutdown.NewGroupHandler("tickers", defaultGroupSettings)
otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings)
portForwardLogger := logger.NewChild(logging.Settings{Prefix: "port forwarding: "})
portForwardLooper := portforward.NewLoop(allSettings.OpenVPN.Provider.PortForwarding,
httpClient, firewallConf, portForwardLogger)
portForwardHandler, portForwardCtx, portForwardDone := goshutdown.NewGoRoutineHandler(
"port forwarding", goshutdown.GoRoutineSettings{Timeout: time.Second})
go portForwardLooper.Run(portForwardCtx, portForwardDone)
openvpnLogger := logger.NewChild(logging.Settings{Prefix: "openvpn: "})
openvpnLooper := openvpn.NewLoop(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
ovpnConf, firewallConf, routingConf, portForwardLooper, openvpnLogger, httpClient, tunnelReadyCh)
ovpnConf, firewallConf, logger, httpClient, tunnelReadyCh)
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
// wait for restartOpenvpn
@@ -386,7 +378,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
"events routing", defaultGoRoutineSettings)
go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh,
unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient,
allSettings.VersionInformation)
allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward,
)
controlGroupHandler.Add(eventsRoutingHandler)
controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port))
@@ -395,7 +388,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
"http server", defaultGoRoutineSettings)
httpServer := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
logger.NewChild(logging.Settings{Prefix: "http server: "}),
buildInfo, openvpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper)
buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
go httpServer.Run(httpServerCtx, httpServerDone)
controlGroupHandler.Add(httpServerHandler)
@@ -413,7 +406,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
}
orderHandler := goshutdown.NewOrder("gluetun", orderSettings)
orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler,
openvpnHandler, portForwardHandler, otherGroupHandler)
openvpnHandler, otherGroupHandler)
// Start openvpn for the first time in a blocking call
// until openvpn is launched
@@ -421,6 +414,13 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
<-ctx.Done()
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file " + allSettings.OpenVPN.Provider.PortForwarding.Filepath)
if err := os.Remove(allSettings.OpenVPN.Provider.PortForwarding.Filepath); err != nil {
logger.Error(err.Error())
}
}
return orderHandler.Shutdown(context.Background())
}
@@ -450,7 +450,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
tunnelReadyCh <-chan struct{},
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
routing routing.VPNGetter, logger logging.Logger, httpClient *http.Client,
versionInformation bool) {
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
defer close(done)
// for linters only
@@ -503,6 +503,15 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
updaterTickerDone = make(chan struct{})
go unboundLooper.RunRestartTicker(restartTickerContext, unboundTickerDone)
go updaterLooper.RunRestartTicker(restartTickerContext, updaterTickerDone)
if portForwardingEnabled {
// vpnGateway required only for PIA
vpnGateway, err := routing.VPNLocalGatewayIP()
if err != nil {
logger.Error("cannot get VPN local gateway IP: " + err.Error())
}
logger.Info("VPN gateway IP address: " + vpnGateway.String())
startPortForward(vpnGateway)
}
}
}
}

2
go.mod
View File

@@ -9,7 +9,7 @@ require (
github.com/qdm12/golibs v0.0.0-20210723175634-a75ca7fd74c2
github.com/qdm12/goshutdown v0.1.0
github.com/qdm12/gosplash v0.1.0
github.com/qdm12/ss-server v0.3.0
github.com/qdm12/ss-server v0.2.0
github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e
github.com/stretchr/testify v1.7.0
github.com/vishvananda/netlink v1.1.0

4
go.sum
View File

@@ -72,8 +72,8 @@ github.com/qdm12/goshutdown v0.1.0 h1:lmwnygdXtnr2pa6VqfR/bm8077/BnBef1+7CP96B7S
github.com/qdm12/goshutdown v0.1.0/go.mod h1:/LP3MWLqI+wGH/ijfaUG+RHzBbKXIiVKnrg5vXOCf6Q=
github.com/qdm12/gosplash v0.1.0 h1:Sfl+zIjFZFP7b0iqf2l5UkmEY97XBnaKkH3FNY6Gf7g=
github.com/qdm12/gosplash v0.1.0/go.mod h1:+A3fWW4/rUeDXhY3ieBzwghKdnIPFJgD8K3qQkenJlw=
github.com/qdm12/ss-server v0.3.0 h1:BfKv4OU6dYb2KcDMYpTc7LIuO2jB73g3JCzy988GrLI=
github.com/qdm12/ss-server v0.3.0/go.mod h1:ug+nWfuzKw/h5fxL1B6e9/OhkVuWJX4i2V1Pf0pJU1o=
github.com/qdm12/ss-server v0.2.0 h1:+togLzeeLAJ68MD1JqOWvYi9rl9t/fx1Qh7wKzZhY1g=
github.com/qdm12/ss-server v0.2.0/go.mod h1:+1bWO1EfWNvsGM5Cuep6vneChK2OHniqtAsED9Fh1y0=
github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e h1:4q+uFLawkaQRq3yARYLsjJPZd2wYwxn4g6G/5v0xW1g=
github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e/go.mod h1:UvJRGkZ9XL3/D7e7JiTTVLm1F3Cymd3/gFpD6frEpBo=
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg=

View File

@@ -4,10 +4,11 @@ import (
"fmt"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/params"
)
func (settings *Provider) cyberghostLines() (lines []string) {
lines = append(lines, lastIndent+"Server groups: "+commaJoin(settings.ServerSelection.Groups))
lines = append(lines, lastIndent+"Server group: "+settings.ServerSelection.Group)
if len(settings.ServerSelection.Regions) > 0 {
lines = append(lines, lastIndent+"Regions: "+commaJoin(settings.ServerSelection.Regions))
@@ -51,8 +52,8 @@ func (settings *Provider) readCyberghost(r reader) (err error) {
return err
}
settings.ServerSelection.Groups, err = r.env.CSVInside("CYBERGHOST_GROUP",
constants.CyberghostGroupChoices())
settings.ServerSelection.Group, err = r.env.Inside("CYBERGHOST_GROUP",
constants.CyberghostGroupChoices(), params.Default("Premium UDP Europe"))
if err != nil {
return fmt.Errorf("environment variable CYBERGHOST_GROUP: %w", err)
}

View File

@@ -33,7 +33,7 @@ func Test_OpenVPN_JSON(t *testing.T) {
"server_selection": {
"tcp": false,
"regions": null,
"groups": null,
"group": "",
"countries": null,
"cities": null,
"hostnames": null,

View File

@@ -24,7 +24,7 @@ func Test_Provider_lines(t *testing.T) {
settings: Provider{
Name: constants.Cyberghost,
ServerSelection: ServerSelection{
Groups: []string{"group"},
Group: "group",
Regions: []string{"a", "El country"},
},
ExtraConfigOptions: ExtraConfigOptions{
@@ -35,7 +35,7 @@ func Test_Provider_lines(t *testing.T) {
lines: []string{
"|--Cyberghost settings:",
" |--Network protocol: udp",
" |--Server groups: group",
" |--Server group: group",
" |--Regions: a, El country",
" |--Client key is set",
" |--Client certificate is set",

View File

@@ -13,7 +13,7 @@ type ServerSelection struct { //nolint:maligned
Regions []string `json:"regions"`
// Cyberghost
Groups []string `json:"groups"`
Group string `json:"group"`
// Fastestvpn, HideMyAss, IPVanish, IVPN, Mullvad, PrivateVPN, Protonvpn, PureVPN, VPNUnlimited
Countries []string `json:"countries"`

View File

@@ -2,16 +2,19 @@ package configuration
import (
"fmt"
"strconv"
"strings"
"github.com/qdm12/golibs/params"
"github.com/qdm12/ss-server/pkg/tcpudp"
)
// ShadowSocks contains settings to configure the Shadowsocks server.
type ShadowSocks struct {
Enabled bool
tcpudp.Settings
Method string
Password string
Port uint16
Enabled bool
Log bool
}
func (settings *ShadowSocks) String() string {
@@ -25,12 +28,12 @@ func (settings *ShadowSocks) lines() (lines []string) {
lines = append(lines, lastIndent+"Shadowsocks server:")
lines = append(lines, indent+lastIndent+"Listening address: "+settings.Address)
lines = append(lines, indent+lastIndent+"Listening port: "+strconv.Itoa(int(settings.Port)))
lines = append(lines, indent+lastIndent+"Cipher: "+settings.CipherName)
lines = append(lines, indent+lastIndent+"Method: "+settings.Method)
if settings.LogAddresses {
lines = append(lines, indent+lastIndent+"Log addresses: enabled")
if settings.Log {
lines = append(lines, indent+lastIndent+"Logging: enabled")
}
return lines
@@ -49,61 +52,24 @@ func (settings *ShadowSocks) read(r reader) (err error) {
return err
}
settings.LogAddresses, err = r.env.OnOff("SHADOWSOCKS_LOG", params.Default("off"))
settings.Log, err = r.env.OnOff("SHADOWSOCKS_LOG", params.Default("off"))
if err != nil {
return fmt.Errorf("environment variable SHADOWSOCKS_LOG: %w", err)
}
settings.CipherName, err = r.env.Get("SHADOWSOCKS_CIPHER", params.Default("chacha20-ietf-poly1305"),
params.RetroKeys([]string{"SHADOWSOCKS_METHOD"}, r.onRetroActive))
settings.Method, err = r.env.Get("SHADOWSOCKS_METHOD", params.Default("chacha20-ietf-poly1305"))
if err != nil {
return fmt.Errorf("environment variable SHADOWSOCKS_CIPHER (or SHADOWSOCKS_METHOD): %w", err)
return fmt.Errorf("environment variable SHADOWSOCKS_METHOD: %w", err)
}
warning, err := settings.getAddress(r.env)
if warning != "" {
var warning string
settings.Port, warning, err = r.env.ListeningPort("SHADOWSOCKS_PORT", params.Default("8388"))
if len(warning) > 0 {
r.logger.Warn(warning)
}
if err != nil {
return err
return fmt.Errorf("environment variable SHADOWSOCKS_PORT: %w", err)
}
return nil
}
func (settings *ShadowSocks) getAddress(env params.Env) (
warning string, err error) {
address, err := env.Get("SHADOWSOCKS_LISTENING_ADDRESS")
if err != nil {
return "", fmt.Errorf("environment variable SHADOWSOCKS_LISTENING_ADDRESS: %w", err)
}
if address != "" {
address, warning, err := env.ListeningAddress("SHADOWSOCKS_LISTENING_ADDRESS")
if err != nil {
return "", fmt.Errorf("environment variable SHADOWSOCKS_LISTENING_ADDRESS: %w", err)
}
settings.Address = address
return warning, nil
}
// Retro-compatibility
const retroWarning = "You are using the old environment variable " +
"SHADOWSOCKS_PORT, please consider using " +
"SHADOWSOCKS_LISTENING_ADDRESS instead"
portStr, err := env.Get("SHADOWSOCKS_PORT")
if err != nil {
return retroWarning, fmt.Errorf("environment variable SHADOWSOCKS_PORT: %w", err)
} else if portStr != "" {
port, _, err := env.ListeningPort("SHADOWSOCKS_PORT")
if err != nil {
return retroWarning, fmt.Errorf("environment variable SHADOWSOCKS_PORT: %w", err)
}
settings.Address = ":" + fmt.Sprint(port)
return retroWarning, nil
}
// Default value
settings.Address = ":8388"
return "", nil
}

View File

@@ -1,8 +1,6 @@
package constants
import (
"sort"
"github.com/qdm12/gluetun/internal/models"
)
@@ -22,20 +20,11 @@ func CyberghostRegionChoices() (choices []string) {
func CyberghostGroupChoices() (choices []string) {
servers := CyberghostServers()
uniqueChoices := map[string]struct{}{}
for _, server := range servers {
uniqueChoices[server.Group] = struct{}{}
choices = make([]string, len(servers))
for i := range servers {
choices[i] = servers[i].Group
}
choices = make([]string, 0, len(uniqueChoices))
for choice := range uniqueChoices {
choices = append(choices, choice)
}
sortable := sort.StringSlice(choices)
sortable.Sort()
return sortable
return makeUnique(choices)
}
func CyberghostHostnameChoices() (choices []string) {

View File

@@ -1,18 +0,0 @@
package constants
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_CyberghostGroupChoices(t *testing.T) {
t.Parallel()
expected := []string{"Premium TCP Asia", "Premium TCP Europe",
"Premium TCP USA", "Premium UDP Asia", "Premium UDP Europe",
"Premium UDP USA"}
choices := CyberghostGroupChoices()
assert.Equal(t, expected, choices)
}

View File

@@ -48,12 +48,3 @@ func PIAServers() (servers []models.PIAServer) {
copy(servers, allServers.Pia.Servers)
return servers
}
func PIAServerWhereName(serverName string) (server models.PIAServer) {
for _, server := range PIAServers() {
if server.ServerName == serverName {
return server
}
}
return server
}

File diff suppressed because it is too large Load Diff

View File

@@ -51,7 +51,6 @@ func (s *State) ApplyStatus(ctx context.Context, status models.LoopStatus) (
case <-ctx.Done():
case newStatus = <-s.running:
}
s.SetStatus(newStatus)
return newStatus.String(), nil

View File

@@ -42,7 +42,6 @@ func (l *Loop) collectLines(stdout, stderr <-chan string, done chan<- struct{})
}
if strings.Contains(line, "Initialization Sequence Completed") {
l.tunnelReady <- struct{}{}
l.startPFCh <- struct{}{}
}
}
}

View File

@@ -1,6 +1,7 @@
package openvpn
import (
"net"
"net/http"
"time"
@@ -10,8 +11,6 @@ import (
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn/state"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/logging"
)
@@ -23,6 +22,8 @@ type Looper interface {
loopstate.Applier
SettingsGetSetter
ServersGetterSetter
PortForwadedGetter
PortForwader
}
type Loop struct {
@@ -34,21 +35,19 @@ type Loop struct {
pgid int
targetConfPath string
// Configurators
conf StarterAuthWriter
fw firewallConfigurer
routing routing.VPNLocalGatewayIPGetter
portForward portforward.StartStopper
conf StarterAuthWriter
fw firewallConfigurer
// Other objects
logger logging.Logger
client *http.Client
tunnelReady chan<- struct{}
logger, pfLogger logging.Logger
client *http.Client
tunnelReady chan<- struct{}
// Internal channels and values
stop <-chan struct{}
stopped chan<- struct{}
start <-chan struct{}
running chan<- models.LoopStatus
userTrigger bool
startPFCh chan struct{}
stop <-chan struct{}
stopped chan<- struct{}
start <-chan struct{}
running chan<- models.LoopStatus
portForwardSignals chan net.IP
userTrigger bool
// Internal constant values
backoffTime time.Duration
}
@@ -64,8 +63,7 @@ const (
func NewLoop(settings configuration.OpenVPN, username string,
puid, pgid int, allServers models.AllServers, conf Configurator,
fw firewallConfigurer, routing routing.VPNLocalGatewayIPGetter,
portForward portforward.StartStopper, logger logging.Logger,
fw firewallConfigurer, logger logging.ParentLogger,
client *http.Client, tunnelReady chan<- struct{}) *Loop {
start := make(chan struct{})
running := make(chan models.LoopStatus)
@@ -76,25 +74,24 @@ func NewLoop(settings configuration.OpenVPN, username string,
state := state.New(statusManager, settings, allServers)
return &Loop{
statusManager: statusManager,
state: state,
username: username,
puid: puid,
pgid: pgid,
targetConfPath: constants.OpenVPNConf,
conf: conf,
fw: fw,
routing: routing,
portForward: portForward,
logger: logger,
client: client,
tunnelReady: tunnelReady,
start: start,
running: running,
stop: stop,
stopped: stopped,
userTrigger: true,
startPFCh: make(chan struct{}),
backoffTime: defaultBackoffTime,
statusManager: statusManager,
state: state,
username: username,
puid: puid,
pgid: pgid,
targetConfPath: constants.OpenVPNConf,
conf: conf,
fw: fw,
logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}),
pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}),
client: client,
tunnelReady: tunnelReady,
start: start,
running: running,
stop: stop,
stopped: stopped,
portForwardSignals: make(chan net.IP),
userTrigger: true,
backoffTime: defaultBackoffTime,
}
}

View File

@@ -1,54 +0,0 @@
package openvpn
import (
"context"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/provider"
)
func (l *Loop) startPortForwarding(ctx context.Context,
enabled bool, portForwarder provider.PortForwarder,
serverName string) {
if !enabled {
return
}
// only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP()
if err != nil {
l.logger.Error("cannot obtain VPN local gateway IP: " + err.Error())
return
}
l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{
PortForwarder: portForwarder,
Gateway: gateway,
ServerName: serverName,
Interface: constants.TUN,
}
_, err = l.portForward.Start(ctx, pfData)
if err != nil {
l.logger.Error("cannot start port forwarding: " + err.Error())
}
}
func (l *Loop) stopPortForwarding(ctx context.Context, enabled bool,
timeout time.Duration) {
if !enabled {
return // nothing to stop
}
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
_, err := l.portForward.Stop(ctx)
if err != nil {
l.logger.Error("cannot stop port forwarding: " + err.Error())
}
}

View File

@@ -0,0 +1,39 @@
package openvpn
import (
"context"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/openvpn/state"
"github.com/qdm12/gluetun/internal/provider"
)
type PortForwadedGetter = state.PortForwardedGetter
func (l *Loop) GetPortForwarded() (port uint16) {
return l.state.GetPortForwarded()
}
type PortForwader interface {
PortForward(vpnGatewayIP net.IP)
}
func (l *Loop) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
// portForward is a blocking operation which may or may not be infinite.
// You should therefore always call it in a goroutine.
func (l *Loop) portForward(ctx context.Context,
providerConf provider.Provider, client *http.Client, gateway net.IP) {
settings := l.state.GetSettings()
if !settings.Provider.PortForwarding.Enabled {
return
}
syncState := func(port uint16) (pfFilepath string) {
l.state.SetPortForwarded(port)
settings := l.state.GetSettings()
return settings.Provider.PortForwarding.Filepath
}
providerConf.PortForward(ctx, client, l.pfLogger,
gateway, l.fw, syncState)
}

View File

@@ -88,33 +88,41 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
<-lineCollectionDone
}
// Needs the stream line from main.go to know when the tunnel is up
portForwardDone := make(chan struct{})
go func(ctx context.Context) {
defer close(portForwardDone)
select {
// TODO have a way to disable pf with a context
case <-ctx.Done():
return
case gateway := <-l.portForwardSignals:
l.portForward(ctx, providerConf, l.client, gateway)
}
}(openvpnCtx)
l.backoffTime = defaultBackoffTime
l.signalOrSetStatus(constants.Running)
stayHere := true
for stayHere {
select {
case <-l.startPFCh:
l.startPortForwarding(ctx, settings.Provider.PortForwarding.Enabled,
providerConf, connection.Hostname)
case <-ctx.Done():
const pfTimeout = 100 * time.Millisecond
l.stopPortForwarding(context.Background(),
settings.Provider.PortForwarding.Enabled, pfTimeout)
openvpnCancel()
<-waitError
close(waitError)
closeStreams()
<-portForwardDone
return
case <-l.stop:
l.userTrigger = true
l.logger.Info("stopping")
l.stopPortForwarding(ctx, settings.Provider.PortForwarding.Enabled, 0)
openvpnCancel()
<-waitError
// do not close waitError or the waitError
// select case will trigger
closeStreams()
<-portForwardDone
l.stopped <- struct{}{}
case <-l.start:
l.userTrigger = true
@@ -126,9 +134,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.statusManager.Lock() // prevent SetStatus from running in parallel
l.stopPortForwarding(ctx, settings.Provider.PortForwarding.Enabled, 0)
openvpnCancel()
l.statusManager.SetStatus(constants.Crashed)
<-portForwardDone
l.logAndWait(ctx, err)
stayHere = false

View File

@@ -13,6 +13,7 @@ var _ Manager = (*State)(nil)
type Manager interface {
SettingsGetSetter
ServersGetterSetter
PortForwardedGetterSetter
GetSettingsAndServers() (settings configuration.OpenVPN,
allServers models.AllServers)
}
@@ -35,6 +36,9 @@ type State struct {
allServers models.AllServers
allServersMu sync.RWMutex
portForwarded uint16
portForwardedMu sync.RWMutex
}
func (s *State) GetSettingsAndServers() (settings configuration.OpenVPN,

View File

@@ -1,32 +0,0 @@
package portforward
import "context"
// firewallBlockPort obtains the state port thread safely and blocks
// it in the firewall if it is not the zero value (0).
func (l *Loop) firewallBlockPort(ctx context.Context) {
port := l.state.GetPortForwarded()
if port == 0 {
return
}
err := l.portAllower.RemoveAllowedPort(ctx, port)
if err != nil {
l.logger.Error("cannot block previous port in firewall: " + err.Error())
}
}
// firewallAllowPort obtains the state port thread safely and allows
// it in the firewall if it is not the zero value (0).
func (l *Loop) firewallAllowPort(ctx context.Context) {
port := l.state.GetPortForwarded()
if port == 0 {
return
}
startData := l.state.GetStartData()
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
if err != nil {
l.logger.Error("cannot allow port through firewall: " + err.Error())
}
}

View File

@@ -1,37 +0,0 @@
package portforward
import (
"fmt"
"os"
)
func (l *Loop) removePortForwardedFile() {
filepath := l.state.GetSettings().Filepath
l.logger.Info("removing port file " + filepath)
if err := os.Remove(filepath); err != nil {
l.logger.Error(err.Error())
}
}
func (l *Loop) writePortForwardedFile(port uint16) {
filepath := l.state.GetSettings().Filepath
l.logger.Info("writing port file " + filepath)
if err := writePortForwardedToFile(filepath, port); err != nil {
l.logger.Error(err.Error())
}
}
func writePortForwardedToFile(filepath string, port uint16) (err error) {
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.Write([]byte(fmt.Sprint(port)))
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -1,9 +0,0 @@
package portforward
import "github.com/qdm12/gluetun/internal/portforward/state"
type Getter = state.PortForwardedGetter
func (l *Loop) GetPortForwarded() (port uint16) {
return l.state.GetPortForwarded()
}

View File

@@ -1,22 +0,0 @@
package portforward
import (
"context"
"time"
)
func (l *Loop) logAndWait(ctx context.Context, err error) {
if err != nil {
l.logger.Error(err.Error())
}
l.logger.Info("retrying in " + l.backoffTime.String())
timer := time.NewTimer(l.backoffTime)
l.backoffTime *= 2
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
}
}

View File

@@ -1,71 +0,0 @@
package portforward
import (
"net/http"
"sync"
"time"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/portforward/state"
"github.com/qdm12/golibs/logging"
)
var _ Looper = (*Loop)(nil)
type Looper interface {
Runner
loopstate.Getter
StartStopper
SettingsGetSetter
Getter
}
type Loop struct {
statusManager loopstate.Manager
state state.Manager
// Objects
client *http.Client
portAllower firewall.PortAllower
logger logging.Logger
// Internal channels and locks
start chan struct{}
running chan models.LoopStatus
stop chan struct{}
stopped chan struct{}
startMu sync.Mutex
backoffTime time.Duration
userTrigger bool
}
const defaultBackoffTime = 5 * time.Second
func NewLoop(settings configuration.PortForwarding,
client *http.Client, portAllower firewall.PortAllower,
logger logging.Logger) *Loop {
start := make(chan struct{})
running := make(chan models.LoopStatus)
stop := make(chan struct{})
stopped := make(chan struct{})
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
state := state.New(statusManager, settings)
return &Loop{
statusManager: statusManager,
state: state,
// Objects
client: client,
portAllower: portAllower,
logger: logger,
start: start,
running: running,
stop: stop,
stopped: stopped,
userTrigger: true,
backoffTime: defaultBackoffTime,
}
}

View File

@@ -1,97 +0,0 @@
package portforward
import (
"context"
"strconv"
"github.com/qdm12/gluetun/internal/constants"
)
type Runner interface {
Run(ctx context.Context, done chan<- struct{})
}
func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
defer close(done)
select {
case <-l.start: // l.state.SetStartData called beforehand
case <-ctx.Done():
return
}
for ctx.Err() == nil {
pfCtx, pfCancel := context.WithCancel(ctx)
portCh := make(chan uint16)
errorCh := make(chan error)
startData := l.state.GetStartData()
go func(ctx context.Context, startData StartData) {
port, err := startData.PortForwarder.PortForward(ctx, l.client, l.logger,
startData.Gateway, startData.ServerName)
if err != nil {
errorCh <- err
return
}
portCh <- port
// Infinite loop
err = startData.PortForwarder.KeepPortForward(ctx, l.client, l.logger,
port, startData.Gateway, startData.ServerName)
errorCh <- err
}(pfCtx, startData)
if l.userTrigger {
l.userTrigger = false
l.running <- constants.Running
} else { // crash
l.backoffTime = defaultBackoffTime
l.statusManager.SetStatus(constants.Running)
}
stayHere := true
for stayHere {
select {
case <-ctx.Done():
pfCancel()
<-errorCh
close(errorCh)
close(portCh)
l.removePortForwardedFile()
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(0)
return
case <-l.start:
l.userTrigger = true
l.logger.Info("starting")
pfCancel()
stayHere = false
case <-l.stop:
l.userTrigger = true
l.logger.Info("stopping")
pfCancel()
<-errorCh
l.removePortForwardedFile()
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(0)
l.stopped <- struct{}{}
case port := <-portCh:
l.logger.Info("port forwarded is " + strconv.Itoa(int(port)))
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(port)
l.firewallAllowPort(ctx)
l.writePortForwardedFile(port)
case err := <-errorCh:
pfCancel()
close(errorCh)
close(portCh)
l.statusManager.SetStatus(constants.Crashed)
l.logAndWait(ctx, err)
stayHere = false
}
}
pfCancel() // for linting
}
}

View File

@@ -1,19 +0,0 @@
package portforward
import (
"context"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/portforward/state"
)
type SettingsGetSetter = state.SettingsGetSetter
func (l *Loop) GetSettings() (settings configuration.PortForwarding) {
return l.state.GetSettings()
}
func (l *Loop) SetSettings(ctx context.Context, settings configuration.PortForwarding) (
outcome string) {
return l.state.SetSettings(ctx, settings)
}

View File

@@ -1,55 +0,0 @@
package state
import (
"context"
"os"
"reflect"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
)
type SettingsGetSetter interface {
GetSettings() (settings configuration.PortForwarding)
SetSettings(ctx context.Context,
settings configuration.PortForwarding) (outcome string)
}
func (s *State) GetSettings() (settings configuration.PortForwarding) {
s.settingsMu.RLock()
defer s.settingsMu.RUnlock()
return s.settings
}
func (s *State) SetSettings(ctx context.Context, settings configuration.PortForwarding) (
outcome string) {
s.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(s.settings, settings)
if settingsUnchanged {
s.settingsMu.Unlock()
return "settings left unchanged"
}
if s.settings.Filepath != settings.Filepath {
_ = os.Rename(s.settings.Filepath, settings.Filepath)
}
newEnabled := settings.Enabled
previousEnabled := s.settings.Enabled
s.settings = settings
s.settingsMu.Unlock()
switch {
case !newEnabled && !previousEnabled:
case newEnabled && previousEnabled:
// no need to restart for now since we os.Rename the file here.
case newEnabled && !previousEnabled:
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
case !newEnabled && previousEnabled:
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
}
return "settings updated"
}

View File

@@ -1,39 +0,0 @@
package state
import (
"net"
"github.com/qdm12/gluetun/internal/provider"
)
type StartData struct {
PortForwarder provider.PortForwarder
Gateway net.IP // needed for PIA
ServerName string // needed for PIA
Interface string // tun0 or wg0 for example
}
type StartDataGetterSetter interface {
StartDataGetter
StartDataSetter
}
type StartDataGetter interface {
GetStartData() (startData StartData)
}
func (s *State) GetStartData() (startData StartData) {
s.startDataMu.RLock()
defer s.startDataMu.RUnlock()
return s.startData
}
type StartDataSetter interface {
SetStartData(startData StartData)
}
func (s *State) SetStartData(startData StartData) {
s.startDataMu.Lock()
defer s.startDataMu.Unlock()
s.startData = startData
}

View File

@@ -1,37 +0,0 @@
package state
import (
"sync"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/loopstate"
)
var _ Manager = (*State)(nil)
type Manager interface {
SettingsGetSetter
PortForwardedGetterSetter
StartDataGetterSetter
}
func New(statusApplier loopstate.Applier,
settings configuration.PortForwarding) *State {
return &State{
statusApplier: statusApplier,
settings: settings,
}
}
type State struct {
statusApplier loopstate.Applier
settings configuration.PortForwarding
settingsMu sync.RWMutex
portForwarded uint16
portForwardedMu sync.RWMutex
startData StartData
startDataMu sync.RWMutex
}

View File

@@ -1,33 +0,0 @@
package portforward
import (
"context"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/portforward/state"
)
func (l *Loop) GetStatus() (status models.LoopStatus) {
return l.statusManager.GetStatus()
}
type StartData = state.StartData
type StartStopper interface {
Start(ctx context.Context, data StartData) (
outcome string, err error)
Stop(ctx context.Context) (outcome string, err error)
}
func (l *Loop) Start(ctx context.Context, data StartData) (
outcome string, err error) {
l.startMu.Lock()
defer l.startMu.Unlock()
l.state.SetStartData(data)
return l.statusManager.ApplyStatus(ctx, constants.Running)
}
func (l *Loop) Stop(ctx context.Context) (outcome string, err error) {
return l.statusManager.ApplyStatus(ctx, constants.Stopped)
}

View File

@@ -1,41 +1,18 @@
package cyberghost
import (
"errors"
"fmt"
"strings"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/utils"
)
var ErrGroupMismatchesProtocol = errors.New("server group does not match protocol")
func (c *Cyberghost) filterServers(selection configuration.ServerSelection) (
servers []models.CyberghostServer, err error) {
if len(selection.Groups) == 0 {
if selection.TCP {
selection.Groups = tcpGroupChoices()
} else {
selection.Groups = udpGroupChoices()
}
}
// Check each group match the protocol
groupsCheckFn := groupsAreAllUDP
if selection.TCP {
groupsCheckFn = groupsAreAllTCP
}
if err := groupsCheckFn(selection.Groups); err != nil {
return nil, err
}
for _, server := range c.servers {
switch {
case
utils.FilterByPossibilities(server.Group, selection.Groups),
case selection.Group != "" && !strings.EqualFold(selection.Group, server.Group), // TODO make CSV
utils.FilterByPossibilities(server.Region, selection.Regions),
utils.FilterByPossibilities(server.Hostname, selection.Hostnames):
default:
@@ -49,51 +26,3 @@ func (c *Cyberghost) filterServers(selection configuration.ServerSelection) (
return servers, nil
}
func tcpGroupChoices() (choices []string) {
const tcp = true
return groupsForTCP(tcp)
}
func udpGroupChoices() (choices []string) {
const tcp = false
return groupsForTCP(tcp)
}
func groupsForTCP(tcp bool) (choices []string) {
allGroups := constants.CyberghostGroupChoices()
choices = make([]string, 0, len(allGroups))
for _, group := range allGroups {
switch {
case tcp && groupIsTCP(group):
choices = append(choices, group)
case !tcp && !groupIsTCP(group):
choices = append(choices, group)
}
}
return choices
}
func groupIsTCP(group string) bool {
return strings.Contains(strings.ToLower(group), "tcp")
}
func groupsAreAllTCP(groups []string) error {
for _, group := range groups {
if !groupIsTCP(group) {
return fmt.Errorf("%w: group %s for protocol TCP",
ErrGroupMismatchesProtocol, group)
}
}
return nil
}
func groupsAreAllUDP(groups []string) error {
for _, group := range groups {
if groupIsTCP(group) {
return fmt.Errorf("%w: group %s for protocol UDP",
ErrGroupMismatchesProtocol, group)
}
}
return nil
}

View File

@@ -21,102 +21,77 @@ func Test_Cyberghost_filterServers(t *testing.T) {
"no servers": {
err: errors.New("no server found: for protocol udp"),
},
"servers without filter defaults to UDP": {
"servers without filter": {
servers: []models.CyberghostServer{
{Region: "a", Group: "Premium TCP Asia"},
{Region: "b", Group: "Premium TCP Europe"},
{Region: "c", Group: "Premium UDP Asia"},
{Region: "d", Group: "Premium UDP Europe"},
{Region: "a", Group: "1"},
{Region: "b", Group: "1"},
{Region: "c", Group: "2"},
{Region: "d", Group: "2"},
},
filteredServers: []models.CyberghostServer{
{Region: "c", Group: "Premium UDP Asia"},
{Region: "d", Group: "Premium UDP Europe"},
},
},
"servers with TCP selection": {
servers: []models.CyberghostServer{
{Region: "a", Group: "Premium TCP Asia"},
{Region: "b", Group: "Premium TCP Europe"},
{Region: "c", Group: "Premium UDP Asia"},
{Region: "d", Group: "Premium UDP Europe"},
},
selection: configuration.ServerSelection{
TCP: true,
},
filteredServers: []models.CyberghostServer{
{Region: "a", Group: "Premium TCP Asia"},
{Region: "b", Group: "Premium TCP Europe"},
{Region: "a", Group: "1"},
{Region: "b", Group: "1"},
{Region: "c", Group: "2"},
{Region: "d", Group: "2"},
},
},
"servers with regions filter": {
servers: []models.CyberghostServer{
{Region: "a", Group: "Premium UDP Asia"},
{Region: "b", Group: "Premium UDP Asia"},
{Region: "c", Group: "Premium UDP Asia"},
{Region: "d", Group: "Premium UDP Asia"},
{Region: "a", Group: "1"},
{Region: "b", Group: "1"},
{Region: "c", Group: "2"},
{Region: "d", Group: "2"},
},
selection: configuration.ServerSelection{
Regions: []string{"a", "c"},
},
filteredServers: []models.CyberghostServer{
{Region: "a", Group: "Premium UDP Asia"},
{Region: "c", Group: "Premium UDP Asia"},
{Region: "a", Group: "1"},
{Region: "c", Group: "2"},
},
},
"servers with group filter": {
servers: []models.CyberghostServer{
{Region: "a", Group: "Premium UDP Europe"},
{Region: "b", Group: "Premium UDP Europe"},
{Region: "c", Group: "Premium TCP Europe"},
{Region: "d", Group: "Premium TCP Europe"},
{Region: "a", Group: "1"},
{Region: "b", Group: "1"},
{Region: "c", Group: "2"},
{Region: "d", Group: "2"},
},
selection: configuration.ServerSelection{
Groups: []string{"Premium UDP Europe"},
Group: "1",
},
filteredServers: []models.CyberghostServer{
{Region: "a", Group: "Premium UDP Europe"},
{Region: "b", Group: "Premium UDP Europe"},
{Region: "a", Group: "1"},
{Region: "b", Group: "1"},
},
},
"servers with bad group filter": {
servers: []models.CyberghostServer{
{Region: "a", Group: "Premium TCP Europe"},
{Region: "b", Group: "Premium TCP Europe"},
{Region: "c", Group: "Premium UDP Europe"},
{Region: "d", Group: "Premium UDP Europe"},
},
selection: configuration.ServerSelection{
Groups: []string{"Premium TCP Europe"},
},
err: errors.New("server group does not match protocol: group Premium TCP Europe for protocol UDP"),
},
"servers with regions and group filter": {
servers: []models.CyberghostServer{
{Region: "a", Group: "Premium UDP Europe"},
{Region: "b", Group: "Premium TCP Europe"},
{Region: "c", Group: "Premium UDP Asia"},
{Region: "d", Group: "Premium TCP Asia"},
{Region: "a", Group: "1"},
{Region: "b", Group: "1"},
{Region: "c", Group: "2"},
{Region: "d", Group: "2"},
},
selection: configuration.ServerSelection{
Regions: []string{"a", "c"},
Groups: []string{"Premium UDP Europe"},
Group: "1",
},
filteredServers: []models.CyberghostServer{
{Region: "a", Group: "Premium UDP Europe"},
{Region: "a", Group: "1"},
},
},
"servers with hostnames filter": {
servers: []models.CyberghostServer{
{Hostname: "a", Group: "Premium UDP Asia"},
{Hostname: "b", Group: "Premium UDP Asia"},
{Hostname: "c", Group: "Premium UDP Asia"},
{Hostname: "a"},
{Hostname: "b"},
{Hostname: "c"},
},
selection: configuration.ServerSelection{
Hostnames: []string{"a", "c"},
},
filteredServers: []models.CyberghostServer{
{Hostname: "a", Group: "Premium UDP Asia"},
{Hostname: "c", Group: "Premium UDP Asia"},
{Hostname: "a"},
{Hostname: "c"},
},
},
}
@@ -138,25 +113,3 @@ func Test_Cyberghost_filterServers(t *testing.T) {
})
}
}
func Test_tcpGroupChoices(t *testing.T) {
t.Parallel()
expected := []string{
"Premium TCP Asia", "Premium TCP Europe", "Premium TCP USA",
}
choices := tcpGroupChoices()
assert.Equal(t, expected, choices)
}
func Test_udpGroupChoices(t *testing.T) {
t.Parallel()
expected := []string{
"Premium UDP Asia", "Premium UDP Europe", "Premium UDP USA",
}
choices := udpGroupChoices()
assert.Equal(t, expected, choices)
}

View File

@@ -35,6 +35,7 @@ func (c *Cyberghost) BuildConf(connection models.OpenVPNConnection,
// Cyberghost specific
// "redirect-gateway def1",
"ncp-disable",
"explicit-exit-notify 2",
"script-security 2",
"route-delay 5",
@@ -55,10 +56,6 @@ func (c *Cyberghost) BuildConf(connection models.OpenVPNConnection,
lines = append(lines, utils.CipherLines(settings.Cipher, settings.Version)...)
if connection.Protocol == constants.UDP {
lines = append(lines, "explicit-exit-notify")
}
if strings.HasSuffix(settings.Cipher, "-gcm") {
lines = append(lines, "ncp-ciphers AES-256-GCM:AES-256-CBC:AES-128-GCM")
}

View File

@@ -31,7 +31,6 @@ func (p *PIA) GetOpenVPNConnection(selection configuration.ServerSelection) (
IP: IP,
Port: port,
Protocol: protocol,
Hostname: server.ServerName, // used for port forwarding TLS
}
connections = append(connections, connection)
}
@@ -47,5 +46,20 @@ func (p *PIA) GetOpenVPNConnection(selection configuration.ServerSelection) (
return connection, err
}
p.activeServer = findActiveServer(servers, connection)
return connection, nil
}
func findActiveServer(servers []models.PIAServer,
connection models.OpenVPNConnection) (activeServer models.PIAServer) {
// Reverse lookup server using the randomly picked connection
for _, server := range servers {
for _, ip := range server.IPs {
if connection.IP.Equal(ip) {
return server
}
}
}
return activeServer
}

View File

@@ -16,51 +16,47 @@ import (
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/format"
"github.com/qdm12/golibs/logging"
)
var (
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
ErrServerNameEmpty = errors.New("server name is empty")
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
ErrBindPort = errors.New("cannot bind port")
ErrBindPort = errors.New("cannot bind port")
)
// PortForward obtains a VPN server side port forwarded from PIA.
//nolint:gocognit
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
logger logging.Logger, gateway net.IP, serverName string) (
port uint16, err error) {
server := constants.PIAServerWhereName(serverName)
if !server.PortForward {
logger.Error("The server " + serverName +
" (region " + server.Region + ") does not support port forwarding")
logger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
syncState func(port uint16) (pfFilepath string)) {
commonName := p.activeServer.ServerName
if !p.activeServer.PortForward {
logger.Error("The server " + commonName +
" (region " + p.activeServer.Region + ") does not support port forwarding")
return
}
if gateway == nil {
return 0, ErrGatewayIPIsNil
} else if serverName == "" {
return 0, ErrServerNameEmpty
logger.Error("aborting because: VPN gateway IP address was not found")
return
}
privateIPClient, err := newHTTPClient(serverName)
privateIPClient, err := newHTTPClient(commonName)
if err != nil {
return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
logger.Error("aborting because: " + err.Error())
return
}
data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil {
return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
logger.Error(err.Error())
}
dataFound := data.Port > 0
durationToExpiration := data.Expiration.Sub(p.timeNow())
expired := durationToExpiration <= 0
if dataFound {
logger.Info("Found saved forwarded port data for port " + strconv.Itoa(int(data.Port)))
logger.Info("Found persistent forwarded port data for port " + strconv.Itoa(int(data.Port)))
if expired {
logger.Warn("Forwarded port data expired on " +
data.Expiration.Format(time.RFC1123) + ", getting another one")
@@ -70,65 +66,99 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
}
if !dataFound || expired {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
if err != nil {
return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err)
tryUntilSuccessful(ctx, logger, func() error {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
return err
})
if ctx.Err() != nil {
return
}
durationToExpiration = data.Expiration.Sub(p.timeNow())
}
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
" expiring in " + format.FriendlyDuration(durationToExpiration))
// First time binding
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
return 0, fmt.Errorf("%w: %s", ErrBindPort, err)
tryUntilSuccessful(ctx, logger, func() error {
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
return fmt.Errorf("%w: %s", ErrBindPort, err)
}
return nil
})
if ctx.Err() != nil {
return
}
return data.Port, nil
}
var (
ErrPortForwardedExpired = errors.New("port forwarded data expired")
)
func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error) {
privateIPClient, err := newHTTPClient(serverName)
if err != nil {
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error(err.Error())
}
data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil {
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
logger.Error(err.Error())
}
durationToExpiration := data.Expiration.Sub(p.timeNow())
expiryTimer := time.NewTimer(durationToExpiration)
const keepAlivePeriod = 15 * time.Minute
// Timer behaving as a ticker
keepAliveTimer := time.NewTimer(keepAlivePeriod)
for {
select {
case <-ctx.Done():
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := portAllower.RemoveAllowedPort(removeCtx, data.Port); err != nil {
logger.Error(err.Error())
}
if !keepAliveTimer.Stop() {
<-keepAliveTimer.C
}
if !expiryTimer.Stop() {
<-expiryTimer.C
}
return ctx.Err()
return
case <-keepAliveTimer.C:
err := bindPort(ctx, privateIPClient, gateway, data)
if err != nil {
return fmt.Errorf("%w: %s", ErrBindPort, err)
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
logger.Error("cannot bind port: " + err.Error())
}
keepAliveTimer.Reset(keepAlivePeriod)
case <-expiryTimer.C:
return fmt.Errorf("%w: on %s", ErrPortForwardedExpired,
data.Expiration.Format(time.RFC1123))
logger.Warn("Forward port has expired on " +
data.Expiration.Format(time.RFC1123) + ", getting another one")
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
if err != nil {
logger.Error(err.Error())
continue
}
break
}
durationToExpiration := data.Expiration.Sub(p.timeNow())
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
" expiring in " + format.FriendlyDuration(durationToExpiration))
if err := portAllower.RemoveAllowedPort(ctx, oldPort); err != nil {
logger.Error(err.Error())
}
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
logger.Error(err.Error())
}
filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error("Cannot write port forward data to file: " + err.Error())
}
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
logger.Error("Cannot bind port: " + err.Error())
}
if !keepAliveTimer.Stop() {
<-keepAliveTimer.C
}
keepAliveTimer.Reset(keepAlivePeriod)
expiryTimer.Reset(durationToExpiration)
}
}
}
@@ -433,6 +463,21 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
return nil
}
func writePortForwardedToFile(filepath string, port uint16) (err error) {
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}
// replaceInErr is used to remove sensitive information from errors.
func replaceInErr(err error, substitutions map[string]string) error {
s := replaceInString(err.Error(), substitutions)

View File

@@ -9,9 +9,10 @@ import (
)
type PIA struct {
servers []models.PIAServer
randSource rand.Source
timeNow func() time.Time
servers []models.PIAServer
randSource rand.Source
timeNow func() time.Time
activeServer models.PIAServer
// Port forwarding
portForwardPath string
authFilePath string

View File

@@ -0,0 +1,31 @@
package privateinternetaccess
import (
"context"
"time"
"github.com/qdm12/golibs/logging"
)
func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) {
const initialRetryPeriod = 5 * time.Second
retryPeriod := initialRetryPeriod
for {
err := fn()
if err == nil {
break
}
logger.Error(err.Error())
logger.Info("Trying again in " + retryPeriod.String())
timer := time.NewTimer(retryPeriod)
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return
}
retryPeriod *= 2
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/cyberghost"
"github.com/qdm12/gluetun/internal/provider/fastestvpn"
@@ -35,16 +36,9 @@ import (
type Provider interface {
GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error)
BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string)
PortForwarder
}
type PortForwarder interface {
PortForward(ctx context.Context, client *http.Client,
logger logging.Logger, gateway net.IP, serverName string) (
port uint16, err error)
KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error)
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
syncState func(port uint16) (pfFilepath string))
}
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider {

View File

@@ -25,13 +25,8 @@ func NoServerFoundError(selection configuration.ServerSelection) (err error) {
}
messageParts = append(messageParts, "protocol "+protocol)
switch len(selection.Countries) {
case 0:
case 1:
part := "group " + selection.Groups[0]
messageParts = append(messageParts, part)
default:
part := "groups " + commaJoin(selection.Groups)
if selection.Group != "" {
part := "group " + selection.Group
messageParts = append(messageParts, part)
}

View File

@@ -2,21 +2,17 @@ package utils
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging"
)
type NoPortForwarder interface {
PortForward(ctx context.Context, client *http.Client,
logger logging.Logger, gateway net.IP, serverName string) (
port uint16, err error)
KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error)
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
syncState func(port uint16) (pfFilepath string))
}
type NoPortForwarding struct {
@@ -29,16 +25,8 @@ func NewNoPortForwarding(providerName string) *NoPortForwarding {
}
}
var ErrPortForwardingNotSupported = errors.New("custom port forwarding obtention is not supported")
func (n *NoPortForwarding) PortForward(ctx context.Context, client *http.Client,
logger logging.Logger, gateway net.IP, serverName string) (
port uint16, err error) {
return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
}
func (n *NoPortForwarding) KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error) {
return fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
syncState func(port uint16) (pfFilepath string)) {
panic("custom port forwarding obtention is not supported for " + n.providerName)
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/qdm12/gluetun/internal/dns"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging"
@@ -17,14 +16,13 @@ import (
func newHandler(ctx context.Context, logger logging.Logger, logging bool,
buildInfo models.BuildInformation,
openvpnLooper openvpn.Looper,
pfGetter portforward.Getter,
unboundLooper dns.Looper,
updaterLooper updater.Looper,
publicIPLooper publicip.Looper,
) http.Handler {
handler := &handler{}
openvpn := newOpenvpnHandler(ctx, openvpnLooper, pfGetter, logger)
openvpn := newOpenvpnHandler(ctx, openvpnLooper, logger)
dns := newDNSHandler(ctx, unboundLooper, logger)
updater := newUpdaterHandler(ctx, updaterLooper, logger)
publicip := newPublicIPHandler(publicIPLooper, logger)

View File

@@ -7,16 +7,14 @@ import (
"strings"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/golibs/logging"
)
func newOpenvpnHandler(ctx context.Context, looper openvpn.Looper,
pfGetter portforward.Getter, logger logging.Logger) http.Handler {
logger logging.Logger) http.Handler {
return &openvpnHandler{
ctx: ctx,
looper: looper,
pf: pfGetter,
logger: logger,
}
}
@@ -24,7 +22,6 @@ func newOpenvpnHandler(ctx context.Context, looper openvpn.Looper,
type openvpnHandler struct {
ctx context.Context
looper openvpn.Looper
pf portforward.Getter
logger logging.Logger
}
@@ -108,7 +105,7 @@ func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
}
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
port := h.pf.GetPortForwarded()
port := h.looper.GetPortForwarded()
encoder := json.NewEncoder(w)
data := portWrapper{Port: port}
if err := encoder.Encode(data); err != nil {

View File

@@ -10,7 +10,6 @@ import (
"github.com/qdm12/gluetun/internal/dns"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging"
@@ -27,11 +26,11 @@ type server struct {
}
func New(ctx context.Context, address string, logEnabled bool, logger logging.Logger,
buildInfo models.BuildInformation, openvpnLooper openvpn.Looper,
pfGetter portforward.Getter, unboundLooper dns.Looper,
buildInfo models.BuildInformation,
openvpnLooper openvpn.Looper, unboundLooper dns.Looper,
updaterLooper updater.Looper, publicIPLooper publicip.Looper) Server {
handler := newHandler(ctx, logger, logEnabled, buildInfo,
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper)
openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
return &server{
address: address,
logger: logger,

View File

@@ -0,0 +1,32 @@
package shadowsocks
import "github.com/qdm12/golibs/logging"
type logAdapter struct {
logger logging.Logger
enabled bool
}
func (l *logAdapter) Info(s string) {
if l.enabled {
l.logger.Info(s)
}
}
func (l *logAdapter) Debug(s string) {
if l.enabled {
l.logger.Debug(s)
}
}
func (l *logAdapter) Error(s string) {
if l.enabled {
l.logger.Error(s)
}
}
func adaptLogger(logger logging.Logger, enabled bool) *logAdapter {
return &logAdapter{
logger: logger,
enabled: enabled,
}
}

View File

@@ -3,6 +3,7 @@ package shadowsocks
import (
"context"
"strconv"
"sync"
"time"
@@ -87,7 +88,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
for ctx.Err() == nil {
settings := l.GetSettings()
server, err := shadowsockslib.NewServer(settings.Settings, l.logger)
server, err := shadowsockslib.NewServer(settings.Method, settings.Password, adaptLogger(l.logger, settings.Log))
if err != nil {
crashed = true
l.logAndWait(ctx, err)
@@ -98,7 +99,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
waitError := make(chan error)
go func() {
waitError <- server.Listen(shadowsocksCtx)
waitError <- server.Listen(shadowsocksCtx, ":"+strconv.Itoa(int(settings.Port)))
}()
if err != nil {
crashed = true

View File

@@ -35,9 +35,9 @@ func GetServers(ctx context.Context, client *http.Client, minServers int) (
if node.IP2 != nil {
ips = append(ips, node.IP2)
}
// if node.IP3 != nil { // Wireguard + Stealth
// ips = append(ips, node.IP3)
// }
if node.IP3 != nil {
ips = append(ips, node.IP3)
}
server := models.WindscribeServer{
Region: region,
City: city,