Compare commits

..

14 Commits

Author SHA1 Message Date
Quentin McGaw
db886163c2 Public IP getter loop refactored 2020-12-28 01:51:55 +00:00
Quentin McGaw
91f5338db0 Fix updater loop bug 2020-12-28 01:50:13 +00:00
Quentin McGaw
82a02287ac Public IP endpoint with GET /ip fixing #319 2020-12-27 21:06:00 +00:00
Quentin McGaw
2dc674559e Re-use username for UID if it exists 2020-12-27 00:36:39 +00:00
Quentin McGaw
38e713fea2 Fix Block-outside-dns #316 2020-12-23 06:46:54 +00:00
Quentin McGaw
2cbb14c36c Fix Purevpn settings display, refers to #317 2020-12-22 14:08:12 +00:00
Quentin McGaw
610e88958e Upgrade golangci-lint to v1.33.0 2020-12-22 13:52:37 +00:00
Quentin McGaw
bb76477467 Fix #316 2020-12-22 13:49:49 +00:00
Quentin McGaw
433a799759 Fix environment variables table for Purevpn 2020-12-22 13:46:52 +00:00
Quentin McGaw
22965ccce3 Fix #315 2020-12-22 06:21:25 +00:00
Quentin McGaw
4257581f55 Loops and HTTP control server rework (#308)
- CRUD REST HTTP server
- `/v1` HTTP server prefix
- Retrocompatible with older routes (redirects to v1 or handles the requests directly)
- DNS, Updater and Openvpn refactored to have a REST-like state with new methods to change their states synchronously
- Openvpn, Unbound and Updater status, see #287
2020-12-19 20:10:34 -05:00
Quentin McGaw
d60d629105 Dev container documentation and cleanup 2020-12-08 06:24:46 +00:00
Quentin McGaw
3f721b1717 Simplify Github workflows triggers 2020-12-07 02:15:50 +00:00
Quentin McGaw
97049bfab4 Add 256x256 png logo for Unraid 2020-12-07 02:11:23 +00:00
69 changed files with 3472 additions and 733 deletions

View File

@@ -0,0 +1,5 @@
.dockerignore
devcontainer.json
docker-compose.yml
Dockerfile
README.md

1
.devcontainer/Dockerfile Normal file
View File

@@ -0,0 +1 @@
FROM qmcgaw/godevcontainer

68
.devcontainer/README.md Normal file
View File

@@ -0,0 +1,68 @@
# Development container
Development container that can be used with VSCode.
It works on Linux, Windows and OSX.
## Requirements
- [VS code](https://code.visualstudio.com/download) installed
- [VS code remote containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) installed
- [Docker](https://www.docker.com/products/docker-desktop) installed and running
- If you don't use Linux or WSL 2, share your home directory `~/` and the directory of your project with Docker Desktop
- [Docker Compose](https://docs.docker.com/compose/install/) installed
- Ensure your host has the following and that they are accessible by Docker:
- `~/.ssh` directory
- `~/.gitconfig` file (can be empty)
## Setup
1. Open the command palette in Visual Studio Code (CTRL+SHIFT+P).
1. Select `Remote-Containers: Open Folder in Container...` and choose the project directory.
## Customization
### Customize the image
You can make changes to the [Dockerfile](Dockerfile) and then rebuild the image. For example, your Dockerfile could be:
```Dockerfile
FROM qmcgaw/godevcontainer
USER root
RUN apk add curl
USER vscode
```
Note that you may need to use `USER root` to build as root, and then change back to `USER vscode`.
To rebuild the image, either:
- With VSCode through the command palette, select `Remote-Containers: Rebuild and reopen in container`
- With a terminal, go to this directory and `docker-compose build`
### Customize VS code settings
You can customize **settings** and **extensions** in the [devcontainer.json](devcontainer.json) definition file.
### Entrypoint script
You can bind mount a shell script to `/home/vscode/.welcome.sh` to replace the [current welcome script](shell/.welcome.sh).
### Publish a port
To access a port from your host to your development container, publish a port in [docker-compose.yml](docker-compose.yml).
### Run other services
1. Modify [docker-compose.yml](docker-compose.yml) to launch other services at the same time as this development container, such as a test database:
```yml
database:
image: postgres
restart: always
environment:
POSTGRES_PASSWORD: password
```
1. In [devcontainer.json](devcontainer.json), change the line `"runServices": ["vscode"],` to `"runServices": ["vscode", "database"],`.
1. In the VS code command palette, rebuild the container.

View File

@@ -1,5 +1,5 @@
{ {
"name": "pia-dev", "name": "gluetun-dev",
"dockerComposeFile": [ "dockerComposeFile": [
"docker-compose.yml" "docker-compose.yml"
], ],
@@ -12,27 +12,25 @@
"workspaceFolder": "/workspace", "workspaceFolder": "/workspace",
"extensions": [ "extensions": [
"golang.go", "golang.go",
"IBM.output-colorizer", "eamodio.gitlens", // IDE Git information
"eamodio.gitlens",
"mhutchie.git-graph",
"davidanson.vscode-markdownlint", "davidanson.vscode-markdownlint",
"shardulm94.trailing-spaces", "ms-azuretools.vscode-docker", // Docker integration and linting
"alefragnani.Bookmarks", "shardulm94.trailing-spaces", // Show trailing spaces
"Gruntfuggly.todo-tree", "Gruntfuggly.todo-tree", // Highlights TODO comments
"mohsen1.prettify-json", "bierner.emojisense", // Emoji sense for markdown
"quicktype.quicktype", "stkb.rewrap", // rewrap comments after n characters on one line
"spikespaz.vscode-smoothtype", "vscode-icons-team.vscode-icons", // Better file extension icons
"stkb.rewrap", "github.vscode-pull-request-github", // Github interaction
"vscode-icons-team.vscode-icons" "redhat.vscode-yaml", // Kubernetes, Drone syntax highlighting
"bajdzis.vscode-database", // Supports connections to mysql or postgres, over SSL, socked
"IBM.output-colorizer", // Colorize your output/test logs
"mohsen1.prettify-json", // Prettify JSON data
], ],
"settings": { "settings": {
// General settings
"files.eol": "\n", "files.eol": "\n",
// Docker
"remote.extensionKind": { "remote.extensionKind": {
"ms-azuretools.vscode-docker": "workspace" "ms-azuretools.vscode-docker": "workspace"
}, },
// Golang general settings
"go.useLanguageServer": true, "go.useLanguageServer": true,
"go.autocompleteUnimportedPackages": true, "go.autocompleteUnimportedPackages": true,
"go.gotoSymbol.includeImports": true, "go.gotoSymbol.includeImports": true,
@@ -43,7 +41,6 @@
"usePlaceholders": false "usePlaceholders": false
}, },
"go.lintTool": "golangci-lint", "go.lintTool": "golangci-lint",
// Golang on save
"go.buildOnSave": "workspace", "go.buildOnSave": "workspace",
"go.lintOnSave": "workspace", "go.lintOnSave": "workspace",
"go.vetOnSave": "workspace", "go.vetOnSave": "workspace",
@@ -53,20 +50,21 @@
"source.organizeImports": true "source.organizeImports": true
} }
}, },
// Golang testing
"go.toolsEnvVars": { "go.toolsEnvVars": {
"GOFLAGS": "-tags=integration" "GOFLAGS": "-tags=",
// "CGO_ENABLED": 1 // for the race detector
}, },
"gopls.env": { "gopls.env": {
"GOFLAGS": "-tags=integration" "GOFLAGS": "-tags="
}, },
"go.testEnvVars": {}, "go.testEnvVars": {},
"go.testFlags": [ "go.testFlags": [
"-v", "-v",
// "-race" // "-race"
], ],
"go.testTimeout": "600s", "go.testTimeout": "10s",
"go.coverOnSingleTest": true,
"go.coverOnSingleTestFile": true, "go.coverOnSingleTestFile": true,
"go.coverOnSingleTest": true "go.coverOnTestPackage": true
} }
} }

View File

@@ -2,14 +2,24 @@ version: "3.7"
services: services:
vscode: vscode:
image: qmcgaw/godevcontainer build: .
image: godevcontainer
volumes: volumes:
- ../:/workspace - ../:/workspace
# Docker socket to access Docker server
- /var/run/docker.sock:/var/run/docker.sock
# SSH directory
- ~/.ssh:/home/vscode/.ssh - ~/.ssh:/home/vscode/.ssh
- ~/.ssh:/root/.ssh - ~/.ssh:/root/.ssh
- /var/run/docker.sock:/var/run/docker.sock # Git config
- ~/.gitconfig:/home/districter/.gitconfig
- ~/.gitconfig:/root/.gitconfig
environment:
- TZ=
cap_add: cap_add:
# For debugging with dlv
- SYS_PTRACE - SYS_PTRACE
security_opt: security_opt:
# For debugging with dlv
- seccomp:unconfined - seccomp:unconfined
entrypoint: zsh -c "while sleep 1000; do :; done" entrypoint: zsh -c "while sleep 1000; do :; done"

View File

@@ -2,28 +2,15 @@ name: Docker build
on: on:
pull_request: pull_request:
branches: [master] branches: [master]
paths-ignore: paths:
- .devcontainer - .github/workflows/build.yml
- .github/ISSUE_TEMPLATE - cmd/**
- .github/workflows/buildx-release.yml - internal/**
- .github/workflows/buildx-branch.yml - .dockerignore
- .github/workflows/buildx-latest.yml - .golangci.yml
- .github/workflows/dockerhub-description.yml - Dockerfile
- .github/workflows/labels.yml - go.mod
- .github/workflows/misspell.yml - go.sum
- .github/CODEOWNERS
- .github/CONTRIBUTING.md
- .github/FUNDING.yml
- .github/labels.yml
- .vscode
- cmd/ovpnparser
- cmd/resolver
- doc
- .gitignore
- docker-compose.yml
- LICENSE
- README.md
- title.svg
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -5,28 +5,15 @@ on:
- '*' - '*'
- '*/*' - '*/*'
- '!master' - '!master'
paths-ignore: paths:
- .devcontainer - .github/workflows/buildx-branch.yml
- .github/ISSUE_TEMPLATE - cmd/**
- .github/workflows/build.yml - internal/**
- .github/workflows/buildx-release.yml - .dockerignore
- .github/workflows/buildx-latest.yml - .golangci.yml
- .github/workflows/dockerhub-description.yml - Dockerfile
- .github/workflows/labels.yml - go.mod
- .github/workflows/misspell.yml - go.sum
- .github/CODEOWNERS
- .github/CONTRIBUTING.md
- .github/FUNDING.yml
- .github/labels.yml
- .vscode
- cmd/ovpnparser
- cmd/resolver
- doc
- .gitignore
- docker-compose.yml
- LICENSE
- README.md
- title.svg
jobs: jobs:
buildx: buildx:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -2,28 +2,15 @@ name: Buildx latest
on: on:
push: push:
branches: [master] branches: [master]
paths-ignore: paths:
- .devcontainer - .github/workflows/buildx-latest.yml
- .github/ISSUE_TEMPLATE - cmd/**
- .github/workflows/build.yml - internal/**
- .github/workflows/buildx-branch.yml - .dockerignore
- .github/workflows/buildx-release.yml - .golangci.yml
- .github/workflows/dockerhub-description.yml - Dockerfile
- .github/workflows/labels.yml - go.mod
- .github/workflows/misspell.yml - go.sum
- .github/CODEOWNERS
- .github/CONTRIBUTING.md
- .github/FUNDING.yml
- .github/labels.yml
- .vscode
- cmd/ovpnparser
- cmd/resolver
- doc
- .gitignore
- docker-compose.yml
- LICENSE
- README.md
- title.svg
jobs: jobs:
buildx: buildx:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -2,28 +2,15 @@ name: Buildx release
on: on:
release: release:
types: [published] types: [published]
paths-ignore: paths:
- .devcontainer - .github/workflows/buildx-release.yml
- .github/ISSUE_TEMPLATE - cmd/**
- .github/workflows/build.yml - internal/**
- .github/workflows/buildx-branch.yml - .dockerignore
- .github/workflows/buildx-latest.yml - .golangci.yml
- .github/workflows/dockerhub-description.yml - Dockerfile
- .github/workflows/labels.yml - go.mod
- .github/workflows/misspell.yml - go.sum
- .github/CODEOWNERS
- .github/CONTRIBUTING.md
- .github/FUNDING.yml
- .github/labels.yml
- .vscode
- cmd/ovpnparser
- cmd/resolver
- doc
- .gitignore
- docker-compose.yml
- LICENSE
- README.md
- title.svg
jobs: jobs:
buildx: buildx:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -1,10 +1,10 @@
name: labels name: labels
on: on:
push: push:
branches: ["master"] branches: [master]
paths: paths:
- '.github/labels.yml' - .github/labels.yml
- '.github/workflows/labels.yml' - .github/workflows/labels.yml
jobs: jobs:
labeler: labeler:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -4,7 +4,7 @@ ARG GO_VERSION=1.15
FROM golang:${GO_VERSION}-alpine${ALPINE_VERSION} AS builder FROM golang:${GO_VERSION}-alpine${ALPINE_VERSION} AS builder
RUN apk --update add git RUN apk --update add git
ENV CGO_ENABLED=0 ENV CGO_ENABLED=0
ARG GOLANGCI_LINT_VERSION=v1.31.0 ARG GOLANGCI_LINT_VERSION=v1.33.0
RUN wget -O- -nv https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s ${GOLANGCI_LINT_VERSION} RUN wget -O- -nv https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s ${GOLANGCI_LINT_VERSION}
WORKDIR /tmp/gobuild WORKDIR /tmp/gobuild
COPY .golangci.yml . COPY .golangci.yml .
@@ -18,10 +18,10 @@ COPY internal/ ./internal/
RUN go test ./... RUN go test ./...
RUN golangci-lint run --timeout=10m RUN golangci-lint run --timeout=10m
RUN go build -trimpath -ldflags="-s -w \ RUN go build -trimpath -ldflags="-s -w \
-X 'main.version=$VERSION' \ -X 'main.version=$VERSION' \
-X 'main.buildDate=$BUILD_DATE' \ -X 'main.buildDate=$BUILD_DATE' \
-X 'main.commit=$COMMIT' \ -X 'main.commit=$COMMIT' \
" -o entrypoint main.go " -o entrypoint main.go
FROM alpine:${ALPINE_VERSION} FROM alpine:${ALPINE_VERSION}
ARG VERSION=unknown ARG VERSION=unknown
@@ -47,7 +47,7 @@ ENV VPNSP=pia \
TZ= \ TZ= \
UID=1000 \ UID=1000 \
GID=1000 \ GID=1000 \
IP_STATUS_FILE="/tmp/gluetun/ip" \ PUBLICIP_FILE="/tmp/gluetun/ip" \
# PIA, Windscribe, Surfshark, Cyberghost, Vyprvpn, NordVPN, PureVPN only # PIA, Windscribe, Surfshark, Cyberghost, Vyprvpn, NordVPN, PureVPN only
USER= \ USER= \
PASSWORD= \ PASSWORD= \

View File

@@ -97,7 +97,7 @@ docker run --rm --network=container:gluetun alpine:3.12 wget -qO- https://ipinfo
| Variable | Default | Choices | Description | | Variable | Default | Choices | Description |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| 🏁 `VPNSP` | `private internet access` | `private internet access`, `mullvad`, `windscribe`, `surfshark`, `vyprvpn`, `nordvpn`, `purevpn`, `privado` | VPN Service Provider | | 🏁 `VPNSP` | `private internet access` | `private internet access`, `mullvad`, `windscribe`, `surfshark`, `vyprvpn`, `nordvpn`, `purevpn`, `privado` | VPN Service Provider |
| `IP_STATUS_FILE` | `/tmp/gluetun/ip` | Any filepath | Filepath to store the public IP address assigned | | `PUBLICIP_FILE` | `/tmp/gluetun/ip` | Any filepath | Filepath to store the public IP address assigned |
| `PROTOCOL` | `udp` | `udp` or `tcp` | Network protocol to use | | `PROTOCOL` | `udp` | `udp` or `tcp` | Network protocol to use |
| `OPENVPN_VERBOSITY` | `1` | `0` to `6` | Openvpn verbosity level | | `OPENVPN_VERBOSITY` | `1` | `0` to `6` | Openvpn verbosity level |
| `OPENVPN_ROOT` | `no` | `yes` or `no` | Run OpenVPN as root | | `OPENVPN_ROOT` | `no` | `yes` or `no` | Run OpenVPN as root |
@@ -192,8 +192,9 @@ docker run --rm --network=container:gluetun alpine:3.12 wget -qO- https://ipinfo
| Variable | Default | Choices | Description | | Variable | Default | Choices | Description |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| 🏁 `USER` | | | Your user ID | | 🏁 `USER` | | | Your username |
| 🏁 `REGION` | | One of the [PureVPN regions](https://support.purevpn.com/vpn-servers) | VPN server region | | 🏁 `PASSWORD` | | | Your password |
| `REGION` | | One of the [PureVPN regions](https://support.purevpn.com/vpn-servers) | VPN server region |
| `COUNTRY` | | One of the [PureVPN countries](https://support.purevpn.com/vpn-servers) | VPN server country | | `COUNTRY` | | One of the [PureVPN countries](https://support.purevpn.com/vpn-servers) | VPN server country |
| `CITY` | | One of the [PureVPN cities](https://support.purevpn.com/vpn-servers) | VPN server city | | `CITY` | | One of the [PureVPN cities](https://support.purevpn.com/vpn-servers) | VPN server city |

View File

@@ -118,7 +118,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
// Should never change // Should never change
uid, gid := allSettings.System.UID, allSettings.System.GID uid, gid := allSettings.System.UID, allSettings.System.GID
err = alpineConf.CreateUser("nonrootuser", uid) const defaultUsername = "nonrootuser"
nonRootUsername, err := alpineConf.CreateUser(defaultUsername, uid)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return 1 return 1
@@ -217,31 +218,29 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers, openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, uid, gid, allServers,
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel) ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
wg.Add(1) wg.Add(1)
// wait for restartOpenvpn // wait for restartOpenvpn
go openvpnLooper.Run(ctx, wg) go openvpnLooper.Run(ctx, wg)
updaterOptions := updater.NewOptions("127.0.0.1") updaterLooper := updater.NewLooper(allSettings.Updater,
updaterLooper := updater.NewLooper(updaterOptions, allSettings.UpdaterPeriod, allServers, storage, openvpnLooper.SetServers, httpClient, logger)
allServers, storage, openvpnLooper.SetAllServers, httpClient, logger)
wg.Add(1) wg.Add(1)
// wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker // wait for updaterLooper.Restart() or its ticket launched with RunRestartTicker
go updaterLooper.Run(ctx, wg) go updaterLooper.Run(ctx, wg)
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, nonRootUsername, uid, gid)
wg.Add(1) wg.Add(1)
// wait for unboundLooper.Restart or its ticker launched with RunRestartTicker // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg, signalDNSReady) go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper(client, logger, fileManager, publicIPLooper := publicip.NewLooper(
allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid) client, logger, fileManager, allSettings.PublicIP, uid, gid)
wg.Add(1) wg.Add(1)
go publicIPLooper.Run(ctx, wg) go publicIPLooper.Run(ctx, wg)
wg.Add(1) wg.Add(1)
go publicIPLooper.RunRestartTicker(ctx, wg) go publicIPLooper.RunRestartTicker(ctx, wg)
publicIPLooper.SetPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker
httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy) httpProxyLooper := httpproxy.NewLooper(logger, allSettings.HTTPProxy)
wg.Add(1) wg.Add(1)
@@ -267,7 +266,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
controlServerAddress := fmt.Sprintf("0.0.0.0:%d", allSettings.ControlServer.Port) controlServerAddress := fmt.Sprintf("0.0.0.0:%d", allSettings.ControlServer.Port)
controlServerLogging := allSettings.ControlServer.Log controlServerLogging := allSettings.ControlServer.Log
httpServer := server.New(controlServerAddress, controlServerLogging, httpServer := server.New(controlServerAddress, controlServerLogging,
logger, buildInfo, openvpnLooper, unboundLooper, updaterLooper) logger, buildInfo, openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
wg.Add(1) wg.Add(1)
go httpServer.Run(ctx, wg) go httpServer.Run(ctx, wg)
@@ -276,8 +275,9 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
wg.Add(1) wg.Add(1)
go healthcheckServer.Run(ctx, wg) go healthcheckServer.Run(ctx, wg)
// Start openvpn for the first time // Start openvpn for the first time in a blocking call
openvpnLooper.Restart() // until openvpn is launched
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
signalsCh := make(chan os.Signal, 1) signalsCh := make(chan os.Signal, 1)
signal.Notify(signalsCh, signal.Notify(signalsCh,
@@ -293,11 +293,6 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
case <-ctx.Done(): case <-ctx.Done():
logger.Warn("context canceled, shutting down") logger.Warn("context canceled, shutting down")
} }
logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath)
if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil {
logger.Error(err)
shutdownErrorsCount++
}
if allSettings.OpenVPN.Provider.PortForwarding.Enabled { if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath) logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
@@ -401,7 +396,7 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
tickerWg.Wait() tickerWg.Wait()
return return
case <-tunnelReadyCh: // blocks until openvpn is connected case <-tunnelReadyCh: // blocks until openvpn is connected
unboundLooper.Restart() _, _ = unboundLooper.SetStatus(constants.Running)
restartTickerCancel() // stop previous restart tickers restartTickerCancel() // stop previous restart tickers
tickerWg.Wait() tickerWg.Wait()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx) restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
@@ -424,7 +419,8 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
startPortForward(vpnGateway) startPortForward(vpnGateway)
} }
case <-dnsReadyCh: case <-dnsReadyCh:
publicIPLooper.Restart() // TODO do not restart if disabled // Runs the Public IP getter job once
_, _ = publicIPLooper.SetStatus(constants.Running)
if !versionInformation { if !versionInformation {
break break
} }

1720
doc/logo.svg Normal file

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 62 KiB

BIN
doc/logo_256.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@@ -7,7 +7,7 @@ import (
) )
type Configurator interface { type Configurator interface {
CreateUser(username string, uid int) error CreateUser(username string, uid int) (createdUsername string, err error)
} }
type configurator struct { type configurator struct {

View File

@@ -6,34 +6,34 @@ import (
) )
// CreateUser creates a user in Alpine with the given UID. // CreateUser creates a user in Alpine with the given UID.
func (c *configurator) CreateUser(username string, uid int) error { func (c *configurator) CreateUser(username string, uid int) (createdUsername string, err error) {
UIDStr := fmt.Sprintf("%d", uid) UIDStr := fmt.Sprintf("%d", uid)
u, err := c.lookupUID(UIDStr) u, err := c.lookupUID(UIDStr)
_, unknownUID := err.(user.UnknownUserIdError) _, unknownUID := err.(user.UnknownUserIdError)
if err != nil && !unknownUID { if err != nil && !unknownUID {
return fmt.Errorf("cannot create user: %w", err) return "", fmt.Errorf("cannot create user: %w", err)
} else if u != nil { } else if u != nil {
if u.Username == username { if u.Username == username {
return nil return "", nil
} }
return fmt.Errorf("user with ID %d exists with username %q instead of %q", uid, u.Username, username) return u.Username, nil
} }
u, err = c.lookupUser(username) u, err = c.lookupUser(username)
_, unknownUsername := err.(user.UnknownUserError) _, unknownUsername := err.(user.UnknownUserError)
if err != nil && !unknownUsername { if err != nil && !unknownUsername {
return fmt.Errorf("cannot create user: %w", err) return "", fmt.Errorf("cannot create user: %w", err)
} else if u != nil { } else if u != nil {
return fmt.Errorf("cannot create user: user with name %s already exists for ID %s instead of %d", return "", fmt.Errorf("cannot create user: user with name %s already exists for ID %s instead of %d",
username, u.Uid, uid) username, u.Uid, uid)
} }
passwd, err := c.fileManager.ReadFile("/etc/passwd") passwd, err := c.fileManager.ReadFile("/etc/passwd")
if err != nil { if err != nil {
return fmt.Errorf("cannot create user: %w", err) return "", fmt.Errorf("cannot create user: %w", err)
} }
passwd = append(passwd, []byte(fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid))...) passwd = append(passwd, []byte(fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid))...)
if err := c.fileManager.WriteToFile("/etc/passwd", passwd); err != nil { if err := c.fileManager.WriteToFile("/etc/passwd", passwd); err != nil {
return fmt.Errorf("cannot create user: %w", err) return "", fmt.Errorf("cannot create user: %w", err)
} }
return nil return username, nil
} }

View File

@@ -71,8 +71,7 @@ func OpenvpnConfig() error {
lines := providerConf.BuildConf( lines := providerConf.BuildConf(
connection, connection,
allSettings.OpenVPN.Verbosity, allSettings.OpenVPN.Verbosity,
allSettings.System.UID, "nonroortuser",
allSettings.System.GID,
allSettings.OpenVPN.Root, allSettings.OpenVPN.Root,
allSettings.OpenVPN.Cipher, allSettings.OpenVPN.Cipher,
allSettings.OpenVPN.Auth, allSettings.OpenVPN.Auth,
@@ -83,7 +82,7 @@ func OpenvpnConfig() error {
} }
func Update(args []string) error { func Update(args []string) error {
options := updater.Options{CLI: true} options := settings.Updater{CLI: true}
var flushToFile bool var flushToFile bool
flagSet := flag.NewFlagSet("update", flag.ExitOnError) flagSet := flag.NewFlagSet("update", flag.ExitOnError)
flagSet.BoolVar(&flushToFile, "file", false, "Write results to /gluetun/servers.json (for end users)") flagSet.BoolVar(&flushToFile, "file", false, "Write results to /gluetun/servers.json (for end users)")

View File

@@ -97,7 +97,7 @@ func DNSProviderMapping() map[models.DNSProvider]models.DNSProviderData {
} }
} }
// Block lists URLs // Block lists URLs.
//nolint:lll //nolint:lll
const ( const (
AdsBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated" AdsBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated"

View File

@@ -0,0 +1,14 @@
package constants
import (
"github.com/qdm12/gluetun/internal/models"
)
const (
Starting models.LoopStatus = "starting"
Running models.LoopStatus = "running"
Stopping models.LoopStatus = "stopping"
Stopped models.LoopStatus = "stopped"
Crashed models.LoopStatus = "crashed"
Completed models.LoopStatus = "completed"
)

View File

@@ -14,9 +14,10 @@ import (
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
) )
func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS, uid, gid int) (err error) { func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS,
username string, uid, gid int) (err error) {
c.logger.Info("generating Unbound configuration") c.logger.Info("generating Unbound configuration")
lines, warnings := generateUnboundConf(ctx, settings, c.client, c.logger) lines, warnings := generateUnboundConf(ctx, settings, username, c.client, c.logger)
for _, warning := range warnings { for _, warning := range warnings {
c.logger.Warn(warning) c.logger.Warn(warning)
} }
@@ -28,7 +29,7 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN
} }
// MakeUnboundConf generates an Unbound configuration from the user provided settings. // MakeUnboundConf generates an Unbound configuration from the user provided settings.
func generateUnboundConf(ctx context.Context, settings settings.DNS, func generateUnboundConf(ctx context.Context, settings settings.DNS, username string,
client network.Client, logger logging.Logger) ( client network.Client, logger logging.Logger) (
lines []string, warnings []error) { lines []string, warnings []error) {
doIPv6 := "no" doIPv6 := "no"
@@ -69,7 +70,7 @@ func generateUnboundConf(ctx context.Context, settings settings.DNS,
"interface": "0.0.0.0", "interface": "0.0.0.0",
"port": "53", "port": "53",
// Other // Other
"username": "\"nonrootuser\"", "username": fmt.Sprintf("%q", username),
} }
// Block lists // Block lists

View File

@@ -41,7 +41,7 @@ func Test_generateUnboundConf(t *testing.T) {
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("%d hostnames blocked overall", 2).Times(1) logger.EXPECT().Info("%d hostnames blocked overall", 2).Times(1)
logger.EXPECT().Info("%d IP addresses blocked overall", 3).Times(1) logger.EXPECT().Info("%d IP addresses blocked overall", 3).Times(1)
lines, warnings := generateUnboundConf(ctx, settings, client, logger) lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger)
require.Len(t, warnings, 0) require.Len(t, warnings, 0)
expected := ` expected := `
server: server:

View File

@@ -15,7 +15,7 @@ import (
type Configurator interface { type Configurator interface {
DownloadRootHints(ctx context.Context, uid, gid int) error DownloadRootHints(ctx context.Context, uid, gid int) error
DownloadRootKey(ctx context.Context, uid, gid int) error DownloadRootKey(ctx context.Context, uid, gid int) error
MakeUnboundConf(ctx context.Context, settings settings.DNS, uid, gid int) (err error) MakeUnboundConf(ctx context.Context, settings settings.DNS, username string, uid, gid int) (err error)
UseDNSInternally(IP net.IP) UseDNSInternally(IP net.IP)
UseDNSSystemWide(ip net.IP, keepNameserver bool) error UseDNSSystemWide(ip net.IP, keepNameserver bool) error
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error) Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
@@ -15,80 +16,53 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func())
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart() GetStatus() (status models.LoopStatus)
Start() SetStatus(status models.LoopStatus) (outcome string, err error)
Stop()
GetSettings() (settings settings.DNS) GetSettings() (settings settings.DNS)
SetSettings(settings settings.DNS) SetSettings(settings settings.DNS) (outcome string)
} }
type looper struct { type looper struct {
conf Configurator state state
settings settings.DNS conf Configurator
settingsMutex sync.RWMutex logger logging.Logger
logger logging.Logger streamMerger command.StreamMerger
streamMerger command.StreamMerger username string
uid int uid int
gid int gid int
restart chan struct{} loopLock sync.Mutex
start chan struct{} start chan struct{}
stop chan struct{} running chan models.LoopStatus
updateTicker chan struct{} stop chan struct{}
timeNow func() time.Time stopped chan struct{}
timeSince func(time.Time) time.Duration updateTicker chan struct{}
timeNow func() time.Time
timeSince func(time.Time) time.Duration
} }
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
streamMerger command.StreamMerger, uid, gid int) Looper { streamMerger command.StreamMerger, username string, uid, gid int) Looper {
return &looper{ return &looper{
state: state{
status: constants.Stopped,
settings: settings,
},
conf: conf, conf: conf,
settings: settings,
logger: logger.WithPrefix("dns over tls: "), logger: logger.WithPrefix("dns over tls: "),
username: username,
uid: uid, uid: uid,
gid: gid, gid: gid,
streamMerger: streamMerger, streamMerger: streamMerger,
restart: make(chan struct{}),
start: make(chan struct{}), start: make(chan struct{}),
running: make(chan models.LoopStatus),
stop: make(chan struct{}), stop: make(chan struct{}),
stopped: make(chan struct{}),
updateTicker: make(chan struct{}), updateTicker: make(chan struct{}),
timeNow: time.Now, timeNow: time.Now,
timeSince: time.Since, timeSince: time.Since,
} }
} }
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) Start() { l.start <- struct{}{} }
func (l *looper) Stop() { l.stop <- struct{}{} }
func (l *looper) GetSettings() (settings settings.DNS) {
l.settingsMutex.RLock()
defer l.settingsMutex.RUnlock()
return l.settings
}
func (l *looper) SetSettings(settings settings.DNS) {
l.settingsMutex.Lock()
defer l.settingsMutex.Unlock()
updatePeriodDiffers := l.settings.UpdatePeriod != settings.UpdatePeriod
l.settings = settings
l.settingsMutex.Unlock()
if updatePeriodDiffers {
l.updateTicker <- struct{}{}
}
}
func (l *looper) isEnabled() bool {
l.settingsMutex.RLock()
defer l.settingsMutex.RUnlock()
return l.settings.Enabled
}
func (l *looper) setEnabled(enabled bool) {
l.settingsMutex.Lock()
defer l.settingsMutex.Unlock()
l.settings.Enabled = enabled
}
func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Warn(err) l.logger.Warn(err)
l.logger.Info("attempting restart in 10 seconds") l.logger.Info("attempting restart in 10 seconds")
@@ -103,96 +77,42 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
} }
func (l *looper) waitForFirstStart(ctx context.Context, signalDNSReady func()) {
for {
select {
case <-l.stop:
l.setEnabled(false)
l.logger.Info("not started yet")
case <-l.restart:
if l.isEnabled() {
return
}
signalDNSReady()
l.logger.Info("not restarting because disabled")
case <-l.start:
l.setEnabled(true)
return
case <-ctx.Done():
return
}
}
}
func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel context.CancelFunc) {
if l.isEnabled() {
return
}
for {
// wait for a signal to re-enable
select {
case <-l.stop:
l.logger.Info("already disabled")
case <-l.restart:
if !l.isEnabled() {
l.logger.Info("not restarting because disabled")
} else {
return
}
case <-l.start:
l.setEnabled(true)
return
case <-ctx.Done():
unboundCancel()
return
}
}
}
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady func()) {
defer wg.Done() defer wg.Done()
const fallback = false const fallback = false
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback) // TODO remove? Use default DNS by default for Docker resolution?
l.waitForFirstStart(ctx, signalDNSReady)
if ctx.Err() != nil { select {
case <-l.start:
case <-ctx.Done():
return return
} }
defer l.logger.Warn("loop exited") defer l.logger.Warn("loop exited")
var unboundCtx context.Context
var unboundCancel context.CancelFunc = func() {}
var waitError chan error
triggeredRestart := false
l.setEnabled(true)
for ctx.Err() == nil { for ctx.Err() == nil {
l.waitForSubsequentStart(ctx, unboundCancel) err := l.updateFiles(ctx)
if err == nil {
break
}
l.state.setStatusWithLock(constants.Crashed)
l.logAndWait(ctx, err)
}
crashed := false
for ctx.Err() == nil {
settings := l.GetSettings() settings := l.GetSettings()
// Setup unboundCtx, unboundCancel := context.WithCancel(context.Background())
if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil {
l.logAndWait(ctx, err)
continue
}
if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil {
l.logAndWait(ctx, err)
continue
}
if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil {
l.logAndWait(ctx, err)
continue
}
if triggeredRestart {
triggeredRestart = false
unboundCancel()
<-waitError
close(waitError)
}
unboundCtx, unboundCancel = context.WithCancel(context.Background())
stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel) stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel)
if err != nil { if err != nil {
unboundCancel() unboundCancel()
if !crashed {
l.running <- constants.Crashed
}
crashed = true
const fallback = true const fallback = true
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
@@ -201,23 +121,37 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
// Started successfully // Started successfully
go l.streamMerger.Merge(unboundCtx, stream, command.MergeName("unbound")) go l.streamMerger.Merge(unboundCtx, stream, command.MergeName("unbound"))
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}, settings.KeepNameserver); err != nil { // use Unbound
l.logger.Error(err) l.logger.Error(err)
} }
if err := l.conf.WaitForUnbound(); err != nil { if err := l.conf.WaitForUnbound(); err != nil {
if !crashed {
l.running <- constants.Crashed
crashed = true
}
unboundCancel() unboundCancel()
const fallback = true const fallback = true
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
continue continue
} }
waitError = make(chan error)
waitError := make(chan error)
go func() { go func() {
err := waitFn() // blocking err := waitFn() // blocking
waitError <- err waitError <- err
}() }()
l.logger.Info("DNS over TLS is ready") l.logger.Info("DNS over TLS is ready")
if !crashed {
l.running <- constants.Running
crashed = false
} else {
l.state.setStatusWithLock(constants.Running)
}
signalDNSReady() signalDNSReady()
stayHere := true stayHere := true
@@ -229,31 +163,28 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
<-waitError <-waitError
close(waitError) close(waitError)
return return
case <-l.restart: // triggered restart
l.logger.Info("restarting")
// unboundCancel occurs next loop run when the setup is complete
triggeredRestart = true
stayHere = false
case <-l.start:
l.logger.Info("already started")
case <-l.stop: case <-l.stop:
l.logger.Info("stopping") l.logger.Info("stopping")
const fallback = false
l.useUnencryptedDNS(fallback)
unboundCancel() unboundCancel()
<-waitError <-waitError
close(waitError) l.stopped <- struct{}{}
l.setEnabled(false) case <-l.start:
l.logger.Info("starting")
stayHere = false stayHere = false
case err := <-waitError: // unexpected error case err := <-waitError: // unexpected error
close(waitError)
unboundCancel() unboundCancel()
l.state.setStatusWithLock(constants.Crashed)
const fallback = true const fallback = true
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
stayHere = false stayHere = false
} }
} }
close(waitError)
unboundCancel()
} }
unboundCancel()
} }
func (l *looper) useUnencryptedDNS(fallback bool) { func (l *looper) useUnencryptedDNS(fallback bool) {
@@ -279,7 +210,11 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
data := constants.DNSProviderMapping()[provider] data := constants.DNSProviderMapping()[provider]
for _, targetIP = range data.IPs { for _, targetIP = range data.IPs {
if targetIP.To4() != nil { if targetIP.To4() != nil {
l.logger.Info("falling back on plaintext DNS at address %s", targetIP) if fallback {
l.logger.Info("falling back on plaintext DNS at address %s", targetIP)
} else {
l.logger.Info("using plaintext DNS at address %s", targetIP)
}
l.conf.UseDNSInternally(targetIP) l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP, settings.KeepNameserver); err != nil { if err := l.conf.UseDNSSystemWide(targetIP, settings.KeepNameserver); err != nil {
l.logger.Error(err) l.logger.Error(err)
@@ -314,7 +249,20 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
return return
case <-timer.C: case <-timer.C:
lastTick = l.timeNow() lastTick = l.timeNow()
l.restart <- struct{}{}
status := l.GetStatus()
if status == constants.Running {
if err := l.updateFiles(ctx); err != nil {
l.state.setStatusWithLock(constants.Crashed)
l.logger.Error(err)
l.logger.Warn("skipping Unbound restart due to failed files update")
continue
}
}
_, _ = l.SetStatus(constants.Stopped)
_, _ = l.SetStatus(constants.Running)
settings := l.GetSettings() settings := l.GetSettings()
timer.Reset(settings.UpdatePeriod) timer.Reset(settings.UpdatePeriod)
case <-l.updateTicker: case <-l.updateTicker:
@@ -337,3 +285,17 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
} }
} }
} }
func (l *looper) updateFiles(ctx context.Context) (err error) {
if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil {
return err
}
if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil {
return err
}
settings := l.GetSettings()
if err := l.conf.MakeUnboundConf(ctx, settings, l.username, l.uid, l.gid); err != nil {
return err
}
return nil
}

96
internal/dns/state.go Normal file
View File

@@ -0,0 +1,96 @@
package dns
import (
"fmt"
"reflect"
"sync"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
)
type state struct {
status models.LoopStatus
settings settings.DNS
statusMu sync.RWMutex
settingsMu sync.RWMutex
}
func (s *state) setStatusWithLock(status models.LoopStatus) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.status = status
}
func (l *looper) GetStatus() (status models.LoopStatus) {
l.state.statusMu.RLock()
defer l.state.statusMu.RUnlock()
return l.state.status
}
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
l.state.statusMu.Lock()
defer l.state.statusMu.Unlock()
existingStatus := l.state.status
switch status {
case constants.Running:
switch existingStatus {
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Starting
l.state.statusMu.Unlock()
l.start <- struct{}{}
newStatus := <-l.running
l.state.statusMu.Lock()
l.state.status = newStatus
return newStatus.String(), nil
case constants.Stopped:
switch existingStatus {
case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Stopping
l.state.statusMu.Unlock()
l.stop <- struct{}{}
<-l.stopped
l.state.statusMu.Lock()
l.state.status = constants.Stopped
return status.String(), nil
default:
return "", fmt.Errorf("status %q can only be %q or %q",
status, constants.Running, constants.Stopped)
}
}
func (l *looper) GetSettings() (settings settings.DNS) {
l.state.settingsMu.RLock()
defer l.state.settingsMu.RUnlock()
return l.state.settings
}
func (l *looper) SetSettings(settings settings.DNS) (outcome string) {
l.state.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
if settingsUnchanged {
l.state.settingsMu.Unlock()
return "settings left unchanged"
}
tempSettings := l.state.settings
tempSettings.UpdatePeriod = settings.UpdatePeriod
onlyUpdatePeriodChanged := reflect.DeepEqual(tempSettings, settings)
l.state.settings = settings
if onlyUpdatePeriodChanged {
l.updateTicker <- struct{}{}
return "update period changed"
}
_, _ = l.SetStatus(constants.Stopped)
outcome, _ = l.SetStatus(constants.Running)
return outcome
}

View File

@@ -20,8 +20,14 @@ type (
VPNProvider string VPNProvider string
// NetworkProtocol contains the network protocol to be used to communicate with the VPN servers. // NetworkProtocol contains the network protocol to be used to communicate with the VPN servers.
NetworkProtocol string NetworkProtocol string
// Loop status such as stopped or running.
LoopStatus string
) )
func (ls LoopStatus) String() string {
return string(ls)
}
func marshalJSONString(s string) (data []byte, err error) { func marshalJSONString(s string) (data []byte, err error) {
return []byte(fmt.Sprintf("%q", s)), nil return []byte(fmt.Sprintf("%q", s)), nil
} }

View File

@@ -3,5 +3,5 @@ package models
type BuildInformation struct { type BuildInformation struct {
Version string `json:"version"` Version string `json:"version"`
Commit string `json:"commit"` Commit string `json:"commit"`
BuildDate string `json:"buildDate"` BuildDate string `json:"build_date"`
} }

View File

@@ -1,6 +1,8 @@
package models package models
import "net" import (
"net"
)
type OpenVPNConnection struct { type OpenVPNConnection struct {
IP net.IP IP net.IP

View File

@@ -9,15 +9,15 @@ import (
// ProviderSettings contains settings specific to a VPN provider. // ProviderSettings contains settings specific to a VPN provider.
type ProviderSettings struct { type ProviderSettings struct {
Name VPNProvider `json:"name"` Name VPNProvider `json:"name"`
ServerSelection ServerSelection `json:"serverSelection"` ServerSelection ServerSelection `json:"server_selection"`
ExtraConfigOptions ExtraConfigOptions `json:"extraConfig"` ExtraConfigOptions ExtraConfigOptions `json:"extra_config"`
PortForwarding PortForwarding `json:"portForwarding"` PortForwarding PortForwarding `json:"port_forwarding"`
} }
type ServerSelection struct { type ServerSelection struct {
// Common // Common
Protocol NetworkProtocol `json:"networkProtocol"` Protocol NetworkProtocol `json:"network_protocol"`
TargetIP net.IP `json:"targetIP,omitempty"` TargetIP net.IP `json:"target_ip,omitempty"`
// Cyberghost, PIA, Surfshark, Windscribe, Vyprvpn, NordVPN // Cyberghost, PIA, Surfshark, Windscribe, Vyprvpn, NordVPN
Regions []string `json:"regions"` Regions []string `json:"regions"`
@@ -34,20 +34,20 @@ type ServerSelection struct {
Owned bool `json:"owned"` Owned bool `json:"owned"`
// Mullvad, Windscribe // Mullvad, Windscribe
CustomPort uint16 `json:"customPort"` CustomPort uint16 `json:"custom_port"`
// NordVPN // NordVPN
Numbers []uint16 `json:"numbers"` Numbers []uint16 `json:"numbers"`
// PIA // PIA
EncryptionPreset string `json:"encryptionPreset"` EncryptionPreset string `json:"encryption_preset"`
} }
type ExtraConfigOptions struct { type ExtraConfigOptions struct {
ClientCertificate string `json:"-"` // Cyberghost ClientCertificate string `json:"-"` // Cyberghost
ClientKey string `json:"-"` // Cyberghost ClientKey string `json:"-"` // Cyberghost
EncryptionPreset string `json:"encryptionPreset"` // PIA EncryptionPreset string `json:"encryption_preset"` // PIA
OpenVPNIPv6 bool `json:"openvpnIPv6"` // Mullvad OpenVPNIPv6 bool `json:"openvpn_ipv6"` // Mullvad
} }
// PortForwarding contains settings for port forwarding. // PortForwarding contains settings for port forwarding.

View File

@@ -20,26 +20,22 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, wg *sync.WaitGroup)
Restart() GetStatus() (status models.LoopStatus)
PortForward(vpnGatewayIP net.IP) SetStatus(status models.LoopStatus) (outcome string, err error)
GetSettings() (settings settings.OpenVPN) GetSettings() (settings settings.OpenVPN)
SetSettings(settings settings.OpenVPN) SetSettings(settings settings.OpenVPN) (outcome string)
GetPortForwarded() (portForwarded uint16) GetServers() (servers models.AllServers)
SetAllServers(allServers models.AllServers) SetServers(servers models.AllServers)
GetPortForwarded() (port uint16)
PortForward(vpnGatewayIP net.IP)
} }
type looper struct { type looper struct {
// Variable parameters state state
provider models.VPNProvider
settings settings.OpenVPN
settingsMutex sync.RWMutex
portForwarded uint16
portForwardedMutex sync.RWMutex
allServers models.AllServers
allServersMutex sync.RWMutex
// Fixed parameters // Fixed parameters
uid int username string
gid int uid int
gid int
// Configurators // Configurators
conf Configurator conf Configurator
fw firewall.Configurator fw firewall.Configurator
@@ -50,22 +46,28 @@ type looper struct {
fileManager files.FileManager fileManager files.FileManager
streamMerger command.StreamMerger streamMerger command.StreamMerger
cancel context.CancelFunc cancel context.CancelFunc
// Internal channels // Internal channels and locks
restart chan struct{} loopLock sync.Mutex
running chan models.LoopStatus
stop, stopped chan struct{}
start chan struct{}
portForwardSignals chan net.IP portForwardSignals chan net.IP
} }
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, func NewLooper(settings settings.OpenVPN,
uid, gid int, allServers models.AllServers, username string, uid, gid int, allServers models.AllServers,
conf Configurator, fw firewall.Configurator, routing routing.Routing, conf Configurator, fw firewall.Configurator, routing routing.Routing,
logger logging.Logger, client *http.Client, fileManager files.FileManager, logger logging.Logger, client *http.Client, fileManager files.FileManager,
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper { streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
return &looper{ return &looper{
provider: provider, state: state{
settings: settings, status: constants.Stopped,
settings: settings,
allServers: allServers,
},
username: username,
uid: uid, uid: uid,
gid: gid, gid: gid,
allServers: allServers,
conf: conf, conf: conf,
fw: fw, fw: fw,
routing: routing, routing: routing,
@@ -75,46 +77,29 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
fileManager: fileManager, fileManager: fileManager,
streamMerger: streamMerger, streamMerger: streamMerger,
cancel: cancel, cancel: cancel,
restart: make(chan struct{}), start: make(chan struct{}),
running: make(chan models.LoopStatus),
stop: make(chan struct{}),
stopped: make(chan struct{}),
portForwardSignals: make(chan net.IP), portForwardSignals: make(chan net.IP),
} }
} }
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway } func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
func (l *looper) GetSettings() (settings settings.OpenVPN) {
l.settingsMutex.RLock()
defer l.settingsMutex.RUnlock()
return l.settings
}
func (l *looper) SetSettings(settings settings.OpenVPN) {
l.settingsMutex.Lock()
defer l.settingsMutex.Unlock()
l.settings = settings
}
func (l *looper) SetAllServers(allServers models.AllServers) {
l.allServersMutex.Lock()
defer l.allServersMutex.Unlock()
l.allServers = allServers
}
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
crashed := false
select { select {
case <-l.restart: case <-l.start:
case <-ctx.Done(): case <-ctx.Done():
return return
} }
defer l.logger.Warn("loop exited") defer l.logger.Warn("loop exited")
for ctx.Err() == nil { for ctx.Err() == nil {
settings := l.GetSettings() settings, allServers := l.state.getSettingsAndServers()
l.allServersMutex.RLock() providerConf := provider.New(settings.Provider.Name, allServers, time.Now)
providerConf := provider.New(l.provider, l.allServers, time.Now)
l.allServersMutex.RUnlock()
connection, err := providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection) connection, err := providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection)
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
@@ -124,8 +109,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
lines := providerConf.BuildConf( lines := providerConf.BuildConf(
connection, connection,
settings.Verbosity, settings.Verbosity,
l.uid, l.username,
l.gid,
settings.Root, settings.Root,
settings.Cipher, settings.Cipher,
settings.Auth, settings.Auth,
@@ -155,6 +139,10 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
stream, waitFn, err := l.conf.Start(openvpnCtx) stream, waitFn, err := l.conf.Start(openvpnCtx)
if err != nil { if err != nil {
openvpnCancel() openvpnCancel()
if !crashed {
l.running <- constants.Crashed
crashed = true
}
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
continue continue
} }
@@ -179,23 +167,41 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
err := waitFn() // blocking err := waitFn() // blocking
waitError <- err waitError <- err
}() }()
select {
case <-ctx.Done(): if !crashed {
l.logger.Warn("context canceled: exiting loop") l.running <- constants.Running
openvpnCancel() crashed = false
<-waitError } else {
close(waitError) l.state.setStatusWithLock(constants.Running)
return
case <-l.restart: // triggered restart
l.logger.Info("restarting")
openvpnCancel()
<-waitError
close(waitError)
case err := <-waitError: // unexpected error
openvpnCancel()
close(waitError)
l.logAndWait(ctx, err)
} }
stayHere := true
for stayHere {
select {
case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
openvpnCancel()
<-waitError
close(waitError)
return
case <-l.stop:
l.logger.Info("stopping")
openvpnCancel()
<-waitError
l.stopped <- struct{}{}
case <-l.start:
l.logger.Info("starting")
stayHere = false
case err := <-waitError: // unexpected error
openvpnCancel()
l.state.setStatusWithLock(constants.Crashed)
l.logAndWait(ctx, err)
crashed = true
stayHere = false
}
}
close(waitError)
openvpnCancel() // just for the linter
} }
} }
@@ -218,24 +224,21 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup, func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
providerConf provider.Provider, client *http.Client, gateway net.IP) { providerConf provider.Provider, client *http.Client, gateway net.IP) {
defer wg.Done() defer wg.Done()
settings := l.GetSettings() l.state.portForwardedMu.RLock()
settings := l.state.settings
l.state.portForwardedMu.RUnlock()
if !settings.Provider.PortForwarding.Enabled { if !settings.Provider.PortForwarding.Enabled {
return return
} }
syncState := func(port uint16) (pfFilepath models.Filepath) { syncState := func(port uint16) (pfFilepath models.Filepath) {
l.portForwardedMutex.Lock() l.state.portForwardedMu.Lock()
l.portForwarded = port defer l.state.portForwardedMu.Unlock()
l.portForwardedMutex.Unlock() l.state.portForwarded = port
settings := l.GetSettings() l.state.settingsMu.RLock()
defer l.state.settingsMu.RUnlock()
return settings.Provider.PortForwarding.Filepath return settings.Provider.PortForwarding.Filepath
} }
providerConf.PortForward(ctx, providerConf.PortForward(ctx,
client, l.fileManager, l.pfLogger, client, l.fileManager, l.pfLogger,
gateway, l.fw, syncState) gateway, l.fw, syncState)
} }
func (l *looper) GetPortForwarded() (portForwarded uint16) {
l.portForwardedMutex.RLock()
defer l.portForwardedMutex.RUnlock()
return l.portForwarded
}

121
internal/openvpn/state.go Normal file
View File

@@ -0,0 +1,121 @@
package openvpn
import (
"fmt"
"reflect"
"sync"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
)
type state struct {
status models.LoopStatus
settings settings.OpenVPN
allServers models.AllServers
portForwarded uint16
statusMu sync.RWMutex
settingsMu sync.RWMutex
allServersMu sync.RWMutex
portForwardedMu sync.RWMutex
}
func (s *state) setStatusWithLock(status models.LoopStatus) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.status = status
}
func (s *state) getSettingsAndServers() (settings settings.OpenVPN, allServers models.AllServers) {
s.settingsMu.RLock()
s.allServersMu.RLock()
settings = s.settings
allServers = s.allServers
s.settingsMu.RLock()
s.allServersMu.RLock()
return settings, allServers
}
func (l *looper) GetStatus() (status models.LoopStatus) {
l.state.statusMu.RLock()
defer l.state.statusMu.RUnlock()
return l.state.status
}
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
l.state.statusMu.Lock()
defer l.state.statusMu.Unlock()
existingStatus := l.state.status
switch status {
case constants.Running:
switch existingStatus {
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Starting
l.state.statusMu.Unlock()
l.start <- struct{}{}
newStatus := <-l.running
l.state.statusMu.Lock()
l.state.status = newStatus
return newStatus.String(), nil
case constants.Stopped:
switch existingStatus {
case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Stopping
l.state.statusMu.Unlock()
l.stop <- struct{}{}
<-l.stopped
l.state.statusMu.Lock()
l.state.status = constants.Stopped
return status.String(), nil
default:
return "", fmt.Errorf("status %q can only be %q or %q",
status, constants.Running, constants.Stopped)
}
}
func (l *looper) GetSettings() (settings settings.OpenVPN) {
l.state.settingsMu.RLock()
defer l.state.settingsMu.RUnlock()
return l.state.settings
}
func (l *looper) SetSettings(settings settings.OpenVPN) (outcome string) {
l.state.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(l.state.settings, settings)
if settingsUnchanged {
l.state.settingsMu.Unlock()
return "settings left unchanged"
}
l.state.settings = settings
_, _ = l.SetStatus(constants.Stopped)
outcome, _ = l.SetStatus(constants.Running)
return outcome
}
func (l *looper) GetServers() (servers models.AllServers) {
l.state.allServersMu.RLock()
defer l.state.allServersMu.RUnlock()
return l.state.allServers
}
func (l *looper) SetServers(servers models.AllServers) {
l.state.allServersMu.Lock()
defer l.state.allServersMu.Unlock()
l.state.allServers = servers
}
func (l *looper) GetPortForwarded() (port uint16) {
l.state.portForwardedMu.RLock()
defer l.state.portForwardedMu.RUnlock()
return l.state.portForwarded
}

View File

@@ -130,8 +130,8 @@ func (r *reader) GetDNSOverTLSPrivateAddresses() (privateAddresses []string, err
return privateAddresses, nil return privateAddresses, nil
} }
// GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using ipv6 DNS over TLS // GetDNSOverTLSIPv6 obtains if Unbound should resolve ipv6 addresses using
// servers from the environment variable DOT_IPV6. // ipv6 DNS over TLS from the environment variable DOT_IPV6.
func (r *reader) GetDNSOverTLSIPv6() (ipv6 bool, err error) { func (r *reader) GetDNSOverTLSIPv6() (ipv6 bool, err error) {
return r.envParams.GetOnOff("DOT_IPV6", libparams.Default("off")) return r.envParams.GetOnOff("DOT_IPV6", libparams.Default("off"))
} }

View File

@@ -37,7 +37,7 @@ type Reader interface {
GetUID() (uid int, err error) GetUID() (uid int, err error)
GetGID() (gid int, err error) GetGID() (gid int, err error)
GetTimezone() (timezone string, err error) GetTimezone() (timezone string, err error)
GetIPStatusFilepath() (filepath models.Filepath, err error) GetPublicIPFilepath() (filepath models.Filepath, err error)
// Firewall getters // Firewall getters
GetFirewall() (enabled bool, err error) GetFirewall() (enabled bool, err error)

View File

@@ -3,6 +3,7 @@ package params
import ( import (
"time" "time"
"github.com/qdm12/gluetun/internal/models"
libparams "github.com/qdm12/golibs/params" libparams "github.com/qdm12/golibs/params"
) )
@@ -15,3 +16,13 @@ func (r *reader) GetPublicIPPeriod() (period time.Duration, err error) {
} }
return time.ParseDuration(s) return time.ParseDuration(s)
} }
// GetPublicIPFilepath obtains the public IP filepath
// from the environment variable PUBLICIP_FILE with retro-compatible
// environment variable IP_STATUS_FILE.
func (r *reader) GetPublicIPFilepath() (filepath models.Filepath, err error) {
filepathStr, err := r.envParams.GetPath("PUBLICIP_FILE",
libparams.RetroKeys([]string{"IP_STATUS_FILE"}, r.onRetroActive),
libparams.Default("/tmp/gluetun/ip"), libparams.CaseSensitiveValue())
return models.Filepath(filepathStr), err
}

View File

@@ -1,7 +1,6 @@
package params package params
import ( import (
"github.com/qdm12/gluetun/internal/models"
libparams "github.com/qdm12/golibs/params" libparams "github.com/qdm12/golibs/params"
) )
@@ -19,11 +18,3 @@ func (r *reader) GetGID() (gid int, err error) {
func (r *reader) GetTimezone() (timezone string, err error) { func (r *reader) GetTimezone() (timezone string, err error) {
return r.envParams.GetEnv("TZ") return r.envParams.GetEnv("TZ")
} }
// GetIPStatusFilepath obtains the IP status file path
// from the environment variable IP_STATUS_FILE.
func (r *reader) GetIPStatusFilepath() (filepath models.Filepath, err error) {
filepathStr, err := r.envParams.GetPath("IP_STATUS_FILE",
libparams.Default("/tmp/gluetun/ip"), libparams.CaseSensitiveValue())
return models.Filepath(filepathStr), err
}

View File

@@ -62,8 +62,8 @@ func (c *cyberghost) GetOpenVPNConnection(selection models.ServerSelection) (
return pickRandomConnection(connections, c.randSource), nil return pickRandomConnection(connections, c.randSource), nil
} }
func (c *cyberghost) BuildConf(connection models.OpenVPNConnection, verbosity, func (c *cyberghost) BuildConf(connection models.OpenVPNConnection, verbosity int,
uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { username string, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
if len(cipher) == 0 { if len(cipher) == 0 {
cipher = aes256cbc cipher = aes256cbc
} }
@@ -105,7 +105,7 @@ func (c *cyberghost) BuildConf(connection models.OpenVPNConnection, verbosity,
lines = append(lines, "ncp-ciphers AES-256-GCM:AES-256-CBC:AES-128-GCM") lines = append(lines, "ncp-ciphers AES-256-GCM:AES-256-CBC:AES-128-GCM")
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<ca>", "<ca>",

View File

@@ -73,7 +73,7 @@ func (m *mullvad) GetOpenVPNConnection(selection models.ServerSelection) (
} }
func (m *mullvad) BuildConf(connection models.OpenVPNConnection, func (m *mullvad) BuildConf(connection models.OpenVPNConnection,
verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { verbosity int, username string, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
if len(cipher) == 0 { if len(cipher) == 0 {
cipher = aes256cbc cipher = aes256cbc
} }
@@ -114,7 +114,7 @@ func (m *mullvad) BuildConf(connection models.OpenVPNConnection,
lines = append(lines, `pull-filter ignore "ifconfig-ipv6"`) lines = append(lines, `pull-filter ignore "ifconfig-ipv6"`)
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<ca>", "<ca>",

View File

@@ -78,7 +78,7 @@ func (n *nordvpn) GetOpenVPNConnection(selection models.ServerSelection) (
return pickRandomConnection(connections, n.randSource), nil return pickRandomConnection(connections, n.randSource), nil
} }
func (n *nordvpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, func (n *nordvpn) BuildConf(connection models.OpenVPNConnection, verbosity int, username string, root bool,
cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
if len(cipher) == 0 { if len(cipher) == 0 {
cipher = aes256cbc cipher = aes256cbc
@@ -121,7 +121,7 @@ func (n *nordvpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid,
fmt.Sprintf("auth %s", auth), fmt.Sprintf("auth %s", auth),
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<ca>", "<ca>",

View File

@@ -109,7 +109,7 @@ func (p *pia) GetOpenVPNConnection(selection models.ServerSelection) (
return connection, nil return connection, nil
} }
func (p *pia) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, func (p *pia) BuildConf(connection models.OpenVPNConnection, verbosity int, username string, root bool,
cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
var X509CRL, certificate string var X509CRL, certificate string
var defaultCipher, defaultAuth string var defaultCipher, defaultAuth string
@@ -161,7 +161,7 @@ func (p *pia) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid
lines = append(lines, "ncp-disable") lines = append(lines, "ncp-disable")
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<crl-verify>", "<crl-verify>",

View File

@@ -70,7 +70,7 @@ func (s *privado) GetOpenVPNConnection(selection models.ServerSelection) (
return pickRandomConnection(connections, s.randSource), nil return pickRandomConnection(connections, s.randSource), nil
} }
func (s *privado) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, func (s *privado) BuildConf(connection models.OpenVPNConnection, verbosity int, username string, root bool,
cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
if len(cipher) == 0 { if len(cipher) == 0 {
cipher = aes256cbc cipher = aes256cbc
@@ -104,7 +104,7 @@ func (s *privado) BuildConf(connection models.OpenVPNConnection, verbosity, uid,
fmt.Sprintf("auth %s", auth), fmt.Sprintf("auth %s", auth),
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<ca>", "<ca>",

View File

@@ -15,7 +15,7 @@ import (
// Provider contains methods to read and modify the openvpn configuration to connect as a client. // Provider contains methods to read and modify the openvpn configuration to connect as a client.
type Provider interface { type Provider interface {
GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error) GetOpenVPNConnection(selection models.ServerSelection) (connection models.OpenVPNConnection, err error)
BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, BuildConf(connection models.OpenVPNConnection, verbosity int, username string,
root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string)
PortForward(ctx context.Context, client *http.Client, PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator, fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,

View File

@@ -72,7 +72,7 @@ func (p *purevpn) GetOpenVPNConnection(selection models.ServerSelection) (
return pickRandomConnection(connections, p.randSource), nil return pickRandomConnection(connections, p.randSource), nil
} }
func (p *purevpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, func (p *purevpn) BuildConf(connection models.OpenVPNConnection, verbosity int, username string, root bool,
cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
if len(cipher) == 0 { if len(cipher) == 0 {
cipher = aes256cbc cipher = aes256cbc
@@ -108,7 +108,7 @@ func (p *purevpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid,
fmt.Sprintf("cipher %s", cipher), fmt.Sprintf("cipher %s", cipher),
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<ca>", "<ca>",

View File

@@ -73,7 +73,7 @@ func (s *surfshark) GetOpenVPNConnection(selection models.ServerSelection) (
return pickRandomConnection(connections, s.randSource), nil return pickRandomConnection(connections, s.randSource), nil
} }
func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, root bool, func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity int, username string, root bool,
cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
if len(cipher) == 0 { if len(cipher) == 0 {
cipher = aes256cbc cipher = aes256cbc
@@ -104,6 +104,7 @@ func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity, ui
"auth-nocache", "auth-nocache",
"mute-replay-warnings", "mute-replay-warnings",
"pull-filter ignore \"auth-token\"", // prevent auth failed loops "pull-filter ignore \"auth-token\"", // prevent auth failed loops
"pull-filter ignore \"block-outside-dns\"",
"auth-retry nointeract", "auth-retry nointeract",
"suppress-timestamps", "suppress-timestamps",
@@ -116,7 +117,7 @@ func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity, ui
fmt.Sprintf("auth %s", auth), fmt.Sprintf("auth %s", auth),
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<ca>", "<ca>",

View File

@@ -69,7 +69,7 @@ func (v *vyprvpn) GetOpenVPNConnection(selection models.ServerSelection) (
return pickRandomConnection(connections, v.randSource), nil return pickRandomConnection(connections, v.randSource), nil
} }
func (v *vyprvpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, func (v *vyprvpn) BuildConf(connection models.OpenVPNConnection, verbosity int, username string,
root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
if len(cipher) == 0 { if len(cipher) == 0 {
cipher = aes256cbc cipher = aes256cbc
@@ -106,7 +106,7 @@ func (v *vyprvpn) BuildConf(connection models.OpenVPNConnection, verbosity, uid,
fmt.Sprintf("auth %s", auth), fmt.Sprintf("auth %s", auth),
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<ca>", "<ca>",

View File

@@ -72,7 +72,7 @@ func (w *windscribe) GetOpenVPNConnection(selection models.ServerSelection) (con
return pickRandomConnection(connections, w.randSource), nil return pickRandomConnection(connections, w.randSource), nil
} }
func (w *windscribe) BuildConf(connection models.OpenVPNConnection, verbosity, uid, gid int, func (w *windscribe) BuildConf(connection models.OpenVPNConnection, verbosity int, username string,
root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
if len(cipher) == 0 { if len(cipher) == 0 {
cipher = aes256cbc cipher = aes256cbc
@@ -111,7 +111,7 @@ func (w *windscribe) BuildConf(connection models.OpenVPNConnection, verbosity, u
lines = append(lines, "ncp-ciphers AES-256-GCM:AES-256-CBC:AES-128-GCM") lines = append(lines, "ncp-ciphers AES-256-GCM:AES-256-CBC:AES-128-GCM")
} }
if !root { if !root {
lines = append(lines, "user nonrootuser") lines = append(lines, "user "+username)
} }
lines = append(lines, []string{ lines = append(lines, []string{
"<ca>", "<ca>",

View File

@@ -2,10 +2,13 @@ package publicip
import ( import (
"context" "context"
"net"
"sync" "sync"
"time" "time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/files" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
@@ -14,62 +17,57 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart() GetStatus() (status models.LoopStatus)
Stop() SetStatus(status models.LoopStatus) (outcome string, err error)
GetPeriod() (period time.Duration) GetSettings() (settings settings.PublicIP)
SetPeriod(period time.Duration) SetSettings(settings settings.PublicIP) (outcome string)
GetPublicIP() (publicIP net.IP)
} }
type looper struct { type looper struct {
period time.Duration state state
periodMutex sync.RWMutex // Objects
getter IPGetter getter IPGetter
logger logging.Logger logger logging.Logger
fileManager files.FileManager fileManager files.FileManager
ipStatusFilepath models.Filepath // Fixed settings
uid int uid int
gid int gid int
restart chan struct{} // Internal channels and locks
stop chan struct{} loopLock sync.Mutex
updateTicker chan struct{} start chan struct{}
timeNow func() time.Time running chan models.LoopStatus
timeSince func(time.Time) time.Duration stop chan struct{}
stopped chan struct{}
updateTicker chan struct{}
// Mock functions
timeNow func() time.Time
timeSince func(time.Time) time.Duration
} }
func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager, func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager,
ipStatusFilepath models.Filepath, period time.Duration, uid, gid int) Looper { settings settings.PublicIP, uid, gid int) Looper {
return &looper{ return &looper{
period: period, state: state{
getter: NewIPGetter(client), status: constants.Stopped,
logger: logger.WithPrefix("ip getter: "), settings: settings,
fileManager: fileManager, },
ipStatusFilepath: ipStatusFilepath, // Objects
uid: uid, getter: NewIPGetter(client),
gid: gid, logger: logger.WithPrefix("ip getter: "),
restart: make(chan struct{}), fileManager: fileManager,
stop: make(chan struct{}), uid: uid,
updateTicker: make(chan struct{}), gid: gid,
timeNow: time.Now, start: make(chan struct{}),
timeSince: time.Since, running: make(chan models.LoopStatus),
stop: make(chan struct{}),
stopped: make(chan struct{}),
updateTicker: make(chan struct{}),
timeNow: time.Now,
timeSince: time.Since,
} }
} }
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) Stop() { l.stop <- struct{}{} }
func (l *looper) GetPeriod() (period time.Duration) {
l.periodMutex.RLock()
defer l.periodMutex.RUnlock()
return l.period
}
func (l *looper) SetPeriod(period time.Duration) {
l.periodMutex.Lock()
l.period = period
l.periodMutex.Unlock()
l.updateTicker <- struct{}{}
}
func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err) l.logger.Error(err)
const waitTime = 5 * time.Second const waitTime = 5 * time.Second
@@ -86,53 +84,84 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
crashed := false
select { select {
case <-l.restart: case <-l.start:
case <-ctx.Done(): case <-ctx.Done():
return return
} }
defer l.logger.Warn("loop exited") defer l.logger.Warn("loop exited")
enabled := true
for ctx.Err() == nil { for ctx.Err() == nil {
for !enabled { getCtx, getCancel := context.WithCancel(ctx)
// wait for a signal to re-enable defer getCancel()
select {
case <-l.stop: ipCh := make(chan net.IP)
l.logger.Info("already disabled") errorCh := make(chan error)
case <-l.restart: go func() {
enabled = true ip, err := l.getter.Get(getCtx)
case <-ctx.Done(): if err != nil {
errorCh <- err
return return
} }
ipCh <- ip
}()
if !crashed {
l.running <- constants.Running
crashed = false
} else {
l.state.setStatusWithLock(constants.Running)
} }
// Enabled and has a period set stayHere := true
for stayHere {
ip, err := l.getter.Get(ctx) select {
if err != nil { case <-ctx.Done():
l.logAndWait(ctx, err) l.logger.Warn("context canceled: exiting loop")
continue getCancel()
} close(errorCh)
l.logger.Info("Public IP address is %s", ip) filepath := l.GetSettings().IPFilepath
const userReadWritePermissions = 0600 l.logger.Info("Removing ip file %s", filepath)
err = l.fileManager.WriteLinesToFile( if err := l.fileManager.Remove(string(filepath)); err != nil {
string(l.ipStatusFilepath), l.logger.Error(err)
[]string{ip.String()}, }
files.Ownership(l.uid, l.gid), return
files.Permissions(userReadWritePermissions)) case <-l.start:
if err != nil { l.logger.Info("starting")
l.logAndWait(ctx, err) getCancel()
continue stayHere = false
} case <-l.stop:
select { l.logger.Info("stopping")
case <-l.restart: // triggered restart getCancel()
case <-l.stop: <-errorCh
enabled = false l.stopped <- struct{}{}
case <-ctx.Done(): case ip := <-ipCh:
return getCancel()
l.state.setPublicIP(ip)
l.logger.Info("Public IP address is %s", ip)
const userReadWritePermissions = 0600
err := l.fileManager.WriteLinesToFile(
string(l.state.settings.IPFilepath),
[]string{ip.String()},
files.Ownership(l.uid, l.gid),
files.Permissions(userReadWritePermissions))
if err != nil {
l.logger.Error(err)
}
l.state.setStatusWithLock(constants.Completed)
case err := <-errorCh:
getCancel()
close(ipCh)
l.state.setStatusWithLock(constants.Crashed)
l.logAndWait(ctx, err)
crashed = true
stayHere = false
}
} }
close(errorCh)
} }
} }
@@ -141,10 +170,9 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
timer := time.NewTimer(time.Hour) timer := time.NewTimer(time.Hour)
timer.Stop() // 1 hour, cannot be a race condition timer.Stop() // 1 hour, cannot be a race condition
timerIsStopped := true timerIsStopped := true
period := l.GetPeriod() if period := l.GetSettings().Period; period > 0 {
if period > 0 {
timer.Reset(period)
timerIsStopped = false timerIsStopped = false
timer.Reset(period)
} }
lastTick := time.Unix(0, 0) lastTick := time.Unix(0, 0)
for { for {
@@ -156,14 +184,14 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
return return
case <-timer.C: case <-timer.C:
lastTick = l.timeNow() lastTick = l.timeNow()
l.restart <- struct{}{} l.start <- struct{}{}
timer.Reset(l.GetPeriod()) timer.Reset(l.GetSettings().Period)
case <-l.updateTicker: case <-l.updateTicker:
if !timer.Stop() { if !timerIsStopped && !timer.Stop() {
<-timer.C <-timer.C
} }
timerIsStopped = true timerIsStopped = true
period := l.GetPeriod() period := l.GetSettings().Period
if period == 0 { if period == 0 {
continue continue
} }

110
internal/publicip/state.go Normal file
View File

@@ -0,0 +1,110 @@
package publicip
import (
"fmt"
"net"
"reflect"
"sync"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
)
type state struct {
status models.LoopStatus
settings settings.PublicIP
ip net.IP
statusMu sync.RWMutex
settingsMu sync.RWMutex
ipMu sync.RWMutex
}
func (s *state) setStatusWithLock(status models.LoopStatus) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.status = status
}
func (l *looper) GetStatus() (status models.LoopStatus) {
l.state.statusMu.RLock()
defer l.state.statusMu.RUnlock()
return l.state.status
}
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
l.state.statusMu.Lock()
defer l.state.statusMu.Unlock()
existingStatus := l.state.status
switch status {
case constants.Running:
switch existingStatus {
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Starting
l.state.statusMu.Unlock()
l.start <- struct{}{}
newStatus := <-l.running
l.state.statusMu.Lock()
l.state.status = newStatus
return newStatus.String(), nil
case constants.Stopped:
switch existingStatus {
case constants.Stopped, constants.Stopping, constants.Starting, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Stopping
l.state.statusMu.Unlock()
l.stop <- struct{}{}
<-l.stopped
l.state.statusMu.Lock()
l.state.status = status
return status.String(), nil
default:
return "", fmt.Errorf("status %q can only be %q or %q",
status, constants.Running, constants.Stopped)
}
}
func (l *looper) GetSettings() (settings settings.PublicIP) {
l.state.settingsMu.RLock()
defer l.state.settingsMu.RUnlock()
return l.state.settings
}
func (l *looper) SetSettings(settings settings.PublicIP) (outcome string) {
l.state.settingsMu.Lock()
defer l.state.settingsMu.Unlock()
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
if settingsUnchanged {
return "settings left unchanged"
}
periodChanged := l.state.settings.Period != settings.Period
l.state.settings = settings
if periodChanged {
l.updateTicker <- struct{}{}
// TODO blocking
}
return "settings updated"
}
func (l *looper) GetPublicIP() (publicIP net.IP) {
l.state.ipMu.RLock()
defer l.state.ipMu.RUnlock()
publicIP = make(net.IP, len(l.state.ip))
copy(publicIP, l.state.ip)
return publicIP
}
func (s *state) setPublicIP(publicIP net.IP) {
s.ipMu.Lock()
defer s.ipMu.Unlock()
s.ip = make(net.IP, len(publicIP))
copy(s.ip, publicIP)
}

76
internal/server/dns.go Normal file
View File

@@ -0,0 +1,76 @@
//nolint:dupl
package server
import (
"encoding/json"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/dns"
"github.com/qdm12/golibs/logging"
)
func newDNSHandler(looper dns.Looper, logger logging.Logger) http.Handler {
return &dnsHandler{
looper: looper,
logger: logger,
}
}
type dnsHandler struct {
looper dns.Looper
logger logging.Logger
}
func (h *dnsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/dns")
switch r.RequestURI {
case "/status": //nolint:goconst
switch r.Method {
case http.MethodGet:
h.getStatus(w)
case http.MethodPut:
h.setStatus(w, r)
default:
http.Error(w, "", http.StatusNotFound)
}
default:
http.Error(w, "", http.StatusNotFound)
}
}
func (h *dnsHandler) getStatus(w http.ResponseWriter) {
status := h.looper.GetStatus()
encoder := json.NewEncoder(w)
data := statusWrapper{Status: string(status)}
if err := encoder.Encode(data); err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func (h *dnsHandler) setStatus(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
var data statusWrapper
if err := decoder.Decode(&data); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
status, err := data.getStatus()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
outcome, err := h.looper.SetStatus(status)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
encoder := json.NewEncoder(w)
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
h.logger.Warn(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}

View File

@@ -1,12 +1,13 @@
package server package server
import ( import (
"fmt"
"net/http" "net/http"
"strings"
"github.com/qdm12/gluetun/internal/dns" "github.com/qdm12/gluetun/internal/dns"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -16,55 +17,36 @@ func newHandler(logger logging.Logger, logging bool,
openvpnLooper openvpn.Looper, openvpnLooper openvpn.Looper,
unboundLooper dns.Looper, unboundLooper dns.Looper,
updaterLooper updater.Looper, updaterLooper updater.Looper,
publicIPLooper publicip.Looper,
) http.Handler { ) http.Handler {
return &handler{ handler := &handler{}
logger: logger,
logging: logging, openvpn := newOpenvpnHandler(openvpnLooper, logger)
buildInfo: buildInfo, dns := newDNSHandler(unboundLooper, logger)
openvpnLooper: openvpnLooper, updater := newUpdaterHandler(updaterLooper, logger)
unboundLooper: unboundLooper, publicip := newPublicIPHandler(publicIPLooper, logger)
updaterLooper: updaterLooper,
} handler.v0 = newHandlerV0(logger, openvpnLooper, unboundLooper, updaterLooper)
handler.v1 = newHandlerV1(logger, buildInfo, openvpn, dns, updater, publicip)
handlerWithLog := withLogMiddleware(handler, logger, logging)
handler.setLogEnabled = handlerWithLog.setEnabled
return handlerWithLog
} }
type handler struct { type handler struct {
logger logging.Logger v0 http.Handler
logging bool v1 http.Handler
buildInfo models.BuildInformation setLogEnabled func(enabled bool)
openvpnLooper openvpn.Looper
unboundLooper dns.Looper
updaterLooper updater.Looper
} }
func (h *handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.logging { r.RequestURI = strings.TrimSuffix(r.RequestURI, "/")
h.logger.Info("HTTP %s %s", request.Method, request.RequestURI) if !strings.HasPrefix(r.RequestURI, "/v1/") && r.RequestURI != "/v1" {
} h.v0.ServeHTTP(w, r)
switch request.Method { return
case http.MethodGet:
switch request.RequestURI {
case "/version":
h.getVersion(responseWriter)
responseWriter.WriteHeader(http.StatusOK)
case "/openvpn/actions/restart":
h.openvpnLooper.Restart()
responseWriter.WriteHeader(http.StatusOK)
case "/unbound/actions/restart":
h.unboundLooper.Restart()
responseWriter.WriteHeader(http.StatusOK)
case "/openvpn/portforwarded":
h.getPortForwarded(responseWriter)
case "/openvpn/settings":
h.getOpenvpnSettings(responseWriter)
case "/updater/restart":
h.updaterLooper.Restart()
responseWriter.WriteHeader(http.StatusOK)
default:
errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI)
http.Error(responseWriter, errString, http.StatusBadRequest)
}
default:
errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI)
http.Error(responseWriter, errString, http.StatusBadRequest)
} }
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/v1")
h.v1.ServeHTTP(w, r)
} }

View File

@@ -0,0 +1,69 @@
package server
import (
"net/http"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/dns"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging"
)
func newHandlerV0(logger logging.Logger,
openvpn openvpn.Looper, dns dns.Looper, updater updater.Looper) http.Handler {
return &handlerV0{
logger: logger,
openvpn: openvpn,
dns: dns,
updater: updater,
}
}
type handlerV0 struct {
logger logging.Logger
openvpn openvpn.Looper
dns dns.Looper
updater updater.Looper
}
func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "unversioned API: only supports GET method", http.StatusNotFound)
return
}
switch r.RequestURI {
case "/version":
http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect)
case "/openvpn/actions/restart":
outcome, _ := h.openvpn.SetStatus(constants.Stopped)
h.logger.Info("openvpn: %s", outcome)
outcome, _ = h.openvpn.SetStatus(constants.Running)
h.logger.Info("openvpn: %s", outcome)
if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil {
h.logger.Warn(err)
}
case "/unbound/actions/restart":
outcome, _ := h.dns.SetStatus(constants.Stopped)
h.logger.Info("dns: %s", outcome)
outcome, _ = h.dns.SetStatus(constants.Running)
h.logger.Info("dns: %s", outcome)
if _, err := w.Write([]byte("dns restarted, please consider using the /v1/ API in the future.")); err != nil {
h.logger.Warn(err)
}
case "/openvpn/portforwarded":
http.Redirect(w, r, "/v1/openvpn/portforwarded", http.StatusPermanentRedirect)
case "/openvpn/settings":
http.Redirect(w, r, "/v1/openvpn/settings", http.StatusPermanentRedirect)
case "/updater/restart":
outcome, _ := h.updater.SetStatus(constants.Stopped)
h.logger.Info("updater: %s", outcome)
outcome, _ = h.updater.SetStatus(constants.Running)
h.logger.Info("updater: %s", outcome)
if _, err := w.Write([]byte("updater restarted, please consider using the /v1/ API in the future.")); err != nil {
h.logger.Warn(err)
}
default:
http.Error(w, "unversioned API: requested URI not found", http.StatusNotFound)
}
}

View File

@@ -0,0 +1,62 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/logging"
)
func newHandlerV1(logger logging.Logger, buildInfo models.BuildInformation,
openvpn, dns, updater, publicip http.Handler) http.Handler {
return &handlerV1{
logger: logger,
buildInfo: buildInfo,
openvpn: openvpn,
dns: dns,
updater: updater,
publicip: publicip,
}
}
type handlerV1 struct {
logger logging.Logger
buildInfo models.BuildInformation
openvpn http.Handler
dns http.Handler
updater http.Handler
publicip http.Handler
}
func (h *handlerV1) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch {
case r.RequestURI == "/version" && r.Method == http.MethodGet:
h.getVersion(w)
case strings.HasPrefix(r.RequestURI, "/openvpn"):
h.openvpn.ServeHTTP(w, r)
case strings.HasPrefix(r.RequestURI, "/dns"):
h.dns.ServeHTTP(w, r)
case strings.HasPrefix(r.RequestURI, "/updater"):
h.updater.ServeHTTP(w, r)
case strings.HasPrefix(r.RequestURI, "/publicip"):
h.publicip.ServeHTTP(w, r)
default:
errString := fmt.Sprintf("%s %s not found", r.Method, r.RequestURI)
http.Error(w, errString, http.StatusNotFound)
}
}
func (h *handlerV1) getVersion(w http.ResponseWriter) {
data, err := json.Marshal(h.buildInfo)
if err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if _, err := w.Write(data); err != nil {
h.logger.Warn(err)
}
}

75
internal/server/log.go Normal file
View File

@@ -0,0 +1,75 @@
package server
import (
"net/http"
"sync"
"time"
"github.com/qdm12/golibs/logging"
)
func withLogMiddleware(childHandler http.Handler, logger logging.Logger, enabled bool) *logMiddleware {
return &logMiddleware{
childHandler: childHandler,
logger: logger,
timeNow: time.Now,
enabled: enabled,
}
}
type logMiddleware struct {
childHandler http.Handler
logger logging.Logger
timeNow func() time.Time
enabled bool
enabledMu sync.RWMutex
}
func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !m.isEnabled() {
m.childHandler.ServeHTTP(w, r)
return
}
tStart := m.timeNow()
statefulWriter := &statefulResponseWriter{httpWriter: w}
m.childHandler.ServeHTTP(statefulWriter, r)
duration := m.timeNow().Sub(tStart)
m.logger.Info("%d %s %s wrote %dB to %s in %s",
statefulWriter.statusCode, r.Method, r.RequestURI, statefulWriter.length, r.RemoteAddr, duration)
}
func (m *logMiddleware) setEnabled(enabled bool) {
m.enabledMu.Lock()
defer m.enabledMu.Unlock()
m.enabled = enabled
}
func (m *logMiddleware) isEnabled() (enabled bool) {
m.enabledMu.RLock()
defer m.enabledMu.RUnlock()
return m.enabled
}
type statefulResponseWriter struct {
httpWriter http.ResponseWriter
statusCode int
length int
}
func (w *statefulResponseWriter) Write(b []byte) (n int, err error) {
n, err = w.httpWriter.Write(b)
if w.statusCode == 0 {
w.statusCode = http.StatusOK
}
w.length += n
return n, err
}
func (w *statefulResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.httpWriter.WriteHeader(statusCode)
}
func (w *statefulResponseWriter) Header() http.Header {
return w.httpWriter.Header()
}

View File

@@ -3,34 +3,110 @@ package server
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/golibs/logging"
) )
func (h *handler) getPortForwarded(w http.ResponseWriter) { func newOpenvpnHandler(looper openvpn.Looper, logger logging.Logger) http.Handler {
port := h.openvpnLooper.GetPortForwarded() return &openvpnHandler{
data, err := json.Marshal(struct { looper: looper,
Port uint16 `json:"port"` logger: logger,
}{port})
if err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if _, err := w.Write(data); err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
} }
} }
func (h *handler) getOpenvpnSettings(w http.ResponseWriter) { type openvpnHandler struct {
settings := h.openvpnLooper.GetSettings() looper openvpn.Looper
data, err := json.Marshal(settings) logger logging.Logger
if err != nil { }
func (h *openvpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/openvpn")
switch r.RequestURI {
case "/status":
switch r.Method {
case http.MethodGet:
h.getStatus(w)
case http.MethodPut:
h.setStatus(w, r)
default:
http.Error(w, "", http.StatusNotFound)
}
case "/settings":
switch r.Method {
case http.MethodGet:
h.getSettings(w)
default:
http.Error(w, "", http.StatusNotFound)
}
case "/portforwarded":
switch r.Method {
case http.MethodGet:
h.getPortForwarded(w)
default:
http.Error(w, "", http.StatusNotFound)
}
default:
http.Error(w, "", http.StatusNotFound)
}
}
func (h *openvpnHandler) getStatus(w http.ResponseWriter) {
status := h.looper.GetStatus()
encoder := json.NewEncoder(w)
data := statusWrapper{Status: string(status)}
if err := encoder.Encode(data); err != nil {
h.logger.Warn(err) h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
if _, err := w.Write(data); err != nil { }
func (h *openvpnHandler) setStatus(w http.ResponseWriter, r *http.Request) { //nolint:dupl
decoder := json.NewDecoder(r.Body)
var data statusWrapper
if err := decoder.Decode(&data); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
status, err := data.getStatus()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
outcome, err := h.looper.SetStatus(status)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
encoder := json.NewEncoder(w)
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
h.logger.Warn(err) h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
settings := h.looper.GetSettings()
settings.User = "redacted"
settings.Password = "redacted"
encoder := json.NewEncoder(w)
if err := encoder.Encode(settings); err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
port := h.looper.GetPortForwarded()
encoder := json.NewEncoder(w)
data := portWrapper{Port: port}
if err := encoder.Encode(data); err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
return
} }
} }

View File

@@ -0,0 +1,55 @@
//nolint:dupl
package server
import (
"encoding/json"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/golibs/logging"
)
func newPublicIPHandler(
looper publicip.Looper,
logger logging.Logger) http.Handler {
return &publicIPHandler{
looper: looper,
logger: logger,
}
}
type publicIPHandler struct {
looper publicip.Looper
logger logging.Logger
}
func (h *publicIPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/publicip")
switch r.RequestURI {
case "/ip":
switch r.Method {
case http.MethodGet:
h.getPublicIP(w)
default:
http.Error(w, "", http.StatusNotFound)
}
default:
http.Error(w, "", http.StatusNotFound)
}
}
type publicIPWrapper struct {
PublicIP string `json:"public_ip"`
}
func (h *publicIPHandler) getPublicIP(w http.ResponseWriter) {
publicIP := h.looper.GetPublicIP()
encoder := json.NewEncoder(w)
data := publicIPWrapper{PublicIP: publicIP.String()}
if err := encoder.Encode(data); err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/qdm12/gluetun/internal/dns" "github.com/qdm12/gluetun/internal/dns"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -23,10 +24,13 @@ type server struct {
handler http.Handler handler http.Handler
} }
func New(address string, logging bool, logger logging.Logger, buildInfo models.BuildInformation, func New(address string, logging bool, logger logging.Logger,
openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper) Server { buildInfo models.BuildInformation,
openvpnLooper openvpn.Looper, unboundLooper dns.Looper,
updaterLooper updater.Looper, publicIPLooper publicip.Looper) Server {
serverLogger := logger.WithPrefix("http server: ") serverLogger := logger.WithPrefix("http server: ")
handler := newHandler(serverLogger, logging, buildInfo, openvpnLooper, unboundLooper, updaterLooper) handler := newHandler(serverLogger, logging, buildInfo,
openvpnLooper, unboundLooper, updaterLooper, publicIPLooper)
return &server{ return &server{
address: address, address: address,
logger: serverLogger, logger: serverLogger,

View File

@@ -0,0 +1,78 @@
//nolint:dupl
package server
import (
"encoding/json"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging"
)
func newUpdaterHandler(
looper updater.Looper,
logger logging.Logger) http.Handler {
return &updaterHandler{
looper: looper,
logger: logger,
}
}
type updaterHandler struct {
looper updater.Looper
logger logging.Logger
}
func (h *updaterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.RequestURI = strings.TrimPrefix(r.RequestURI, "/updater")
switch r.RequestURI {
case "/status":
switch r.Method {
case http.MethodGet:
h.getStatus(w)
case http.MethodPut:
h.setStatus(w, r)
default:
http.Error(w, "", http.StatusNotFound)
}
default:
http.Error(w, "", http.StatusNotFound)
}
}
func (h *updaterHandler) getStatus(w http.ResponseWriter) {
status := h.looper.GetStatus()
encoder := json.NewEncoder(w)
data := statusWrapper{Status: string(status)}
if err := encoder.Encode(data); err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func (h *updaterHandler) setStatus(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
var data statusWrapper
if err := decoder.Decode(&data); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
status, err := data.getStatus()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
outcome, err := h.looper.SetStatus(status)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
encoder := json.NewEncoder(w)
if err := encoder.Encode(outcomeWrapper{Outcome: outcome}); err != nil {
h.logger.Warn(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}

View File

@@ -1,19 +0,0 @@
package server
import (
"encoding/json"
"net/http"
)
func (h *handler) getVersion(w http.ResponseWriter) {
data, err := json.Marshal(h.buildInfo)
if err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if _, err := w.Write(data); err != nil {
h.logger.Warn(err)
w.WriteHeader(http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,32 @@
package server
import (
"fmt"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
)
type statusWrapper struct {
Status string `json:"status"`
}
func (sw *statusWrapper) getStatus() (status models.LoopStatus, err error) {
status = models.LoopStatus(sw.Status)
switch status {
case constants.Stopped, constants.Running:
return status, nil
default:
return "", fmt.Errorf(
"invalid status %q: possible values are: %s, %s",
sw.Status, constants.Stopped, constants.Running)
}
}
type portWrapper struct {
Port uint16 `json:"port"`
}
type outcomeWrapper struct {
Outcome string `json:"outcome"`
}

View File

@@ -12,9 +12,9 @@ import (
// OpenVPN contains settings to configure the OpenVPN client. // OpenVPN contains settings to configure the OpenVPN client.
type OpenVPN struct { type OpenVPN struct {
User string `json:"user"` User string `json:"user"`
Password string `json:"-"` Password string `json:"password"`
Verbosity int `json:"verbosity"` Verbosity int `json:"verbosity"`
Root bool `json:"runAsRoot"` Root bool `json:"run_as_root"`
Cipher string `json:"cipher"` Cipher string `json:"cipher"`
Auth string `json:"auth"` Auth string `json:"auth"`
Provider models.ProviderSettings `json:"provider"` Provider models.ProviderSettings `json:"provider"`

View File

@@ -20,7 +20,7 @@ func Test_OpenVPN_JSON(t *testing.T) {
data, err := json.Marshal(in) data, err := json.Marshal(in)
require.NoError(t, err) require.NoError(t, err)
//nolint:lll //nolint:lll
assert.Equal(t, `{"user":"","verbosity":0,"runAsRoot":true,"cipher":"","auth":"","provider":{"name":"name","serverSelection":{"networkProtocol":"","regions":null,"group":"","countries":null,"cities":null,"hostnames":null,"isps":null,"owned":false,"customPort":0,"numbers":null,"encryptionPreset":""},"extraConfig":{"encryptionPreset":"","openvpnIPv6":false},"portForwarding":{"enabled":false,"filepath":""}}}`, string(data)) assert.Equal(t, `{"user":"","password":"","verbosity":0,"run_as_root":true,"cipher":"","auth":"","provider":{"name":"name","server_selection":{"network_protocol":"","regions":null,"group":"","countries":null,"cities":null,"hostnames":null,"isps":null,"owned":false,"custom_port":0,"numbers":null,"encryption_preset":""},"extra_config":{"encryption_preset":"","openvpn_ipv6":false},"port_forwarding":{"enabled":false,"filepath":""}}}`, string(data))
var out OpenVPN var out OpenVPN
err = json.Unmarshal(data, &out) err = json.Unmarshal(data, &out)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -213,7 +213,7 @@ func GetNordvpnSettings(paramsReader params.Reader) (settings models.ProviderSet
// GetPurevpnSettings obtains Purevpn settings from environment variables using the params package. // GetPurevpnSettings obtains Purevpn settings from environment variables using the params package.
func GetPurevpnSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) { func GetPurevpnSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
settings.Name = constants.Mullvad settings.Name = constants.Purevpn
settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol() settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol()
if err != nil { if err != nil {
return settings, err return settings, err

View File

@@ -0,0 +1,39 @@
package settings
import (
"fmt"
"strings"
"time"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/params"
)
type PublicIP struct {
Period time.Duration `json:"period"`
IPFilepath models.Filepath `json:"ip_filepath"`
}
func getPublicIPSettings(paramsReader params.Reader) (settings PublicIP, err error) {
settings.Period, err = paramsReader.GetPublicIPPeriod()
if err != nil {
return settings, err
}
settings.IPFilepath, err = paramsReader.GetPublicIPFilepath()
if err != nil {
return settings, err
}
return settings, nil
}
func (s *PublicIP) String() string {
if s.Period == 0 {
return "Public IP getter settings: disabled"
}
settingsList := []string{
"Public IP getter settings:",
fmt.Sprintf("Period: %s", s.Period),
fmt.Sprintf("IP file: %s", s.IPFilepath),
}
return strings.Join(settingsList, "\n|--")
}

View File

@@ -1,9 +1,7 @@
package settings package settings
import ( import (
"fmt"
"strings" "strings"
"time"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/params"
@@ -23,8 +21,8 @@ type Settings struct {
Firewall Firewall Firewall Firewall
HTTPProxy HTTPProxy HTTPProxy HTTPProxy
ShadowSocks ShadowSocks ShadowSocks ShadowSocks
PublicIPPeriod time.Duration Updater Updater
UpdaterPeriod time.Duration PublicIP PublicIP
VersionInformation bool VersionInformation bool
ControlServer ControlServer ControlServer ControlServer
} }
@@ -34,10 +32,6 @@ func (s *Settings) String() string {
if s.VersionInformation { if s.VersionInformation {
versionInformation = enabled versionInformation = enabled
} }
updaterLine := "Updater: disabled"
if s.UpdaterPeriod > 0 {
updaterLine = fmt.Sprintf("Updater period: %s", s.UpdaterPeriod)
}
return strings.Join([]string{ return strings.Join([]string{
"Settings summary below:", "Settings summary below:",
s.OpenVPN.String(), s.OpenVPN.String(),
@@ -47,9 +41,9 @@ func (s *Settings) String() string {
s.HTTPProxy.String(), s.HTTPProxy.String(),
s.ShadowSocks.String(), s.ShadowSocks.String(),
s.ControlServer.String(), s.ControlServer.String(),
"Public IP check period: " + s.PublicIPPeriod.String(), // TODO print disabled if 0 s.Updater.String(),
s.PublicIP.String(),
"Version information: " + versionInformation, "Version information: " + versionInformation,
updaterLine,
"", // new line at the end "", // new line at the end
}, "\n") }, "\n")
} }
@@ -85,7 +79,7 @@ func GetAllSettings(paramsReader params.Reader) (settings Settings, err error) {
if err != nil { if err != nil {
return settings, err return settings, err
} }
settings.PublicIPPeriod, err = paramsReader.GetPublicIPPeriod() settings.PublicIP, err = getPublicIPSettings(paramsReader)
if err != nil { if err != nil {
return settings, err return settings, err
} }
@@ -93,7 +87,7 @@ func GetAllSettings(paramsReader params.Reader) (settings Settings, err error) {
if err != nil { if err != nil {
return settings, err return settings, err
} }
settings.UpdaterPeriod, err = paramsReader.GetUpdaterPeriod() settings.Updater, err = GetUpdaterSettings(paramsReader)
if err != nil { if err != nil {
return settings, err return settings, err
} }

View File

@@ -4,16 +4,14 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/params"
) )
// System contains settings to configure system related elements. // System contains settings to configure system related elements.
type System struct { type System struct {
UID int UID int
GID int GID int
Timezone string Timezone string
IPStatusFilepath models.Filepath
} }
// GetSystemSettings obtains the System settings using the params functions. // GetSystemSettings obtains the System settings using the params functions.
@@ -30,10 +28,6 @@ func GetSystemSettings(paramsReader params.Reader) (settings System, err error)
if err != nil { if err != nil {
return settings, err return settings, err
} }
settings.IPStatusFilepath, err = paramsReader.GetIPStatusFilepath()
if err != nil {
return settings, err
}
return settings, nil return settings, nil
} }
@@ -43,7 +37,6 @@ func (s *System) String() string {
fmt.Sprintf("User ID: %d", s.UID), fmt.Sprintf("User ID: %d", s.UID),
fmt.Sprintf("Group ID: %d", s.GID), fmt.Sprintf("Group ID: %d", s.GID),
fmt.Sprintf("Timezone: %s", s.Timezone), fmt.Sprintf("Timezone: %s", s.Timezone),
fmt.Sprintf("IP Status filepath: %s", s.IPStatusFilepath),
} }
return strings.Join(settingsList, "\n|--") return strings.Join(settingsList, "\n|--")
} }

View File

@@ -0,0 +1,59 @@
package settings
import (
"fmt"
"strings"
"time"
"github.com/qdm12/gluetun/internal/params"
)
type Updater struct {
Period time.Duration `json:"period"`
DNSAddress string `json:"dns_address"`
Cyberghost bool `json:"cyberghost"`
Mullvad bool `json:"mullvad"`
Nordvpn bool `json:"nordvpn"`
PIA bool `json:"pia"`
Privado bool `json:"privado"`
Purevpn bool `json:"purevpn"`
Surfshark bool `json:"surfshark"`
Vyprvpn bool `json:"vyprvpn"`
Windscribe bool `json:"windscribe"`
// The two below should be used in CLI mode only
Stdout bool `json:"-"` // in order to update constants file (maintainer side)
CLI bool `json:"-"`
}
// GetUpdaterSettings obtains the server updater settings using the params functions.
func GetUpdaterSettings(paramsReader params.Reader) (settings Updater, err error) {
settings = Updater{
Cyberghost: true,
Mullvad: true,
Nordvpn: true,
PIA: true,
Purevpn: true,
Surfshark: true,
Vyprvpn: true,
Windscribe: true,
Stdout: false,
CLI: false,
DNSAddress: "127.0.0.1",
}
settings.Period, err = paramsReader.GetUpdaterPeriod()
if err != nil {
return settings, err
}
return settings, nil
}
func (s *Updater) String() string {
if s.Period == 0 {
return "Server updater settings: disabled"
}
settingsList := []string{
"Server updater settings:",
fmt.Sprintf("Period: %s", s.Period),
}
return strings.Join(settingsList, "\n|--")
}

View File

@@ -6,7 +6,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/gluetun/internal/storage" "github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -14,60 +16,54 @@ import (
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup)
Restart() GetStatus() (status models.LoopStatus)
Stop() SetStatus(status models.LoopStatus) (outcome string, err error)
GetPeriod() (period time.Duration) GetSettings() (settings settings.Updater)
SetPeriod(period time.Duration) SetSettings(settings settings.Updater) (outcome string)
} }
type looper struct { type looper struct {
period time.Duration state state
periodMutex sync.RWMutex // Objects
updater Updater updater Updater
storage storage.Storage storage storage.Storage
setAllServers func(allServers models.AllServers) setAllServers func(allServers models.AllServers)
logger logging.Logger logger logging.Logger
restart chan struct{} // Internal channels and locks
stop chan struct{} loopLock sync.Mutex
updateTicker chan struct{} start chan struct{}
timeNow func() time.Time running chan models.LoopStatus
timeSince func(time.Time) time.Duration stop chan struct{}
stopped chan struct{}
updateTicker chan struct{}
// Mock functions
timeNow func() time.Time
timeSince func(time.Time) time.Duration
} }
func NewLooper(options Options, period time.Duration, currentServers models.AllServers, func NewLooper(settings settings.Updater, currentServers models.AllServers,
storage storage.Storage, setAllServers func(allServers models.AllServers), storage storage.Storage, setAllServers func(allServers models.AllServers),
client *http.Client, logger logging.Logger) Looper { client *http.Client, logger logging.Logger) Looper {
loggerWithPrefix := logger.WithPrefix("updater: ") loggerWithPrefix := logger.WithPrefix("updater: ")
return &looper{ return &looper{
period: period, state: state{
updater: New(options, client, currentServers, loggerWithPrefix), status: constants.Stopped,
settings: settings,
},
updater: New(settings, client, currentServers, loggerWithPrefix),
storage: storage, storage: storage,
setAllServers: setAllServers, setAllServers: setAllServers,
logger: loggerWithPrefix, logger: loggerWithPrefix,
restart: make(chan struct{}), start: make(chan struct{}),
running: make(chan models.LoopStatus),
stop: make(chan struct{}), stop: make(chan struct{}),
stopped: make(chan struct{}),
updateTicker: make(chan struct{}), updateTicker: make(chan struct{}),
timeNow: time.Now, timeNow: time.Now,
timeSince: time.Since, timeSince: time.Since,
} }
} }
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) Stop() { l.stop <- struct{}{} }
func (l *looper) GetPeriod() (period time.Duration) {
l.periodMutex.RLock()
defer l.periodMutex.RUnlock()
return l.period
}
func (l *looper) SetPeriod(period time.Duration) {
l.periodMutex.Lock()
l.period = period
l.periodMutex.Unlock()
l.updateTicker <- struct{}{}
}
func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err) l.logger.Error(err)
const waitTime = 5 * time.Minute const waitTime = 5 * time.Minute
@@ -84,52 +80,70 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
crashed := false
select { select {
case <-l.restart: case <-l.start:
l.logger.Info("starting...")
case <-ctx.Done(): case <-ctx.Done():
return return
} }
defer l.logger.Warn("loop exited") defer l.logger.Warn("loop exited")
enabled := true
for ctx.Err() == nil { for ctx.Err() == nil {
for !enabled { updateCtx, updateCancel := context.WithCancel(ctx)
// wait for a signal to re-enable defer updateCancel()
serversCh := make(chan models.AllServers)
errorCh := make(chan error)
go func() {
servers, err := l.updater.UpdateServers(updateCtx)
if err != nil {
errorCh <- err
return
}
serversCh <- servers
}()
if !crashed {
l.running <- constants.Running
crashed = false
} else {
l.state.setStatusWithLock(constants.Running)
}
stayHere := true
for stayHere {
select { select {
case <-l.stop:
l.logger.Info("already disabled")
case <-l.restart:
enabled = true
case <-ctx.Done(): case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
updateCancel()
close(errorCh)
return return
case <-l.start:
l.logger.Info("starting")
updateCancel()
stayHere = false
case <-l.stop:
l.logger.Info("stopping")
updateCancel()
<-errorCh
l.stopped <- struct{}{}
case servers := <-serversCh:
updateCancel()
l.setAllServers(servers)
if err := l.storage.FlushToFile(servers); err != nil {
l.logger.Error(err)
}
l.state.setStatusWithLock(constants.Completed)
l.logger.Info("Updated servers information")
case err := <-errorCh:
updateCancel()
close(serversCh)
l.state.setStatusWithLock(constants.Crashed)
l.logAndWait(ctx, err)
crashed = true
stayHere = false
} }
} }
close(errorCh)
// Enabled and has a period set
servers, err := l.updater.UpdateServers(ctx)
if err != nil {
if ctx.Err() != nil {
return
}
l.logAndWait(ctx, err)
continue
}
l.setAllServers(servers)
if err := l.storage.FlushToFile(servers); err != nil {
l.logger.Error(err)
}
l.logger.Info("Updated servers information")
select {
case <-l.restart: // triggered restart
case <-l.stop:
enabled = false
case <-ctx.Done():
return
}
} }
} }
@@ -138,7 +152,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
timer := time.NewTimer(time.Hour) timer := time.NewTimer(time.Hour)
timer.Stop() timer.Stop()
timerIsStopped := true timerIsStopped := true
if period := l.GetPeriod(); period > 0 { if period := l.GetSettings().Period; period > 0 {
timerIsStopped = false timerIsStopped = false
timer.Reset(period) timer.Reset(period)
} }
@@ -152,14 +166,14 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
return return
case <-timer.C: case <-timer.C:
lastTick = l.timeNow() lastTick = l.timeNow()
l.restart <- struct{}{} l.start <- struct{}{}
timer.Reset(l.GetPeriod()) timer.Reset(l.GetSettings().Period)
case <-l.updateTicker: case <-l.updateTicker:
if !timerIsStopped && !timer.Stop() { if !timerIsStopped && !timer.Stop() {
<-timer.C <-timer.C
} }
timerIsStopped = true timerIsStopped = true
period := l.GetPeriod() period := l.GetSettings().Period
if period == 0 { if period == 0 {
continue continue
} }

View File

@@ -1,32 +0,0 @@
package updater
type Options struct {
Cyberghost bool
Mullvad bool
Nordvpn bool
PIA bool
Privado bool
Purevpn bool
Surfshark bool
Vyprvpn bool
Windscribe bool
Stdout bool // in order to update constants file (maintainer side)
CLI bool
DNSAddress string
}
func NewOptions(dnsAddress string) Options {
return Options{
Cyberghost: true,
Mullvad: true,
Nordvpn: true,
PIA: true,
Purevpn: true,
Surfshark: true,
Vyprvpn: true,
Windscribe: true,
Stdout: false,
CLI: false,
DNSAddress: dnsAddress,
}
}

88
internal/updater/state.go Normal file
View File

@@ -0,0 +1,88 @@
package updater
import (
"fmt"
"reflect"
"sync"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
)
type state struct {
status models.LoopStatus
settings settings.Updater
statusMu sync.RWMutex
periodMu sync.RWMutex
}
func (s *state) setStatusWithLock(status models.LoopStatus) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.status = status
}
func (l *looper) GetStatus() (status models.LoopStatus) {
l.state.statusMu.RLock()
defer l.state.statusMu.RUnlock()
return l.state.status
}
func (l *looper) SetStatus(status models.LoopStatus) (outcome string, err error) {
l.state.statusMu.Lock()
defer l.state.statusMu.Unlock()
existingStatus := l.state.status
switch status {
case constants.Running:
switch existingStatus {
case constants.Starting, constants.Running, constants.Stopping, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Starting
l.state.statusMu.Unlock()
l.start <- struct{}{}
newStatus := <-l.running
l.state.statusMu.Lock()
l.state.status = newStatus
return newStatus.String(), nil
case constants.Stopped:
switch existingStatus {
case constants.Stopped, constants.Stopping, constants.Starting, constants.Crashed:
return fmt.Sprintf("already %s", existingStatus), nil
}
l.loopLock.Lock()
defer l.loopLock.Unlock()
l.state.status = constants.Stopping
l.state.statusMu.Unlock()
l.stop <- struct{}{}
<-l.stopped
l.state.statusMu.Lock()
l.state.status = status
return status.String(), nil
default:
return "", fmt.Errorf("status %q can only be %q or %q",
status, constants.Running, constants.Stopped)
}
}
func (l *looper) GetSettings() (settings settings.Updater) {
l.state.periodMu.RLock()
defer l.state.periodMu.RUnlock()
return l.state.settings
}
func (l *looper) SetSettings(settings settings.Updater) (outcome string) {
l.state.periodMu.Lock()
defer l.state.periodMu.Unlock()
settingsUnchanged := reflect.DeepEqual(settings, l.state.settings)
if settingsUnchanged {
return "settings left unchanged"
}
l.state.settings = settings
l.updateTicker <- struct{}{}
return "settings updated"
}

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
) )
@@ -17,7 +18,7 @@ type Updater interface {
type updater struct { type updater struct {
// configuration // configuration
options Options options settings.Updater
// state // state
servers models.AllServers servers models.AllServers
@@ -30,11 +31,12 @@ type updater struct {
client network.Client client network.Client
} }
func New(options Options, httpClient *http.Client, currentServers models.AllServers, logger logging.Logger) Updater { func New(settings settings.Updater, httpClient *http.Client,
if len(options.DNSAddress) == 0 { currentServers models.AllServers, logger logging.Logger) Updater {
options.DNSAddress = "1.1.1.1" if len(settings.DNSAddress) == 0 {
settings.DNSAddress = "1.1.1.1"
} }
resolver := newResolver(options.DNSAddress) resolver := newResolver(settings.DNSAddress)
const clientTimeout = 10 * time.Second const clientTimeout = 10 * time.Second
return &updater{ return &updater{
logger: logger, logger: logger,
@@ -42,7 +44,7 @@ func New(options Options, httpClient *http.Client, currentServers models.AllServ
println: func(s string) { fmt.Println(s) }, println: func(s string) { fmt.Println(s) },
lookupIP: newLookupIP(resolver), lookupIP: newLookupIP(resolver),
client: network.NewClient(clientTimeout), client: network.NewClient(clientTimeout),
options: options, options: settings,
servers: currentServers, servers: currentServers,
} }
} }