Compare commits
5 Commits
master
...
remove-kee
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f0b8f7292 | ||
|
|
25b381e138 | ||
|
|
35b6b709b2 | ||
|
|
40ea51a3ae | ||
|
|
1a93a41a55 |
5
.github/workflows/ci.yml
vendored
5
.github/workflows/ci.yml
vendored
@@ -93,9 +93,6 @@ jobs:
|
|||||||
- name: Run Gluetun container with Mullvad configuration
|
- name: Run Gluetun container with Mullvad configuration
|
||||||
run: echo -e "${{ secrets.MULLVAD_WIREGUARD_PRIVATE_KEY }}\n${{ secrets.MULLVAD_WIREGUARD_ADDRESS }}" | ./ci/runner mullvad
|
run: echo -e "${{ secrets.MULLVAD_WIREGUARD_PRIVATE_KEY }}\n${{ secrets.MULLVAD_WIREGUARD_ADDRESS }}" | ./ci/runner mullvad
|
||||||
|
|
||||||
- name: Run Gluetun container with ProtonVPN configuration
|
|
||||||
run: echo -e "${{ secrets.PROTONVPN_WIREGUARD_PRIVATE_KEY }}" | ./ci/runner protonvpn
|
|
||||||
|
|
||||||
codeql:
|
codeql:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
@@ -121,7 +118,7 @@ jobs:
|
|||||||
github.event_name == 'release' ||
|
github.event_name == 'release' ||
|
||||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
|
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && github.actor != 'dependabot[bot]')
|
||||||
)
|
)
|
||||||
needs: [verify, verify-private, codeql]
|
needs: [verify, codeql]
|
||||||
permissions:
|
permissions:
|
||||||
actions: read
|
actions: read
|
||||||
contents: read
|
contents: read
|
||||||
|
|||||||
11
Dockerfile
11
Dockerfile
@@ -163,12 +163,10 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
LOG_LEVEL=info \
|
LOG_LEVEL=info \
|
||||||
# Health
|
# Health
|
||||||
HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \
|
HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \
|
||||||
HEALTH_TARGET_ADDRESSES=cloudflare.com:443,github.com:443 \
|
HEALTH_TARGET_ADDRESS=cloudflare.com:443 \
|
||||||
HEALTH_ICMP_TARGET_IPS=1.1.1.1,8.8.8.8 \
|
HEALTH_ICMP_TARGET_IP=1.1.1.1 \
|
||||||
HEALTH_SMALL_CHECK_TYPE=icmp \
|
|
||||||
HEALTH_RESTART_VPN=on \
|
HEALTH_RESTART_VPN=on \
|
||||||
# DNS
|
# DNS
|
||||||
DNS_SERVER=on \
|
|
||||||
DNS_UPSTREAM_RESOLVER_TYPE=DoT \
|
DNS_UPSTREAM_RESOLVER_TYPE=DoT \
|
||||||
DNS_UPSTREAM_RESOLVERS=cloudflare \
|
DNS_UPSTREAM_RESOLVERS=cloudflare \
|
||||||
DNS_BLOCK_IPS= \
|
DNS_BLOCK_IPS= \
|
||||||
@@ -181,8 +179,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
DNS_UNBLOCK_HOSTNAMES= \
|
DNS_UNBLOCK_HOSTNAMES= \
|
||||||
DNS_REBINDING_PROTECTION_EXEMPT_HOSTNAMES= \
|
DNS_REBINDING_PROTECTION_EXEMPT_HOSTNAMES= \
|
||||||
DNS_UPDATE_PERIOD=24h \
|
DNS_UPDATE_PERIOD=24h \
|
||||||
DNS_ADDRESS=127.0.0.1 \
|
DNS_UPSTREAM_PLAIN_ADDRESSES= \
|
||||||
DNS_KEEP_NAMESERVER=off \
|
|
||||||
# HTTP proxy
|
# HTTP proxy
|
||||||
HTTPPROXY= \
|
HTTPPROXY= \
|
||||||
HTTPPROXY_LOG=off \
|
HTTPPROXY_LOG=off \
|
||||||
@@ -208,7 +205,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
UPDATER_PERIOD=0 \
|
UPDATER_PERIOD=0 \
|
||||||
UPDATER_MIN_RATIO=0.8 \
|
UPDATER_MIN_RATIO=0.8 \
|
||||||
UPDATER_VPN_SERVICE_PROVIDERS= \
|
UPDATER_VPN_SERVICE_PROVIDERS= \
|
||||||
UPDATER_PROTONVPN_EMAIL= \
|
UPDATER_PROTONVPN_USERNAME= \
|
||||||
UPDATER_PROTONVPN_PASSWORD= \
|
UPDATER_PROTONVPN_PASSWORD= \
|
||||||
# Public IP
|
# Public IP
|
||||||
PUBLICIP_FILE="/tmp/gluetun/ip" \
|
PUBLICIP_FILE="/tmp/gluetun/ip" \
|
||||||
|
|||||||
@@ -21,8 +21,6 @@ func main() {
|
|||||||
switch os.Args[1] {
|
switch os.Args[1] {
|
||||||
case "mullvad":
|
case "mullvad":
|
||||||
err = internal.MullvadTest(ctx)
|
err = internal.MullvadTest(ctx)
|
||||||
case "protonvpn":
|
|
||||||
err = internal.ProtonVPNTest(ctx)
|
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unknown command: %s", os.Args[1])
|
err = fmt.Errorf("unknown command: %s", os.Args[1])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +1,193 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/docker/docker/api/types/container"
|
||||||
|
"github.com/docker/docker/api/types/network"
|
||||||
|
"github.com/docker/docker/client"
|
||||||
|
v1 "github.com/opencontainers/image-spec/specs-go/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
func MullvadTest(ctx context.Context) error {
|
func MullvadTest(ctx context.Context) error {
|
||||||
expectedSecrets := []string{
|
secrets, err := readSecrets(ctx)
|
||||||
"Wireguard private key",
|
|
||||||
"Wireguard address",
|
|
||||||
}
|
|
||||||
secrets, err := readSecrets(ctx, expectedSecrets)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reading secrets: %w", err)
|
return fmt.Errorf("reading secrets: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
env := []string{
|
const timeout = 15 * time.Second
|
||||||
"VPN_SERVICE_PROVIDER=mullvad",
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
"VPN_TYPE=wireguard",
|
defer cancel()
|
||||||
"LOG_LEVEL=debug",
|
|
||||||
"SERVER_COUNTRIES=USA",
|
client, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||||
"WIREGUARD_PRIVATE_KEY=" + secrets[0],
|
if err != nil {
|
||||||
"WIREGUARD_ADDRESSES=" + secrets[1],
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
config := &container.Config{
|
||||||
|
Image: "qmcgaw/gluetun",
|
||||||
|
StopTimeout: ptrTo(3),
|
||||||
|
Env: []string{
|
||||||
|
"VPN_SERVICE_PROVIDER=mullvad",
|
||||||
|
"VPN_TYPE=wireguard",
|
||||||
|
"LOG_LEVEL=debug",
|
||||||
|
"SERVER_COUNTRIES=USA",
|
||||||
|
"WIREGUARD_PRIVATE_KEY=" + secrets.mullvadWireguardPrivateKey,
|
||||||
|
"WIREGUARD_ADDRESSES=" + secrets.mullvadWireguardAddress,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
hostConfig := &container.HostConfig{
|
||||||
|
AutoRemove: true,
|
||||||
|
CapAdd: []string{"NET_ADMIN", "NET_RAW"},
|
||||||
|
}
|
||||||
|
networkConfig := (*network.NetworkingConfig)(nil)
|
||||||
|
platform := (*v1.Platform)(nil)
|
||||||
|
const containerName = "" // auto-generated name
|
||||||
|
|
||||||
|
response, err := client.ContainerCreate(ctx, config, hostConfig, networkConfig, platform, containerName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating container: %w", err)
|
||||||
|
}
|
||||||
|
for _, warning := range response.Warnings {
|
||||||
|
fmt.Println("Warning during container creation:", warning)
|
||||||
|
}
|
||||||
|
containerID := response.ID
|
||||||
|
defer stopContainer(client, containerID)
|
||||||
|
|
||||||
|
beforeStartTime := time.Now()
|
||||||
|
|
||||||
|
err = client.ContainerStart(ctx, containerID, container.StartOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting container: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return waitForLogLine(ctx, client, containerID, beforeStartTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrTo[T any](v T) *T { return &v }
|
||||||
|
|
||||||
|
type secrets struct {
|
||||||
|
mullvadWireguardPrivateKey string
|
||||||
|
mullvadWireguardAddress string
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSecrets(ctx context.Context) (secrets, error) {
|
||||||
|
expectedSecrets := [...]string{
|
||||||
|
"Mullvad Wireguard private key",
|
||||||
|
"Mullvad Wireguard address",
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(os.Stdin)
|
||||||
|
lines := make([]string, 0, len(expectedSecrets))
|
||||||
|
|
||||||
|
for i := range expectedSecrets {
|
||||||
|
fmt.Println("🤫 reading", expectedSecrets[i], "from Stdin...")
|
||||||
|
if !scanner.Scan() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
lines = append(lines, strings.TrimSpace(scanner.Text()))
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return secrets{}, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return secrets{}, fmt.Errorf("reading secrets from stdin: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(lines) < len(expectedSecrets) {
|
||||||
|
return secrets{}, fmt.Errorf("expected %d secrets via Stdin, but only received %d",
|
||||||
|
len(expectedSecrets), len(lines))
|
||||||
|
}
|
||||||
|
for i, line := range lines {
|
||||||
|
if line == "" {
|
||||||
|
return secrets{}, fmt.Errorf("secret on line %d/%d was empty", i+1, len(lines))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return secrets{
|
||||||
|
mullvadWireguardPrivateKey: lines[0],
|
||||||
|
mullvadWireguardAddress: lines[1],
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopContainer(client *client.Client, containerID string) {
|
||||||
|
const stopTimeout = 5 * time.Second // must be higher than 3s, see above [container.Config]'s StopTimeout field
|
||||||
|
stopCtx, stopCancel := context.WithTimeout(context.Background(), stopTimeout)
|
||||||
|
defer stopCancel()
|
||||||
|
|
||||||
|
err := client.ContainerStop(stopCtx, containerID, container.StopOptions{})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("failed to stop container:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var successRegexp = regexp.MustCompile(`^.+Public IP address is .+$`)
|
||||||
|
|
||||||
|
func waitForLogLine(ctx context.Context, client *client.Client, containerID string,
|
||||||
|
beforeStartTime time.Time,
|
||||||
|
) error {
|
||||||
|
logOptions := container.LogsOptions{
|
||||||
|
ShowStdout: true,
|
||||||
|
Follow: true,
|
||||||
|
Since: beforeStartTime.Format(time.RFC3339Nano),
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, err := client.ContainerLogs(ctx, containerID, logOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting container logs: %w", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
var linesSeen []string
|
||||||
|
scanner := bufio.NewScanner(reader)
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
if scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if len(line) > 8 { // remove Docker log prefix
|
||||||
|
line = line[8:]
|
||||||
|
}
|
||||||
|
linesSeen = append(linesSeen, line)
|
||||||
|
if successRegexp.MatchString(line) {
|
||||||
|
fmt.Println("✅ Success line logged")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err := scanner.Err()
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
logSeenLines(linesSeen)
|
||||||
|
return fmt.Errorf("reading log stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The scanner is either done or cannot read because of EOF
|
||||||
|
fmt.Println("The log scanner stopped")
|
||||||
|
logSeenLines(linesSeen)
|
||||||
|
|
||||||
|
// Check if the container is still running
|
||||||
|
inspect, err := client.ContainerInspect(ctx, containerID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inspecting container: %w", err)
|
||||||
|
}
|
||||||
|
if !inspect.State.Running {
|
||||||
|
return fmt.Errorf("container stopped unexpectedly while waiting for log line. Exit code: %d", inspect.State.ExitCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func logSeenLines(lines []string) {
|
||||||
|
fmt.Println("Logs seen so far:")
|
||||||
|
for _, line := range lines {
|
||||||
|
fmt.Println(" " + line)
|
||||||
}
|
}
|
||||||
return simpleTest(ctx, env)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ProtonVPNTest(ctx context.Context) error {
|
|
||||||
expectedSecrets := []string{
|
|
||||||
"Wireguard private key",
|
|
||||||
}
|
|
||||||
secrets, err := readSecrets(ctx, expectedSecrets)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("reading secrets: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
env := []string{
|
|
||||||
"VPN_SERVICE_PROVIDER=protonvpn",
|
|
||||||
"VPN_TYPE=wireguard",
|
|
||||||
"LOG_LEVEL=debug",
|
|
||||||
"SERVER_COUNTRIES=United States",
|
|
||||||
"WIREGUARD_PRIVATE_KEY=" + secrets[0],
|
|
||||||
}
|
|
||||||
return simpleTest(ctx, env)
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func readSecrets(ctx context.Context, expectedSecrets []string) (lines []string, err error) {
|
|
||||||
scanner := bufio.NewScanner(os.Stdin)
|
|
||||||
lines = make([]string, 0, len(expectedSecrets))
|
|
||||||
|
|
||||||
for i := range expectedSecrets {
|
|
||||||
fmt.Println("🤫 reading", expectedSecrets[i], "from Stdin...")
|
|
||||||
if !scanner.Scan() {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
lines = append(lines, strings.TrimSpace(scanner.Text()))
|
|
||||||
fmt.Println("🤫 "+expectedSecrets[i], "secret read successfully")
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("reading secrets from stdin: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(lines) < len(expectedSecrets) {
|
|
||||||
return nil, fmt.Errorf("expected %d secrets via Stdin, but only received %d",
|
|
||||||
len(expectedSecrets), len(lines))
|
|
||||||
}
|
|
||||||
for i, line := range lines {
|
|
||||||
if line == "" {
|
|
||||||
return nil, fmt.Errorf("secret on line %d/%d was empty", i+1, len(lines))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return lines, nil
|
|
||||||
}
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"regexp"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/docker/docker/api/types/container"
|
|
||||||
"github.com/docker/docker/api/types/network"
|
|
||||||
"github.com/docker/docker/client"
|
|
||||||
v1 "github.com/opencontainers/image-spec/specs-go/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ptrTo[T any](v T) *T { return &v }
|
|
||||||
|
|
||||||
func simpleTest(ctx context.Context, env []string) error {
|
|
||||||
const timeout = 30 * time.Second
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
client, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating Docker client: %w", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
config := &container.Config{
|
|
||||||
Image: "qmcgaw/gluetun",
|
|
||||||
StopTimeout: ptrTo(3),
|
|
||||||
Env: env,
|
|
||||||
}
|
|
||||||
hostConfig := &container.HostConfig{
|
|
||||||
AutoRemove: true,
|
|
||||||
CapAdd: []string{"NET_ADMIN", "NET_RAW"},
|
|
||||||
}
|
|
||||||
networkConfig := (*network.NetworkingConfig)(nil)
|
|
||||||
platform := (*v1.Platform)(nil)
|
|
||||||
const containerName = "" // auto-generated name
|
|
||||||
|
|
||||||
response, err := client.ContainerCreate(ctx, config, hostConfig, networkConfig, platform, containerName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating container: %w", err)
|
|
||||||
}
|
|
||||||
for _, warning := range response.Warnings {
|
|
||||||
fmt.Println("Warning during container creation:", warning)
|
|
||||||
}
|
|
||||||
containerID := response.ID
|
|
||||||
defer stopContainer(client, containerID)
|
|
||||||
|
|
||||||
beforeStartTime := time.Now()
|
|
||||||
|
|
||||||
err = client.ContainerStart(ctx, containerID, container.StartOptions{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("starting container: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return waitForLogLine(ctx, client, containerID, beforeStartTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
func stopContainer(client *client.Client, containerID string) {
|
|
||||||
const stopTimeout = 5 * time.Second // must be higher than 3s, see above [container.Config]'s StopTimeout field
|
|
||||||
stopCtx, stopCancel := context.WithTimeout(context.Background(), stopTimeout)
|
|
||||||
defer stopCancel()
|
|
||||||
|
|
||||||
err := client.ContainerStop(stopCtx, containerID, container.StopOptions{})
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("failed to stop container:", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var successRegexp = regexp.MustCompile(`^.+Public IP address is .+$`)
|
|
||||||
|
|
||||||
func waitForLogLine(ctx context.Context, client *client.Client, containerID string,
|
|
||||||
beforeStartTime time.Time,
|
|
||||||
) error {
|
|
||||||
logOptions := container.LogsOptions{
|
|
||||||
ShowStdout: true,
|
|
||||||
Follow: true,
|
|
||||||
Since: beforeStartTime.Format(time.RFC3339Nano),
|
|
||||||
}
|
|
||||||
|
|
||||||
reader, err := client.ContainerLogs(ctx, containerID, logOptions)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error getting container logs: %w", err)
|
|
||||||
}
|
|
||||||
defer reader.Close()
|
|
||||||
|
|
||||||
var linesSeen []string
|
|
||||||
scanner := bufio.NewScanner(reader)
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
if scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if len(line) > 8 { // remove Docker log prefix
|
|
||||||
line = line[8:]
|
|
||||||
}
|
|
||||||
linesSeen = append(linesSeen, line)
|
|
||||||
if successRegexp.MatchString(line) {
|
|
||||||
fmt.Println("✅ Success line logged")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err := scanner.Err()
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
logSeenLines(linesSeen)
|
|
||||||
return fmt.Errorf("reading log stream: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The scanner is either done or cannot read because of EOF
|
|
||||||
fmt.Println("The log scanner stopped")
|
|
||||||
logSeenLines(linesSeen)
|
|
||||||
|
|
||||||
// Check if the container is still running
|
|
||||||
inspect, err := client.ContainerInspect(ctx, containerID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("inspecting container: %w", err)
|
|
||||||
}
|
|
||||||
if !inspect.State.Running {
|
|
||||||
return fmt.Errorf("container stopped unexpectedly while waiting for log line. Exit code: %d", inspect.State.ExitCode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func logSeenLines(lines []string) {
|
|
||||||
fmt.Println("Logs seen so far:")
|
|
||||||
for _, line := range lines {
|
|
||||||
fmt.Println(" " + line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -164,8 +164,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
defer fmt.Println(gluetunLogo)
|
|
||||||
|
|
||||||
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -602,34 +600,3 @@ type RunStarter interface {
|
|||||||
Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string,
|
Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string,
|
||||||
waitError <-chan error, err error)
|
waitError <-chan error, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
const gluetunLogo = ` @@@
|
|
||||||
@@@@
|
|
||||||
@@@@@@
|
|
||||||
@@@@.@@ @@@@@@@@@@
|
|
||||||
@@@@.@@@ @@@@@@@@==@@@@
|
|
||||||
@@@.@..@@ @@@@@@@=@..==@@@@
|
|
||||||
@@@@ @@@.@@.@@ @@@@@@===@@@@.=@@@
|
|
||||||
@...-@@ @@@@.@@.@@@ @@@ @@@@@@=======@@@=@@@@
|
|
||||||
@@@@@@@@ @@@.-%@.+@@@@@@@@ @@@@@%============@@@@
|
|
||||||
@@@.--@..@@@@.-@@@@@@@==============@@@@
|
|
||||||
@@@@ @@@-@--@@.@@.---@@@@@==============#@@@@@
|
|
||||||
@@@ @@@.@@-@@.@@--@@@@@===============@@@@@@
|
|
||||||
@@@@.@--@@@@@@@@@@================@@@@@@@
|
|
||||||
@@@..--@@*@@@@@@================@@@@+*@@
|
|
||||||
@@@.---@@.@@@@=================@@@@--@@
|
|
||||||
@@@-.---@@@@@@================@@@@*--@@@
|
|
||||||
@@@.:-#@@@@@@===============*@@@@.---@@
|
|
||||||
@@@.-------.@@@============@@@@@@.--@@@
|
|
||||||
@@@..--------:@@@=========@@@@@@@@.--@@@
|
|
||||||
@@@.-@@@@@@@@@@@========@@@@@ @@@.--@@
|
|
||||||
@@.@@@@===============@@@@@ @@@@@@---@@@@@@
|
|
||||||
@@@@@@@==============@@@@@@@@@@@@*@---@@@@@@@@
|
|
||||||
@@@@@@=============@@@@@ @@@...------------.*@@@
|
|
||||||
@@@@%===========@@@@@@ @@@..------@@@@.-----.-@@@
|
|
||||||
@@@@@@.=======@@@@@@ @@@.-------@@@@@@-.------=@@
|
|
||||||
@@@@@@@@@===@@@@@@ @@.------@@@@ @@@@.-----@@@
|
|
||||||
@@@==@@@=@@@@@@@ @@@.-@@@@@@@ @@@@@@@--@@
|
|
||||||
@@@@@@@@@@@@@ @@@@@@@@ @@@@@@@
|
|
||||||
@@@@@@@@ @@@@ @@@@
|
|
||||||
`
|
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -10,7 +10,7 @@ require (
|
|||||||
github.com/klauspost/compress v1.18.1
|
github.com/klauspost/compress v1.18.1
|
||||||
github.com/klauspost/pgzip v1.2.6
|
github.com/klauspost/pgzip v1.2.6
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4
|
github.com/pelletier/go-toml/v2 v2.2.4
|
||||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251123213823-54e987293e88
|
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f
|
||||||
github.com/qdm12/gosettings v0.4.4
|
github.com/qdm12/gosettings v0.4.4
|
||||||
github.com/qdm12/goshutdown v0.3.0
|
github.com/qdm12/goshutdown v0.3.0
|
||||||
github.com/qdm12/gosplash v0.2.0
|
github.com/qdm12/gosplash v0.2.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -69,8 +69,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA
|
|||||||
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
|
github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw=
|
||||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251123213823-54e987293e88 h1:GJ5FALvJ3UmHjVaNYebrfV5zF5You4dq8HfRWZy2loM=
|
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f h1:6wN5D9wACfmXDsQ366egVt0jXY4nqL/QnIwg4nWhXco=
|
||||||
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251123213823-54e987293e88/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
|
github.com/qdm12/dns/v2 v2.0.0-rc9.0.20251114155417-248acd28339f/go.mod h1:98foWgXJZ+g8gJIuO+fdO+oWpFei5WShMFTeN4Im2lE=
|
||||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
|
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978 h1:TRGpCU1l0lNwtogEUSs5U+RFceYxkAJUmrGabno7J5c=
|
||||||
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
|
github.com/qdm12/goservices v0.1.1-0.20251104135713-6bee97bd4978/go.mod h1:D1Po4CRQLYjccnAR2JsVlN1sBMgQrcNLONbvyuzcdTg=
|
||||||
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
|
github.com/qdm12/gosettings v0.4.4 h1:SM6tOZDf6k8qbjWU8KWyBF4mWIixfsKCfh9DGRLHlj4=
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ type UpdaterLogger interface {
|
|||||||
func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error {
|
func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) error {
|
||||||
options := settings.Updater{}
|
options := settings.Updater{}
|
||||||
var endUserMode, maintainerMode, updateAll bool
|
var endUserMode, maintainerMode, updateAll bool
|
||||||
var csvProviders, ipToken, protonUsername, protonEmail, protonPassword string
|
var csvProviders, ipToken, protonUsername, protonPassword string
|
||||||
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
|
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
|
||||||
flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)")
|
flagSet.BoolVar(&endUserMode, "enduser", false, "Write results to /gluetun/servers.json (for end users)")
|
||||||
flagSet.BoolVar(&maintainerMode, "maintainer", false,
|
flagSet.BoolVar(&maintainerMode, "maintainer", false,
|
||||||
@@ -50,9 +50,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
|||||||
flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers")
|
flagSet.BoolVar(&updateAll, "all", false, "Update servers for all VPN providers")
|
||||||
flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for")
|
flagSet.StringVar(&csvProviders, "providers", "", "CSV string of VPN providers to update server data for")
|
||||||
flagSet.StringVar(&ipToken, "ip-token", "", "IP data service token (e.g. ipinfo.io) to use")
|
flagSet.StringVar(&ipToken, "ip-token", "", "IP data service token (e.g. ipinfo.io) to use")
|
||||||
flagSet.StringVar(&protonUsername, "proton-username", "",
|
flagSet.StringVar(&protonUsername, "proton-username", "", "Username to use to authenticate with Proton")
|
||||||
"(Retro-compatibility) Username to use to authenticate with Proton. Use -proton-email instead.") // v4 remove this
|
|
||||||
flagSet.StringVar(&protonEmail, "proton-email", "", "Email to use to authenticate with Proton")
|
|
||||||
flagSet.StringVar(&protonPassword, "proton-password", "", "Password to use to authenticate with Proton")
|
flagSet.StringVar(&protonPassword, "proton-password", "", "Password to use to authenticate with Proton")
|
||||||
if err := flagSet.Parse(args); err != nil {
|
if err := flagSet.Parse(args); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -72,12 +70,7 @@ func (c *CLI) Update(ctx context.Context, args []string, logger UpdaterLogger) e
|
|||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(options.Providers, providers.Protonvpn) {
|
if slices.Contains(options.Providers, providers.Protonvpn) {
|
||||||
if protonEmail == "" && protonUsername != "" {
|
options.ProtonUsername = &protonUsername
|
||||||
protonEmail = protonUsername + "@protonmail.com"
|
|
||||||
logger.Warn("use -proton-email instead of -proton-username in the future. " +
|
|
||||||
"This assumes the email is " + protonEmail + " and may not work.")
|
|
||||||
}
|
|
||||||
options.ProtonEmail = &protonEmail
|
|
||||||
options.ProtonPassword = &protonPassword
|
options.ProtonPassword = &protonPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ func readObsolete(r *reader.Reader) (warnings []string) {
|
|||||||
"DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.",
|
"DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.",
|
||||||
"HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete",
|
"HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete",
|
||||||
"HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete",
|
"HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete",
|
||||||
|
"DNS_SERVER": "DNS_SERVER is obsolete because the forwarding server is always enabled.",
|
||||||
|
"DOT": "DOT is obsolete because the forwarding server is always enabled.",
|
||||||
|
"DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because the forwarding server is always used and " +
|
||||||
|
"forwards local names to private DNS resolvers found in /etc/resolv.conf",
|
||||||
}
|
}
|
||||||
sortedKeys := maps.Keys(keyToMessage)
|
sortedKeys := maps.Keys(keyToMessage)
|
||||||
slices.Sort(sortedKeys)
|
slices.Sort(sortedKeys)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/provider"
|
"github.com/qdm12/dns/v2/pkg/provider"
|
||||||
@@ -13,20 +14,25 @@ import (
|
|||||||
"github.com/qdm12/gotree"
|
"github.com/qdm12/gotree"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DNSUpstreamTypeDot = "dot"
|
||||||
|
DNSUpstreamTypeDoh = "doh"
|
||||||
|
DNSUpstreamTypePlain = "plain"
|
||||||
|
)
|
||||||
|
|
||||||
// DNS contains settings to configure DNS.
|
// DNS contains settings to configure DNS.
|
||||||
type DNS struct {
|
type DNS struct {
|
||||||
// ServerEnabled is true if the server should be running
|
// UpstreamType can be [dnsUpstreamTypeDot], [dnsUpstreamTypeDoh]
|
||||||
// and used. It defaults to true, and cannot be nil
|
// or [dnsUpstreamTypePlain]. It defaults to [dnsUpstreamTypeDot].
|
||||||
// in the internal state.
|
|
||||||
ServerEnabled *bool
|
|
||||||
// UpstreamType can be dot or plain, and defaults to dot.
|
|
||||||
UpstreamType string `json:"upstream_type"`
|
UpstreamType string `json:"upstream_type"`
|
||||||
// UpdatePeriod is the period to update DNS block lists.
|
// UpdatePeriod is the period to update DNS block lists.
|
||||||
// It can be set to 0 to disable the update.
|
// It can be set to 0 to disable the update.
|
||||||
// It defaults to 24h and cannot be nil in
|
// It defaults to 24h and cannot be nil in
|
||||||
// the internal state.
|
// the internal state.
|
||||||
UpdatePeriod *time.Duration
|
UpdatePeriod *time.Duration
|
||||||
// Providers is a list of DNS providers
|
// Providers is a list of DNS providers.
|
||||||
|
// It defaults to either ["cloudflare"] or [] if the
|
||||||
|
// UpstreamPlainAddresses field is set.
|
||||||
Providers []string `json:"providers"`
|
Providers []string `json:"providers"`
|
||||||
// Caching is true if the server should cache
|
// Caching is true if the server should cache
|
||||||
// DNS responses.
|
// DNS responses.
|
||||||
@@ -36,32 +42,23 @@ type DNS struct {
|
|||||||
// Blacklist contains settings to configure the filter
|
// Blacklist contains settings to configure the filter
|
||||||
// block lists.
|
// block lists.
|
||||||
Blacklist DNSBlacklist
|
Blacklist DNSBlacklist
|
||||||
// ServerAddress is the DNS server to use inside
|
// UpstreamPlainAddresses are the upstream plaintext DNS resolver
|
||||||
// the Go program and for the system.
|
// addresses to use by the built-in DNS server forwarder.
|
||||||
// It defaults to '127.0.0.1' to be used with the
|
// Note, if the upstream type is [dnsUpstreamTypePlain] these are merged
|
||||||
// local server. It cannot be the zero value in the internal
|
// together with provider names set in the Providers field.
|
||||||
// state.
|
// If this field is set, the Providers field will default to the empty slice.
|
||||||
ServerAddress netip.Addr
|
UpstreamPlainAddresses []netip.AddrPort
|
||||||
// KeepNameserver is true if the existing DNS server
|
|
||||||
// found in /etc/resolv.conf should be used
|
|
||||||
// Note setting this to true will likely DNS traffic
|
|
||||||
// outside the VPN tunnel since it would go through
|
|
||||||
// the local DNS server of your Docker/Kubernetes
|
|
||||||
// configuration, which is likely not going through the tunnel.
|
|
||||||
// This will also disable the DNS forwarder server and the
|
|
||||||
// `ServerAddress` field will be ignored.
|
|
||||||
// It defaults to false and cannot be nil in the
|
|
||||||
// internal state.
|
|
||||||
KeepNameserver *bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
|
ErrDNSUpstreamTypeNotValid = errors.New("DNS upstream type is not valid")
|
||||||
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
|
ErrDNSUpdatePeriodTooShort = errors.New("update period is too short")
|
||||||
|
ErrDNSUpstreamPlainNoIPv6 = errors.New("upstream plain addresses do not contain any IPv6 address")
|
||||||
|
ErrDNSUpstreamPlainNoIPv4 = errors.New("upstream plain addresses do not contain any IPv4 address")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d DNS) validate() (err error) {
|
func (d DNS) validate() (err error) {
|
||||||
if !helpers.IsOneOf(d.UpstreamType, "dot", "doh", "plain") {
|
if !helpers.IsOneOf(d.UpstreamType, DNSUpstreamTypeDot, DNSUpstreamTypeDoh, DNSUpstreamTypePlain) {
|
||||||
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType)
|
return fmt.Errorf("%w: %s", ErrDNSUpstreamTypeNotValid, d.UpstreamType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,6 +76,18 @@ func (d DNS) validate() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if d.UpstreamType == DNSUpstreamTypePlain {
|
||||||
|
if *d.IPv6 && !slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool {
|
||||||
|
return addrPort.Addr().Is6()
|
||||||
|
}) {
|
||||||
|
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
|
||||||
|
} else if !slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool {
|
||||||
|
return addrPort.Addr().Is4()
|
||||||
|
}) {
|
||||||
|
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = d.Blacklist.validate()
|
err = d.Blacklist.validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -89,15 +98,13 @@ func (d DNS) validate() (err error) {
|
|||||||
|
|
||||||
func (d *DNS) Copy() (copied DNS) {
|
func (d *DNS) Copy() (copied DNS) {
|
||||||
return DNS{
|
return DNS{
|
||||||
ServerEnabled: gosettings.CopyPointer(d.ServerEnabled),
|
UpstreamType: d.UpstreamType,
|
||||||
UpstreamType: d.UpstreamType,
|
UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod),
|
||||||
UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod),
|
Providers: gosettings.CopySlice(d.Providers),
|
||||||
Providers: gosettings.CopySlice(d.Providers),
|
Caching: gosettings.CopyPointer(d.Caching),
|
||||||
Caching: gosettings.CopyPointer(d.Caching),
|
IPv6: gosettings.CopyPointer(d.IPv6),
|
||||||
IPv6: gosettings.CopyPointer(d.IPv6),
|
Blacklist: d.Blacklist.copy(),
|
||||||
Blacklist: d.Blacklist.copy(),
|
UpstreamPlainAddresses: d.UpstreamPlainAddresses,
|
||||||
ServerAddress: d.ServerAddress,
|
|
||||||
KeepNameserver: gosettings.CopyPointer(d.KeepNameserver),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,20 +112,17 @@ func (d *DNS) Copy() (copied DNS) {
|
|||||||
// settings object with any field set in the other
|
// settings object with any field set in the other
|
||||||
// settings.
|
// settings.
|
||||||
func (d *DNS) overrideWith(other DNS) {
|
func (d *DNS) overrideWith(other DNS) {
|
||||||
d.ServerEnabled = gosettings.OverrideWithPointer(d.ServerEnabled, other.ServerEnabled)
|
|
||||||
d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType)
|
d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType)
|
||||||
d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod)
|
d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod)
|
||||||
d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers)
|
d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers)
|
||||||
d.Caching = gosettings.OverrideWithPointer(d.Caching, other.Caching)
|
d.Caching = gosettings.OverrideWithPointer(d.Caching, other.Caching)
|
||||||
d.IPv6 = gosettings.OverrideWithPointer(d.IPv6, other.IPv6)
|
d.IPv6 = gosettings.OverrideWithPointer(d.IPv6, other.IPv6)
|
||||||
d.Blacklist.overrideWith(other.Blacklist)
|
d.Blacklist.overrideWith(other.Blacklist)
|
||||||
d.ServerAddress = gosettings.OverrideWithValidator(d.ServerAddress, other.ServerAddress)
|
d.UpstreamPlainAddresses = gosettings.OverrideWithSlice(d.UpstreamPlainAddresses, other.UpstreamPlainAddresses)
|
||||||
d.KeepNameserver = gosettings.OverrideWithPointer(d.KeepNameserver, other.KeepNameserver)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNS) setDefaults() {
|
func (d *DNS) setDefaults() {
|
||||||
d.ServerEnabled = gosettings.DefaultPointer(d.ServerEnabled, true)
|
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, DNSUpstreamTypeDot)
|
||||||
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, "dot")
|
|
||||||
const defaultUpdatePeriod = 24 * time.Hour
|
const defaultUpdatePeriod = 24 * time.Hour
|
||||||
d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod)
|
d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod)
|
||||||
d.Providers = gosettings.DefaultSlice(d.Providers, []string{
|
d.Providers = gosettings.DefaultSlice(d.Providers, []string{
|
||||||
@@ -127,26 +131,53 @@ func (d *DNS) setDefaults() {
|
|||||||
d.Caching = gosettings.DefaultPointer(d.Caching, true)
|
d.Caching = gosettings.DefaultPointer(d.Caching, true)
|
||||||
d.IPv6 = gosettings.DefaultPointer(d.IPv6, false)
|
d.IPv6 = gosettings.DefaultPointer(d.IPv6, false)
|
||||||
d.Blacklist.setDefaults()
|
d.Blacklist.setDefaults()
|
||||||
d.ServerAddress = gosettings.DefaultValidator(d.ServerAddress,
|
d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{})
|
||||||
netip.AddrFrom4([4]byte{127, 0, 0, 1}))
|
}
|
||||||
d.KeepNameserver = gosettings.DefaultPointer(d.KeepNameserver, false)
|
|
||||||
|
func defaultDNSProviders() []string {
|
||||||
|
return []string{
|
||||||
|
provider.Cloudflare().Name,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d DNS) GetFirstPlaintextIPv4() (ipv4 netip.Addr) {
|
func (d DNS) GetFirstPlaintextIPv4() (ipv4 netip.Addr) {
|
||||||
localhost := netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
if d.UpstreamType == DNSUpstreamTypePlain {
|
||||||
if d.ServerAddress.Compare(localhost) != 0 && d.ServerAddress.Is4() {
|
for _, addrPort := range d.UpstreamPlainAddresses {
|
||||||
return d.ServerAddress
|
if addrPort.Addr().Is4() {
|
||||||
|
return addrPort.Addr()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ipv4 = findPlainIPv4InProviders(d.Providers)
|
||||||
|
if ipv4.IsValid() {
|
||||||
|
return ipv4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Either:
|
||||||
|
// - all upstream plain addresses are IPv6 and no provider is set
|
||||||
|
// - all providers set do not have a plaintext IPv4 address
|
||||||
|
ipv4 = findPlainIPv4InProviders(defaultDNSProviders())
|
||||||
|
if !ipv4.IsValid() {
|
||||||
|
panic("no plaintext IPv4 address found in default DNS providers")
|
||||||
|
}
|
||||||
|
return ipv4
|
||||||
|
}
|
||||||
|
|
||||||
|
func findPlainIPv4InProviders(providerNames []string) netip.Addr {
|
||||||
providers := provider.NewProviders()
|
providers := provider.NewProviders()
|
||||||
provider, err := providers.Get(d.Providers[0])
|
for _, name := range providerNames {
|
||||||
if err != nil {
|
provider, err := providers.Get(name)
|
||||||
// Settings should be validated before calling this function,
|
if err != nil {
|
||||||
// so an error happening here is a programming error.
|
// Settings should be validated before calling this function,
|
||||||
panic(err)
|
// so an error happening here is a programming error.
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if len(provider.Plain.IPv4) > 0 {
|
||||||
|
return provider.Plain.IPv4[0].Addr()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return netip.Addr{}
|
||||||
return provider.Plain.IPv4[0].Addr()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d DNS) String() string {
|
func (d DNS) String() string {
|
||||||
@@ -155,22 +186,22 @@ func (d DNS) String() string {
|
|||||||
|
|
||||||
func (d DNS) toLinesNode() (node *gotree.Node) {
|
func (d DNS) toLinesNode() (node *gotree.Node) {
|
||||||
node = gotree.New("DNS settings:")
|
node = gotree.New("DNS settings:")
|
||||||
node.Appendf("Keep existing nameserver(s): %s", gosettings.BoolToYesNo(d.KeepNameserver))
|
|
||||||
if *d.KeepNameserver {
|
|
||||||
return node
|
|
||||||
}
|
|
||||||
node.Appendf("DNS server address to use: %s", d.ServerAddress)
|
|
||||||
|
|
||||||
node.Appendf("DNS forwarder server enabled: %s", gosettings.BoolToYesNo(d.ServerEnabled))
|
|
||||||
if !*d.ServerEnabled {
|
|
||||||
return node
|
|
||||||
}
|
|
||||||
|
|
||||||
node.Appendf("Upstream resolver type: %s", d.UpstreamType)
|
node.Appendf("Upstream resolver type: %s", d.UpstreamType)
|
||||||
|
|
||||||
upstreamResolvers := node.Append("Upstream resolvers:")
|
upstreamResolvers := node.Append("Upstream resolvers:")
|
||||||
for _, provider := range d.Providers {
|
if len(d.UpstreamPlainAddresses) > 0 {
|
||||||
upstreamResolvers.Append(provider)
|
if d.UpstreamType == DNSUpstreamTypePlain {
|
||||||
|
for _, addr := range d.UpstreamPlainAddresses {
|
||||||
|
upstreamResolvers.Append(addr.String())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
node.Appendf("Upstream plain addresses: ignored because upstream type is not plain")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, provider := range d.Providers {
|
||||||
|
upstreamResolvers.Append(provider)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
node.Appendf("Caching: %s", gosettings.BoolToYesNo(d.Caching))
|
node.Appendf("Caching: %s", gosettings.BoolToYesNo(d.Caching))
|
||||||
@@ -188,11 +219,6 @@ func (d DNS) toLinesNode() (node *gotree.Node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNS) read(r *reader.Reader) (err error) {
|
func (d *DNS) read(r *reader.Reader) (err error) {
|
||||||
d.ServerEnabled, err = r.BoolPtr("DNS_SERVER", reader.RetroKeys("DOT"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE")
|
d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE")
|
||||||
|
|
||||||
d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD")
|
d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD")
|
||||||
@@ -217,15 +243,43 @@ func (d *DNS) read(r *reader.Reader) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.ServerAddress, err = r.NetipAddr("DNS_ADDRESS", reader.RetroKeys("DNS_PLAINTEXT_ADDRESS"))
|
err = d.readUpstreamPlainAddresses(r)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.KeepNameserver, err = r.BoolPtr("DNS_KEEP_NAMESERVER")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
|
||||||
|
// If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_TYPE=plain
|
||||||
|
// for these to be used. This is an added safety measure to reduce misunderstandings, and
|
||||||
|
// reduce odd settings overrides.
|
||||||
|
d.UpstreamPlainAddresses, err = r.CSVNetipAddrPorts("DNS_UPSTREAM_PLAIN_ADDRESSES")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retro-compatibility - remove in v4
|
||||||
|
// If DNS_ADDRESS is set to a non-localhost address, append it to the other
|
||||||
|
// upstream plain addresses, assuming port 53, and force the upstream type to plain AND
|
||||||
|
// clear any user picked providers, to maintain retro-compatibility behavior.
|
||||||
|
serverAddress, err := r.NetipAddr("DNS_ADDRESS",
|
||||||
|
reader.RetroKeys("DNS_PLAINTEXT_ADDRESS"),
|
||||||
|
reader.IsRetro("DNS_UPSTREAM_PLAIN_ADDRESSES"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if !serverAddress.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
isLocalhost := serverAddress.Compare(netip.AddrFrom4([4]byte{127, 0, 0, 1})) == 0
|
||||||
|
if isLocalhost {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
const defaultPlainPort = 53
|
||||||
|
addrPort := netip.AddrPortFrom(serverAddress, defaultPlainPort)
|
||||||
|
d.UpstreamPlainAddresses = append(d.UpstreamPlainAddresses, addrPort)
|
||||||
|
d.UpstreamType = DNSUpstreamTypePlain
|
||||||
|
d.Providers = []string{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
26
internal/configuration/settings/dns_test.go
Normal file
26
internal/configuration/settings/dns_test.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/qdm12/dns/v2/pkg/provider"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_defaultDNSProviders(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
names := defaultDNSProviders()
|
||||||
|
|
||||||
|
found := false
|
||||||
|
providers := provider.NewProviders()
|
||||||
|
for _, name := range names {
|
||||||
|
provider, err := providers.Get(name)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if len(provider.Plain.IPv4) > 0 {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, found, "no default DNS provider has a plaintext IPv4 address")
|
||||||
|
}
|
||||||
@@ -37,7 +37,7 @@ var (
|
|||||||
ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
|
ErrSystemTimezoneNotValid = errors.New("timezone is not valid")
|
||||||
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
|
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
|
||||||
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
|
ErrUpdaterProtonPasswordMissing = errors.New("proton password is missing")
|
||||||
ErrUpdaterProtonEmailMissing = errors.New("proton email is missing")
|
ErrUpdaterProtonUsernameMissing = errors.New("proton username is missing")
|
||||||
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
|
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
|
||||||
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
|
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
|
||||||
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
|
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package settings
|
package settings
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -18,63 +17,34 @@ type Health struct {
|
|||||||
// for the health check server.
|
// for the health check server.
|
||||||
// It cannot be the empty string in the internal state.
|
// It cannot be the empty string in the internal state.
|
||||||
ServerAddress string
|
ServerAddress string
|
||||||
// TargetAddresses are the addresses (host or host:port)
|
// TargetAddress is the address (host or host:port)
|
||||||
// to TCP TLS dial to periodically for the health check.
|
// to TCP TLS dial to periodically for the health check.
|
||||||
// Addresses after the first one are used as fallbacks for retries.
|
// It cannot be the empty string in the internal state.
|
||||||
// It cannot be empty in the internal state.
|
TargetAddress string
|
||||||
TargetAddresses []string
|
// ICMPTargetIP is the IP address to use for ICMP echo requests
|
||||||
// ICMPTargetIPs are the IP addresses to use for ICMP echo requests
|
// in the health checker. It can be set to an unspecified address (0.0.0.0)
|
||||||
// in the health checker. The slice can be set to a single
|
// such that the VPN server IP is used, which is also the default behavior.
|
||||||
// unspecified address (0.0.0.0) such that the VPN server IP is used,
|
ICMPTargetIP netip.Addr
|
||||||
// although this can be less reliable. It defaults to [1.1.1.1,8.8.8.8],
|
|
||||||
// and cannot be left empty in the internal state.
|
|
||||||
ICMPTargetIPs []netip.Addr
|
|
||||||
// SmallCheckType is the type of small health check to perform.
|
|
||||||
// It can be "icmp" or "dns", and defaults to "icmp".
|
|
||||||
// Note it changes automatically to dns if icmp is not supported.
|
|
||||||
SmallCheckType string
|
|
||||||
// RestartVPN indicates whether to restart the VPN connection
|
// RestartVPN indicates whether to restart the VPN connection
|
||||||
// when the healthcheck fails.
|
// when the healthcheck fails.
|
||||||
RestartVPN *bool
|
RestartVPN *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrICMPTargetIPNotValid = errors.New("ICMP target IP address is not valid")
|
|
||||||
ErrICMPTargetIPsNotCompatible = errors.New("ICMP target IP addresses are not compatible")
|
|
||||||
ErrSmallCheckTypeNotValid = errors.New("small check type is not valid")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (h Health) Validate() (err error) {
|
func (h Health) Validate() (err error) {
|
||||||
err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
|
err = validate.ListeningAddress(h.ServerAddress, os.Getuid())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("server listening address is not valid: %w", err)
|
return fmt.Errorf("server listening address is not valid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ip := range h.ICMPTargetIPs {
|
|
||||||
switch {
|
|
||||||
case !ip.IsValid():
|
|
||||||
return fmt.Errorf("%w: %s", ErrICMPTargetIPNotValid, ip)
|
|
||||||
case ip.IsUnspecified() && len(h.ICMPTargetIPs) > 1:
|
|
||||||
return fmt.Errorf("%w: only a single IP address must be set if it is to be unspecified",
|
|
||||||
ErrICMPTargetIPsNotCompatible)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = validate.IsOneOf(h.SmallCheckType, "icmp", "dns")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%w: %s", ErrSmallCheckTypeNotValid, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Health) copy() (copied Health) {
|
func (h *Health) copy() (copied Health) {
|
||||||
return Health{
|
return Health{
|
||||||
ServerAddress: h.ServerAddress,
|
ServerAddress: h.ServerAddress,
|
||||||
TargetAddresses: h.TargetAddresses,
|
TargetAddress: h.TargetAddress,
|
||||||
ICMPTargetIPs: gosettings.CopySlice(h.ICMPTargetIPs),
|
ICMPTargetIP: h.ICMPTargetIP,
|
||||||
SmallCheckType: h.SmallCheckType,
|
RestartVPN: gosettings.CopyPointer(h.RestartVPN),
|
||||||
RestartVPN: gosettings.CopyPointer(h.RestartVPN),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,20 +53,15 @@ func (h *Health) copy() (copied Health) {
|
|||||||
// settings.
|
// settings.
|
||||||
func (h *Health) OverrideWith(other Health) {
|
func (h *Health) OverrideWith(other Health) {
|
||||||
h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress)
|
h.ServerAddress = gosettings.OverrideWithComparable(h.ServerAddress, other.ServerAddress)
|
||||||
h.TargetAddresses = gosettings.OverrideWithSlice(h.TargetAddresses, other.TargetAddresses)
|
h.TargetAddress = gosettings.OverrideWithComparable(h.TargetAddress, other.TargetAddress)
|
||||||
h.ICMPTargetIPs = gosettings.OverrideWithSlice(h.ICMPTargetIPs, other.ICMPTargetIPs)
|
h.ICMPTargetIP = gosettings.OverrideWithComparable(h.ICMPTargetIP, other.ICMPTargetIP)
|
||||||
h.SmallCheckType = gosettings.OverrideWithComparable(h.SmallCheckType, other.SmallCheckType)
|
|
||||||
h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN)
|
h.RestartVPN = gosettings.OverrideWithPointer(h.RestartVPN, other.RestartVPN)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Health) SetDefaults() {
|
func (h *Health) SetDefaults() {
|
||||||
h.ServerAddress = gosettings.DefaultComparable(h.ServerAddress, "127.0.0.1:9999")
|
h.ServerAddress = gosettings.DefaultComparable(h.ServerAddress, "127.0.0.1:9999")
|
||||||
h.TargetAddresses = gosettings.DefaultSlice(h.TargetAddresses, []string{"cloudflare.com:443", "github.com:443"})
|
h.TargetAddress = gosettings.DefaultComparable(h.TargetAddress, "cloudflare.com:443")
|
||||||
h.ICMPTargetIPs = gosettings.DefaultSlice(h.ICMPTargetIPs, []netip.Addr{
|
h.ICMPTargetIP = gosettings.DefaultComparable(h.ICMPTargetIP, netip.IPv4Unspecified()) // use the VPN server IP
|
||||||
netip.AddrFrom4([4]byte{1, 1, 1, 1}),
|
|
||||||
netip.AddrFrom4([4]byte{8, 8, 8, 8}),
|
|
||||||
})
|
|
||||||
h.SmallCheckType = gosettings.DefaultComparable(h.SmallCheckType, "icmp")
|
|
||||||
h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true)
|
h.RestartVPN = gosettings.DefaultPointer(h.RestartVPN, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,37 +72,24 @@ func (h Health) String() string {
|
|||||||
func (h Health) toLinesNode() (node *gotree.Node) {
|
func (h Health) toLinesNode() (node *gotree.Node) {
|
||||||
node = gotree.New("Health settings:")
|
node = gotree.New("Health settings:")
|
||||||
node.Appendf("Server listening address: %s", h.ServerAddress)
|
node.Appendf("Server listening address: %s", h.ServerAddress)
|
||||||
targetAddrs := node.Appendf("Target addresses:")
|
node.Appendf("Target address: %s", h.TargetAddress)
|
||||||
for _, targetAddr := range h.TargetAddresses {
|
icmpTarget := "VPN server IP"
|
||||||
targetAddrs.Append(targetAddr)
|
if !h.ICMPTargetIP.IsUnspecified() {
|
||||||
}
|
icmpTarget = h.ICMPTargetIP.String()
|
||||||
switch h.SmallCheckType {
|
|
||||||
case "icmp":
|
|
||||||
icmpNode := node.Appendf("Small health check type: ICMP echo request")
|
|
||||||
if len(h.ICMPTargetIPs) == 1 && h.ICMPTargetIPs[0].IsUnspecified() {
|
|
||||||
icmpNode.Appendf("ICMP target IP: VPN server IP address")
|
|
||||||
} else {
|
|
||||||
icmpIPs := icmpNode.Appendf("ICMP target IPs:")
|
|
||||||
for _, ip := range h.ICMPTargetIPs {
|
|
||||||
icmpIPs.Append(ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "dns":
|
|
||||||
node.Appendf("Small health check type: Plain DNS lookup over UDP")
|
|
||||||
}
|
}
|
||||||
|
node.Appendf("ICMP target IP: %s", icmpTarget)
|
||||||
node.Appendf("Restart VPN on healthcheck failure: %s", gosettings.BoolToYesNo(h.RestartVPN))
|
node.Appendf("Restart VPN on healthcheck failure: %s", gosettings.BoolToYesNo(h.RestartVPN))
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Health) Read(r *reader.Reader) (err error) {
|
func (h *Health) Read(r *reader.Reader) (err error) {
|
||||||
h.ServerAddress = r.String("HEALTH_SERVER_ADDRESS")
|
h.ServerAddress = r.String("HEALTH_SERVER_ADDRESS")
|
||||||
h.TargetAddresses = r.CSV("HEALTH_TARGET_ADDRESSES",
|
h.TargetAddress = r.String("HEALTH_TARGET_ADDRESS",
|
||||||
reader.RetroKeys("HEALTH_ADDRESS_TO_PING", "HEALTH_TARGET_ADDRESS"))
|
reader.RetroKeys("HEALTH_ADDRESS_TO_PING"))
|
||||||
h.ICMPTargetIPs, err = r.CSVNetipAddresses("HEALTH_ICMP_TARGET_IPS", reader.RetroKeys("HEALTH_ICMP_TARGET_IP"))
|
h.ICMPTargetIP, err = r.NetipAddr("HEALTH_ICMP_TARGET_IP")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
h.SmallCheckType = r.String("HEALTH_SMALL_CHECK_TYPE")
|
|
||||||
h.RestartVPN, err = r.BoolPtr("HEALTH_RESTART_VPN")
|
h.RestartVPN, err = r.BoolPtr("HEALTH_RESTART_VPN")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package settings
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
||||||
@@ -25,12 +24,6 @@ type OpenVPNSelection struct {
|
|||||||
// and can be udp or tcp. It cannot be the empty string
|
// and can be udp or tcp. It cannot be the empty string
|
||||||
// in the internal state.
|
// in the internal state.
|
||||||
Protocol string `json:"protocol"`
|
Protocol string `json:"protocol"`
|
||||||
// EndpointIP is the server endpoint IP address.
|
|
||||||
// If set, it overrides any IP address from the picked
|
|
||||||
// built-in server connection. To indicate it should
|
|
||||||
// not be used, it should be set to [netip.IPv4Unspecified].
|
|
||||||
// It can never be the zero value in the internal state.
|
|
||||||
EndpointIP netip.Addr `json:"endpoint_ip"`
|
|
||||||
// CustomPort is the OpenVPN server endpoint port.
|
// CustomPort is the OpenVPN server endpoint port.
|
||||||
// It can be set to 0 to indicate no custom port should
|
// It can be set to 0 to indicate no custom port should
|
||||||
// be used. It cannot be nil in the internal state.
|
// be used. It cannot be nil in the internal state.
|
||||||
@@ -149,7 +142,6 @@ func (o *OpenVPNSelection) copy() (copied OpenVPNSelection) {
|
|||||||
return OpenVPNSelection{
|
return OpenVPNSelection{
|
||||||
ConfFile: gosettings.CopyPointer(o.ConfFile),
|
ConfFile: gosettings.CopyPointer(o.ConfFile),
|
||||||
Protocol: o.Protocol,
|
Protocol: o.Protocol,
|
||||||
EndpointIP: o.EndpointIP,
|
|
||||||
CustomPort: gosettings.CopyPointer(o.CustomPort),
|
CustomPort: gosettings.CopyPointer(o.CustomPort),
|
||||||
PIAEncPreset: gosettings.CopyPointer(o.PIAEncPreset),
|
PIAEncPreset: gosettings.CopyPointer(o.PIAEncPreset),
|
||||||
}
|
}
|
||||||
@@ -159,14 +151,12 @@ func (o *OpenVPNSelection) overrideWith(other OpenVPNSelection) {
|
|||||||
o.ConfFile = gosettings.OverrideWithPointer(o.ConfFile, other.ConfFile)
|
o.ConfFile = gosettings.OverrideWithPointer(o.ConfFile, other.ConfFile)
|
||||||
o.Protocol = gosettings.OverrideWithComparable(o.Protocol, other.Protocol)
|
o.Protocol = gosettings.OverrideWithComparable(o.Protocol, other.Protocol)
|
||||||
o.CustomPort = gosettings.OverrideWithPointer(o.CustomPort, other.CustomPort)
|
o.CustomPort = gosettings.OverrideWithPointer(o.CustomPort, other.CustomPort)
|
||||||
o.EndpointIP = gosettings.OverrideWithValidator(o.EndpointIP, other.EndpointIP)
|
|
||||||
o.PIAEncPreset = gosettings.OverrideWithPointer(o.PIAEncPreset, other.PIAEncPreset)
|
o.PIAEncPreset = gosettings.OverrideWithPointer(o.PIAEncPreset, other.PIAEncPreset)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OpenVPNSelection) setDefaults(vpnProvider string) {
|
func (o *OpenVPNSelection) setDefaults(vpnProvider string) {
|
||||||
o.ConfFile = gosettings.DefaultPointer(o.ConfFile, "")
|
o.ConfFile = gosettings.DefaultPointer(o.ConfFile, "")
|
||||||
o.Protocol = gosettings.DefaultComparable(o.Protocol, constants.UDP)
|
o.Protocol = gosettings.DefaultComparable(o.Protocol, constants.UDP)
|
||||||
o.EndpointIP = gosettings.DefaultValidator(o.EndpointIP, netip.IPv4Unspecified())
|
|
||||||
o.CustomPort = gosettings.DefaultPointer(o.CustomPort, 0)
|
o.CustomPort = gosettings.DefaultPointer(o.CustomPort, 0)
|
||||||
|
|
||||||
var defaultEncPreset string
|
var defaultEncPreset string
|
||||||
@@ -184,10 +174,6 @@ func (o OpenVPNSelection) toLinesNode() (node *gotree.Node) {
|
|||||||
node = gotree.New("OpenVPN server selection settings:")
|
node = gotree.New("OpenVPN server selection settings:")
|
||||||
node.Appendf("Protocol: %s", strings.ToUpper(o.Protocol))
|
node.Appendf("Protocol: %s", strings.ToUpper(o.Protocol))
|
||||||
|
|
||||||
if !o.EndpointIP.IsUnspecified() {
|
|
||||||
node.Appendf("Endpoint IP address: %s", o.EndpointIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
if *o.CustomPort != 0 {
|
if *o.CustomPort != 0 {
|
||||||
node.Appendf("Custom port: %d", *o.CustomPort)
|
node.Appendf("Custom port: %d", *o.CustomPort)
|
||||||
}
|
}
|
||||||
@@ -208,12 +194,6 @@ func (o *OpenVPNSelection) read(r *reader.Reader) (err error) {
|
|||||||
|
|
||||||
o.Protocol = r.String("OPENVPN_PROTOCOL", reader.RetroKeys("PROTOCOL"))
|
o.Protocol = r.String("OPENVPN_PROTOCOL", reader.RetroKeys("PROTOCOL"))
|
||||||
|
|
||||||
o.EndpointIP, err = r.NetipAddr("OPENVPN_ENDPOINT_IP",
|
|
||||||
reader.RetroKeys("OPENVPN_TARGET_IP", "VPN_ENDPOINT_IP"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
o.CustomPort, err = r.Uint16Ptr("OPENVPN_ENDPOINT_PORT",
|
o.CustomPort, err = r.Uint16Ptr("OPENVPN_ENDPOINT_PORT",
|
||||||
reader.RetroKeys("PORT", "OPENVPN_PORT", "VPN_ENDPOINT_PORT"))
|
reader.RetroKeys("PORT", "OPENVPN_PORT", "VPN_ENDPOINT_PORT"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package settings
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
||||||
@@ -21,6 +22,12 @@ type ServerSelection struct {
|
|||||||
// or 'wireguard'. It cannot be the empty string
|
// or 'wireguard'. It cannot be the empty string
|
||||||
// in the internal state.
|
// in the internal state.
|
||||||
VPN string `json:"vpn"`
|
VPN string `json:"vpn"`
|
||||||
|
// TargetIP is the server endpoint IP address to use.
|
||||||
|
// It will override any IP address from the picked
|
||||||
|
// built-in server. It cannot be the empty value in the internal
|
||||||
|
// state, and can be set to the unspecified address to indicate
|
||||||
|
// there is not target IP address to use.
|
||||||
|
TargetIP netip.Addr `json:"target_ip"`
|
||||||
// Countries is the list of countries to filter VPN servers with.
|
// Countries is the list of countries to filter VPN servers with.
|
||||||
Countries []string `json:"countries"`
|
Countries []string `json:"countries"`
|
||||||
// Categories is the list of categories to filter VPN servers with.
|
// Categories is the list of categories to filter VPN servers with.
|
||||||
@@ -292,6 +299,7 @@ func validateFeatureFilters(settings ServerSelection, vpnServiceProvider string)
|
|||||||
func (ss *ServerSelection) copy() (copied ServerSelection) {
|
func (ss *ServerSelection) copy() (copied ServerSelection) {
|
||||||
return ServerSelection{
|
return ServerSelection{
|
||||||
VPN: ss.VPN,
|
VPN: ss.VPN,
|
||||||
|
TargetIP: ss.TargetIP,
|
||||||
Countries: gosettings.CopySlice(ss.Countries),
|
Countries: gosettings.CopySlice(ss.Countries),
|
||||||
Categories: gosettings.CopySlice(ss.Categories),
|
Categories: gosettings.CopySlice(ss.Categories),
|
||||||
Regions: gosettings.CopySlice(ss.Regions),
|
Regions: gosettings.CopySlice(ss.Regions),
|
||||||
@@ -315,6 +323,7 @@ func (ss *ServerSelection) copy() (copied ServerSelection) {
|
|||||||
|
|
||||||
func (ss *ServerSelection) overrideWith(other ServerSelection) {
|
func (ss *ServerSelection) overrideWith(other ServerSelection) {
|
||||||
ss.VPN = gosettings.OverrideWithComparable(ss.VPN, other.VPN)
|
ss.VPN = gosettings.OverrideWithComparable(ss.VPN, other.VPN)
|
||||||
|
ss.TargetIP = gosettings.OverrideWithValidator(ss.TargetIP, other.TargetIP)
|
||||||
ss.Countries = gosettings.OverrideWithSlice(ss.Countries, other.Countries)
|
ss.Countries = gosettings.OverrideWithSlice(ss.Countries, other.Countries)
|
||||||
ss.Categories = gosettings.OverrideWithSlice(ss.Categories, other.Categories)
|
ss.Categories = gosettings.OverrideWithSlice(ss.Categories, other.Categories)
|
||||||
ss.Regions = gosettings.OverrideWithSlice(ss.Regions, other.Regions)
|
ss.Regions = gosettings.OverrideWithSlice(ss.Regions, other.Regions)
|
||||||
@@ -337,6 +346,7 @@ func (ss *ServerSelection) overrideWith(other ServerSelection) {
|
|||||||
|
|
||||||
func (ss *ServerSelection) setDefaults(vpnProvider string, portForwardingEnabled bool) {
|
func (ss *ServerSelection) setDefaults(vpnProvider string, portForwardingEnabled bool) {
|
||||||
ss.VPN = gosettings.DefaultComparable(ss.VPN, vpn.OpenVPN)
|
ss.VPN = gosettings.DefaultComparable(ss.VPN, vpn.OpenVPN)
|
||||||
|
ss.TargetIP = gosettings.DefaultValidator(ss.TargetIP, netip.IPv4Unspecified())
|
||||||
ss.OwnedOnly = gosettings.DefaultPointer(ss.OwnedOnly, false)
|
ss.OwnedOnly = gosettings.DefaultPointer(ss.OwnedOnly, false)
|
||||||
ss.FreeOnly = gosettings.DefaultPointer(ss.FreeOnly, false)
|
ss.FreeOnly = gosettings.DefaultPointer(ss.FreeOnly, false)
|
||||||
ss.PremiumOnly = gosettings.DefaultPointer(ss.PremiumOnly, false)
|
ss.PremiumOnly = gosettings.DefaultPointer(ss.PremiumOnly, false)
|
||||||
@@ -358,6 +368,9 @@ func (ss ServerSelection) String() string {
|
|||||||
func (ss ServerSelection) toLinesNode() (node *gotree.Node) {
|
func (ss ServerSelection) toLinesNode() (node *gotree.Node) {
|
||||||
node = gotree.New("Server selection settings:")
|
node = gotree.New("Server selection settings:")
|
||||||
node.Appendf("VPN type: %s", ss.VPN)
|
node.Appendf("VPN type: %s", ss.VPN)
|
||||||
|
if !ss.TargetIP.IsUnspecified() {
|
||||||
|
node.Appendf("Target IP address: %s", ss.TargetIP)
|
||||||
|
}
|
||||||
|
|
||||||
if len(ss.Countries) > 0 {
|
if len(ss.Countries) > 0 {
|
||||||
node.Appendf("Countries: %s", strings.Join(ss.Countries, ", "))
|
node.Appendf("Countries: %s", strings.Join(ss.Countries, ", "))
|
||||||
@@ -448,6 +461,12 @@ func (ss *ServerSelection) read(r *reader.Reader,
|
|||||||
) (err error) {
|
) (err error) {
|
||||||
ss.VPN = vpnType
|
ss.VPN = vpnType
|
||||||
|
|
||||||
|
ss.TargetIP, err = r.NetipAddr("OPENVPN_ENDPOINT_IP",
|
||||||
|
reader.RetroKeys("OPENVPN_TARGET_IP", "VPN_ENDPOINT_IP"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
countriesRetroKeys := []string{"COUNTRY"}
|
countriesRetroKeys := []string{"COUNTRY"}
|
||||||
if vpnProvider == providers.Cyberghost {
|
if vpnProvider == providers.Cyberghost {
|
||||||
countriesRetroKeys = append(countriesRetroKeys, "REGION")
|
countriesRetroKeys = append(countriesRetroKeys, "REGION")
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package settings
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
"github.com/qdm12/gluetun/internal/configuration/settings/helpers"
|
||||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||||
@@ -174,13 +173,11 @@ func (s Settings) Warnings() (warnings []string) {
|
|||||||
"by creating an issue, attaching the new certificate and we will update Gluetun.")
|
"by creating an issue, attaching the new certificate and we will update Gluetun.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO remove in v4
|
for _, upstreamAddress := range s.DNS.UpstreamPlainAddresses {
|
||||||
if s.DNS.ServerAddress.Unmap().Compare(netip.AddrFrom4([4]byte{127, 0, 0, 1})) != 0 {
|
if upstreamAddress.Addr().IsPrivate() {
|
||||||
warnings = append(warnings, "DNS address is set to "+s.DNS.ServerAddress.String()+
|
warnings = append(warnings, "DNS upstream address "+upstreamAddress.String()+" is private: "+
|
||||||
" so the local forwarding DNS server will not be used."+
|
"DNS traffic might leak out of the VPN tunnel to that address.")
|
||||||
" The default value changed to 127.0.0.1 so it uses the internal DNS server."+
|
}
|
||||||
" If this server fails to start, the IPv4 address of the first plaintext DNS server"+
|
|
||||||
" corresponding to the first DNS provider chosen is used.")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return warnings
|
return warnings
|
||||||
|
|||||||
@@ -38,9 +38,6 @@ func Test_Settings_String(t *testing.T) {
|
|||||||
| ├── Run OpenVPN as: root
|
| ├── Run OpenVPN as: root
|
||||||
| └── Verbosity level: 1
|
| └── Verbosity level: 1
|
||||||
├── DNS settings:
|
├── DNS settings:
|
||||||
| ├── Keep existing nameserver(s): no
|
|
||||||
| ├── DNS server address to use: 127.0.0.1
|
|
||||||
| ├── DNS forwarder server enabled: yes
|
|
||||||
| ├── Upstream resolver type: dot
|
| ├── Upstream resolver type: dot
|
||||||
| ├── Upstream resolvers:
|
| ├── Upstream resolvers:
|
||||||
| | └── Cloudflare
|
| | └── Cloudflare
|
||||||
@@ -57,13 +54,8 @@ func Test_Settings_String(t *testing.T) {
|
|||||||
| └── Log level: INFO
|
| └── Log level: INFO
|
||||||
├── Health settings:
|
├── Health settings:
|
||||||
| ├── Server listening address: 127.0.0.1:9999
|
| ├── Server listening address: 127.0.0.1:9999
|
||||||
| ├── Target addresses:
|
| ├── Target address: cloudflare.com:443
|
||||||
| | ├── cloudflare.com:443
|
| ├── ICMP target IP: VPN server IP
|
||||||
| | └── github.com:443
|
|
||||||
| ├── Small health check type: ICMP echo request
|
|
||||||
| | └── ICMP target IPs:
|
|
||||||
| | ├── 1.1.1.1
|
|
||||||
| | └── 8.8.8.8
|
|
||||||
| └── Restart VPN on healthcheck failure: yes
|
| └── Restart VPN on healthcheck failure: yes
|
||||||
├── Shadowsocks server settings:
|
├── Shadowsocks server settings:
|
||||||
| └── Enabled: no
|
| └── Enabled: no
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ type Updater struct {
|
|||||||
// Providers is the list of VPN service providers
|
// Providers is the list of VPN service providers
|
||||||
// to update server information for.
|
// to update server information for.
|
||||||
Providers []string
|
Providers []string
|
||||||
// ProtonEmail is the email to authenticate with the Proton API.
|
// ProtonUsername is the username to authenticate with the Proton API.
|
||||||
ProtonEmail *string
|
ProtonUsername *string
|
||||||
// ProtonPassword is the password to authenticate with the Proton API.
|
// ProtonPassword is the password to authenticate with the Proton API.
|
||||||
ProtonPassword *string
|
ProtonPassword *string
|
||||||
}
|
}
|
||||||
@@ -58,11 +58,11 @@ func (u Updater) Validate() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if provider == providers.Protonvpn {
|
if provider == providers.Protonvpn {
|
||||||
authenticatedAPI := *u.ProtonEmail != "" || *u.ProtonPassword != ""
|
authenticatedAPI := *u.ProtonUsername != "" || *u.ProtonPassword != ""
|
||||||
if authenticatedAPI {
|
if authenticatedAPI {
|
||||||
switch {
|
switch {
|
||||||
case *u.ProtonEmail == "":
|
case *u.ProtonUsername == "":
|
||||||
return fmt.Errorf("%w", ErrUpdaterProtonEmailMissing)
|
return fmt.Errorf("%w", ErrUpdaterProtonUsernameMissing)
|
||||||
case *u.ProtonPassword == "":
|
case *u.ProtonPassword == "":
|
||||||
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
|
return fmt.Errorf("%w", ErrUpdaterProtonPasswordMissing)
|
||||||
}
|
}
|
||||||
@@ -79,7 +79,7 @@ func (u *Updater) copy() (copied Updater) {
|
|||||||
DNSAddress: u.DNSAddress,
|
DNSAddress: u.DNSAddress,
|
||||||
MinRatio: u.MinRatio,
|
MinRatio: u.MinRatio,
|
||||||
Providers: gosettings.CopySlice(u.Providers),
|
Providers: gosettings.CopySlice(u.Providers),
|
||||||
ProtonEmail: gosettings.CopyPointer(u.ProtonEmail),
|
ProtonUsername: gosettings.CopyPointer(u.ProtonUsername),
|
||||||
ProtonPassword: gosettings.CopyPointer(u.ProtonPassword),
|
ProtonPassword: gosettings.CopyPointer(u.ProtonPassword),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -92,7 +92,7 @@ func (u *Updater) overrideWith(other Updater) {
|
|||||||
u.DNSAddress = gosettings.OverrideWithComparable(u.DNSAddress, other.DNSAddress)
|
u.DNSAddress = gosettings.OverrideWithComparable(u.DNSAddress, other.DNSAddress)
|
||||||
u.MinRatio = gosettings.OverrideWithComparable(u.MinRatio, other.MinRatio)
|
u.MinRatio = gosettings.OverrideWithComparable(u.MinRatio, other.MinRatio)
|
||||||
u.Providers = gosettings.OverrideWithSlice(u.Providers, other.Providers)
|
u.Providers = gosettings.OverrideWithSlice(u.Providers, other.Providers)
|
||||||
u.ProtonEmail = gosettings.OverrideWithPointer(u.ProtonEmail, other.ProtonEmail)
|
u.ProtonUsername = gosettings.OverrideWithPointer(u.ProtonUsername, other.ProtonUsername)
|
||||||
u.ProtonPassword = gosettings.OverrideWithPointer(u.ProtonPassword, other.ProtonPassword)
|
u.ProtonPassword = gosettings.OverrideWithPointer(u.ProtonPassword, other.ProtonPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ func (u *Updater) SetDefaults(vpnProvider string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set these to empty strings to avoid nil pointer panics
|
// Set these to empty strings to avoid nil pointer panics
|
||||||
u.ProtonEmail = gosettings.DefaultPointer(u.ProtonEmail, "")
|
u.ProtonUsername = gosettings.DefaultPointer(u.ProtonUsername, "")
|
||||||
u.ProtonPassword = gosettings.DefaultPointer(u.ProtonPassword, "")
|
u.ProtonPassword = gosettings.DefaultPointer(u.ProtonPassword, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ func (u Updater) toLinesNode() (node *gotree.Node) {
|
|||||||
node.Appendf("Minimum ratio: %.1f", u.MinRatio)
|
node.Appendf("Minimum ratio: %.1f", u.MinRatio)
|
||||||
node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", "))
|
node.Appendf("Providers to update: %s", strings.Join(u.Providers, ", "))
|
||||||
if slices.Contains(u.Providers, providers.Protonvpn) {
|
if slices.Contains(u.Providers, providers.Protonvpn) {
|
||||||
node.Appendf("Proton API email: %s", *u.ProtonEmail)
|
node.Appendf("Proton API username: %s", *u.ProtonUsername)
|
||||||
node.Appendf("Proton API password: %s", gosettings.ObfuscateKey(*u.ProtonPassword))
|
node.Appendf("Proton API password: %s", gosettings.ObfuscateKey(*u.ProtonPassword))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,13 +154,11 @@ func (u *Updater) read(r *reader.Reader) (err error) {
|
|||||||
|
|
||||||
u.Providers = r.CSV("UPDATER_VPN_SERVICE_PROVIDERS")
|
u.Providers = r.CSV("UPDATER_VPN_SERVICE_PROVIDERS")
|
||||||
|
|
||||||
u.ProtonEmail = r.Get("UPDATER_PROTONVPN_EMAIL")
|
u.ProtonUsername = r.Get("UPDATER_PROTONVPN_USERNAME")
|
||||||
if u.ProtonEmail == nil {
|
if u.ProtonUsername != nil {
|
||||||
protonUsername := r.String("UPDATER_PROTONVPN_USERNAME", reader.IsRetro("UPDATER_PROTONVPN_EMAIL"))
|
// Enforce to use the username not the email address
|
||||||
if protonUsername != "" {
|
*u.ProtonUsername = strings.TrimSuffix(*u.ProtonUsername, "@protonmail.com")
|
||||||
protonEmail := protonUsername + "@protonmail.com"
|
*u.ProtonUsername = strings.TrimSuffix(*u.ProtonUsername, "@proton.me")
|
||||||
u.ProtonEmail = &protonEmail
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
u.ProtonPassword = r.Get("UPDATER_PROTONVPN_PASSWORD")
|
u.ProtonPassword = r.Get("UPDATER_PROTONVPN_PASSWORD")
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ import (
|
|||||||
|
|
||||||
type WireguardSelection struct {
|
type WireguardSelection struct {
|
||||||
// EndpointIP is the server endpoint IP address.
|
// EndpointIP is the server endpoint IP address.
|
||||||
// It is notably required with the custom provider.
|
// It is only used with VPN providers generating Wireguard
|
||||||
// Otherwise it overrides any IP address from the picked
|
// configurations specific to each server and user.
|
||||||
// built-in server connection. To indicate it should
|
// To indicate it should not be used, it should be set
|
||||||
// not be used, it should be set to [netip.IPv4Unspecified].
|
// to netip.IPv4Unspecified(). It can never be the zero value
|
||||||
// It can never be the zero value in the internal state.
|
// in the internal state.
|
||||||
EndpointIP netip.Addr `json:"endpoint_ip"`
|
EndpointIP netip.Addr `json:"endpoint_ip"`
|
||||||
// EndpointPort is a the server port to use for the VPN server.
|
// EndpointPort is a the server port to use for the VPN server.
|
||||||
// It is optional for VPN providers IVPN, Mullvad, Surfshark
|
// It is optional for VPN providers IVPN, Mullvad, Surfshark
|
||||||
|
|||||||
@@ -18,14 +18,8 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if *l.GetSettings().KeepNameserver {
|
const fallback = false
|
||||||
l.logger.Warn("⚠️⚠️⚠️ keeping the default container nameservers, " +
|
l.useUnencryptedDNS(fallback)
|
||||||
"this will likely leak DNS traffic outside the VPN " +
|
|
||||||
"and go through your container network DNS outside the VPN tunnel!")
|
|
||||||
} else {
|
|
||||||
const fallback = false
|
|
||||||
l.useUnencryptedDNS(fallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-l.start:
|
case <-l.start:
|
||||||
@@ -38,8 +32,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
// Their values are to be used if DOT=off
|
// Their values are to be used if DOT=off
|
||||||
var runError <-chan error
|
var runError <-chan error
|
||||||
|
|
||||||
settings := l.GetSettings()
|
for {
|
||||||
for !*settings.KeepNameserver && *settings.ServerEnabled {
|
|
||||||
var err error
|
var err error
|
||||||
runError, err = l.setupServer(ctx)
|
runError, err = l.setupServer(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -59,15 +52,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
l.useUnencryptedDNS(fallback)
|
l.useUnencryptedDNS(fallback)
|
||||||
}
|
}
|
||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
settings = l.GetSettings()
|
|
||||||
}
|
}
|
||||||
l.signalOrSetStatus(constants.Running)
|
l.signalOrSetStatus(constants.Running)
|
||||||
|
|
||||||
settings = l.GetSettings()
|
const fallback = false
|
||||||
if !*settings.KeepNameserver && !*settings.ServerEnabled {
|
l.useUnencryptedDNS(fallback)
|
||||||
const fallback = false
|
|
||||||
l.useUnencryptedDNS(fallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
l.userTrigger = false
|
l.userTrigger = false
|
||||||
|
|
||||||
@@ -82,19 +71,15 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
if !*l.GetSettings().KeepNameserver {
|
l.stopServer()
|
||||||
l.stopServer()
|
// TODO revert OS and Go nameserver when exiting
|
||||||
// TODO revert OS and Go nameserver when exiting
|
|
||||||
}
|
|
||||||
return true
|
return true
|
||||||
case <-l.stop:
|
case <-l.stop:
|
||||||
l.userTrigger = true
|
l.userTrigger = true
|
||||||
l.logger.Info("stopping")
|
l.logger.Info("stopping")
|
||||||
if !*l.GetSettings().KeepNameserver {
|
const fallback = false
|
||||||
const fallback = false
|
l.useUnencryptedDNS(fallback)
|
||||||
l.useUnencryptedDNS(fallback)
|
l.stopServer()
|
||||||
l.stopServer()
|
|
||||||
}
|
|
||||||
l.stopped <- struct{}{}
|
l.stopped <- struct{}{}
|
||||||
case <-l.start:
|
case <-l.start:
|
||||||
l.userTrigger = true
|
l.userTrigger = true
|
||||||
|
|||||||
@@ -26,31 +26,23 @@ func (l *Loop) SetSettings(ctx context.Context, settings settings.DNS) (
|
|||||||
return l.state.SetSettings(ctx, settings)
|
return l.state.SetSettings(ctx, settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildServerSettings(settings settings.DNS,
|
func buildServerSettings(userSettings settings.DNS,
|
||||||
filter *mapfilter.Filter, localResolvers []netip.Addr,
|
filter *mapfilter.Filter, localResolvers []netip.Addr,
|
||||||
logger Logger) (
|
logger Logger) (
|
||||||
serverSettings server.Settings, err error,
|
serverSettings server.Settings, err error,
|
||||||
) {
|
) {
|
||||||
serverSettings.Logger = logger
|
serverSettings.Logger = logger
|
||||||
|
|
||||||
providersData := provider.NewProviders()
|
upstreamResolvers := buildProviders(userSettings)
|
||||||
upstreamResolvers := make([]provider.Provider, len(settings.Providers))
|
|
||||||
for i := range settings.Providers {
|
|
||||||
var err error
|
|
||||||
upstreamResolvers[i], err = providersData.Get(settings.Providers[i])
|
|
||||||
if err != nil {
|
|
||||||
panic(err) // this should already had been checked
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ipVersion := "ipv4"
|
ipVersion := "ipv4"
|
||||||
if *settings.IPv6 {
|
if *userSettings.IPv6 {
|
||||||
ipVersion = "ipv6"
|
ipVersion = "ipv6"
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialer server.Dialer
|
var dialer server.Dialer
|
||||||
switch settings.UpstreamType {
|
switch userSettings.UpstreamType {
|
||||||
case "dot":
|
case settings.DNSUpstreamTypeDot:
|
||||||
dialerSettings := dot.Settings{
|
dialerSettings := dot.Settings{
|
||||||
UpstreamResolvers: upstreamResolvers,
|
UpstreamResolvers: upstreamResolvers,
|
||||||
IPVersion: ipVersion,
|
IPVersion: ipVersion,
|
||||||
@@ -59,7 +51,7 @@ func buildServerSettings(settings settings.DNS,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return server.Settings{}, fmt.Errorf("creating DNS over TLS dialer: %w", err)
|
return server.Settings{}, fmt.Errorf("creating DNS over TLS dialer: %w", err)
|
||||||
}
|
}
|
||||||
case "doh":
|
case settings.DNSUpstreamTypeDoh:
|
||||||
dialerSettings := doh.Settings{
|
dialerSettings := doh.Settings{
|
||||||
UpstreamResolvers: upstreamResolvers,
|
UpstreamResolvers: upstreamResolvers,
|
||||||
IPVersion: ipVersion,
|
IPVersion: ipVersion,
|
||||||
@@ -68,7 +60,7 @@ func buildServerSettings(settings settings.DNS,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return server.Settings{}, fmt.Errorf("creating DNS over HTTPS dialer: %w", err)
|
return server.Settings{}, fmt.Errorf("creating DNS over HTTPS dialer: %w", err)
|
||||||
}
|
}
|
||||||
case "plain":
|
case settings.DNSUpstreamTypePlain:
|
||||||
dialerSettings := plain.Settings{
|
dialerSettings := plain.Settings{
|
||||||
UpstreamResolvers: upstreamResolvers,
|
UpstreamResolvers: upstreamResolvers,
|
||||||
IPVersion: ipVersion,
|
IPVersion: ipVersion,
|
||||||
@@ -78,11 +70,11 @@ func buildServerSettings(settings settings.DNS,
|
|||||||
return server.Settings{}, fmt.Errorf("creating plain DNS dialer: %w", err)
|
return server.Settings{}, fmt.Errorf("creating plain DNS dialer: %w", err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
panic("unknown upstream type: " + settings.UpstreamType)
|
panic("unknown upstream type: " + userSettings.UpstreamType)
|
||||||
}
|
}
|
||||||
serverSettings.Dialer = dialer
|
serverSettings.Dialer = dialer
|
||||||
|
|
||||||
if *settings.Caching {
|
if *userSettings.Caching {
|
||||||
lruCache, err := lru.New(lru.Settings{})
|
lruCache, err := lru.New(lru.Settings{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return server.Settings{}, fmt.Errorf("creating LRU cache: %w", err)
|
return server.Settings{}, fmt.Errorf("creating LRU cache: %w", err)
|
||||||
@@ -123,3 +115,48 @@ func buildServerSettings(settings settings.DNS,
|
|||||||
|
|
||||||
return serverSettings, nil
|
return serverSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildProviders(userSettings settings.DNS) []provider.Provider {
|
||||||
|
if userSettings.UpstreamType == settings.DNSUpstreamTypePlain &&
|
||||||
|
len(userSettings.UpstreamPlainAddresses) > 0 {
|
||||||
|
providers := make([]provider.Provider, len(userSettings.UpstreamPlainAddresses))
|
||||||
|
for i, addrPort := range userSettings.UpstreamPlainAddresses {
|
||||||
|
providers[i] = provider.Provider{
|
||||||
|
Name: addrPort.String(),
|
||||||
|
}
|
||||||
|
if addrPort.Addr().Is4() {
|
||||||
|
providers[i].Plain.IPv4 = []netip.AddrPort{addrPort}
|
||||||
|
} else {
|
||||||
|
providers[i].Plain.IPv6 = []netip.AddrPort{addrPort}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
providersData := provider.NewProviders()
|
||||||
|
providers := make([]provider.Provider, 0, len(userSettings.Providers)+len(userSettings.UpstreamPlainAddresses))
|
||||||
|
for _, providerName := range userSettings.Providers {
|
||||||
|
provider, err := providersData.Get(providerName)
|
||||||
|
if err != nil {
|
||||||
|
panic(err) // this should already had been checked
|
||||||
|
}
|
||||||
|
providers = append(providers, provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userSettings.UpstreamType != settings.DNSUpstreamTypePlain {
|
||||||
|
return providers
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addrPort := range userSettings.UpstreamPlainAddresses {
|
||||||
|
newProvider := provider.Provider{
|
||||||
|
Name: addrPort.String(),
|
||||||
|
}
|
||||||
|
if addrPort.Addr().Is4() {
|
||||||
|
newProvider.Plain.IPv4 = []netip.AddrPort{addrPort}
|
||||||
|
} else {
|
||||||
|
newProvider.Plain.IPv6 = []netip.AddrPort{addrPort}
|
||||||
|
}
|
||||||
|
providers = append(providers, newProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
return providers
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/check"
|
"github.com/qdm12/dns/v2/pkg/check"
|
||||||
"github.com/qdm12/dns/v2/pkg/nameserver"
|
"github.com/qdm12/dns/v2/pkg/nameserver"
|
||||||
@@ -38,12 +37,8 @@ func (l *Loop) setupServer(ctx context.Context) (runError <-chan error, err erro
|
|||||||
l.server = server
|
l.server = server
|
||||||
|
|
||||||
// use internal DNS server
|
// use internal DNS server
|
||||||
const defaultDNSPort = 53
|
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{})
|
||||||
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
|
|
||||||
AddrPort: netip.AddrPortFrom(settings.ServerAddress, defaultDNSPort),
|
|
||||||
})
|
|
||||||
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
|
||||||
IPs: []netip.Addr{settings.ServerAddress},
|
|
||||||
ResolvPath: l.resolvConf,
|
ResolvPath: l.resolvConf,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -40,8 +40,6 @@ func (s *State) SetSettings(ctx context.Context, settings settings.DNS) (
|
|||||||
|
|
||||||
// Restart
|
// Restart
|
||||||
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
|
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
|
||||||
if *settings.ServerEnabled {
|
outcome, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
|
||||||
outcome, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
|
|
||||||
}
|
|
||||||
return outcome
|
return outcome
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,16 +16,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Checker struct {
|
type Checker struct {
|
||||||
tlsDialAddrs []string
|
tlsDialAddr string
|
||||||
dialer *net.Dialer
|
dialer *net.Dialer
|
||||||
echoer *icmp.Echoer
|
echoer *icmp.Echoer
|
||||||
dnsClient *dns.Client
|
dnsClient *dns.Client
|
||||||
logger Logger
|
logger Logger
|
||||||
icmpTargetIPs []netip.Addr
|
icmpTarget netip.Addr
|
||||||
smallCheckType string
|
configMutex sync.Mutex
|
||||||
configMutex sync.Mutex
|
|
||||||
|
|
||||||
icmpNotPermitted bool
|
icmpNotPermitted bool
|
||||||
|
smallCheckName string
|
||||||
|
|
||||||
// Internal periodic service signals
|
// Internal periodic service signals
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
@@ -45,37 +45,35 @@ func NewChecker(logger Logger) *Checker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetConfig sets the TCP+TLS dial addresses, the ICMP echo IP address
|
// SetConfig sets the TCP+TLS dial address and the ICMP echo IP address
|
||||||
// to target and the desired small check type (dns or icmp).
|
// to target by the [Checker].
|
||||||
// This function MUST be called before calling [Checker.Start].
|
// This function MUST be called before calling [Checker.Start].
|
||||||
func (c *Checker) SetConfig(tlsDialAddrs []string, icmpTargets []netip.Addr,
|
func (c *Checker) SetConfig(tlsDialAddr string, icmpTarget netip.Addr) {
|
||||||
smallCheckType string,
|
|
||||||
) {
|
|
||||||
c.configMutex.Lock()
|
c.configMutex.Lock()
|
||||||
defer c.configMutex.Unlock()
|
defer c.configMutex.Unlock()
|
||||||
c.tlsDialAddrs = tlsDialAddrs
|
c.tlsDialAddr = tlsDialAddr
|
||||||
c.icmpTargetIPs = icmpTargets
|
c.icmpTarget = icmpTarget
|
||||||
c.smallCheckType = smallCheckType
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the checker by first running a blocking 6s-timed TCP+TLS check,
|
// Start starts the checker by first running a blocking 2s-timed TCP+TLS check,
|
||||||
// and, on success, starts the periodic checks in a separate goroutine:
|
// and, on success, starts the periodic checks in a separate goroutine:
|
||||||
// - a "small" ICMP echo check every minute
|
// - a "small" ICMP echo check every 15 seconds
|
||||||
// - a "full" TCP+TLS check every 5 minutes
|
// - a "full" TCP+TLS check every 5 minutes
|
||||||
// It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed.
|
// It returns a channel `runError` that receives an error (nil or not) when a periodic check is performed.
|
||||||
// It returns an error if the initial TCP+TLS check fails.
|
// It returns an error if the initial TCP+TLS check fails.
|
||||||
// The Checker has to be ultimately stopped by calling [Checker.Stop].
|
// The Checker has to be ultimately stopped by calling [Checker.Stop].
|
||||||
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
|
func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) {
|
||||||
if len(c.tlsDialAddrs) == 0 || len(c.icmpTargetIPs) == 0 || c.smallCheckType == "" {
|
if c.tlsDialAddr == "" || c.icmpTarget.IsUnspecified() {
|
||||||
panic("call Checker.SetConfig with non empty values before Checker.Start")
|
panic("call Checker.SetConfig with non empty values before Checker.Start")
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.icmpNotPermitted {
|
// connection isn't under load yet when the checker starts, so a short
|
||||||
// restore forced check type to dns if icmp was found to be not permitted
|
// 6 seconds timeout suffices and provides quick enough feedback that
|
||||||
c.smallCheckType = smallCheckDNS
|
// the new connection is not working.
|
||||||
}
|
const timeout = 6 * time.Second
|
||||||
|
tcpTLSCheckCtx, tcpTLSCheckCancel := context.WithTimeout(ctx, timeout)
|
||||||
err = c.startupCheck(ctx)
|
err = tcpTLSCheck(tcpTLSCheckCtx, c.dialer, c.tlsDialAddr)
|
||||||
|
tcpTLSCheckCancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("startup check: %w", err)
|
return nil, fmt.Errorf("startup check: %w", err)
|
||||||
}
|
}
|
||||||
@@ -85,6 +83,7 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
|
|||||||
c.stop = cancel
|
c.stop = cancel
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
c.done = done
|
c.done = done
|
||||||
|
c.smallCheckName = "ICMP echo"
|
||||||
const smallCheckPeriod = time.Minute
|
const smallCheckPeriod = time.Minute
|
||||||
smallCheckTimer := time.NewTimer(smallCheckPeriod)
|
smallCheckTimer := time.NewTimer(smallCheckPeriod)
|
||||||
const fullCheckPeriod = 5 * time.Minute
|
const fullCheckPeriod = 5 * time.Minute
|
||||||
@@ -124,16 +123,13 @@ func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error)
|
|||||||
func (c *Checker) Stop() error {
|
func (c *Checker) Stop() error {
|
||||||
c.stop()
|
c.stop()
|
||||||
<-c.done
|
<-c.done
|
||||||
c.tlsDialAddrs = nil
|
c.icmpTarget = netip.Addr{}
|
||||||
c.icmpTargetIPs = nil
|
|
||||||
c.smallCheckType = ""
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Checker) smallPeriodicCheck(ctx context.Context) error {
|
func (c *Checker) smallPeriodicCheck(ctx context.Context) error {
|
||||||
c.configMutex.Lock()
|
c.configMutex.Lock()
|
||||||
icmpTargetIPs := make([]netip.Addr, len(c.icmpTargetIPs))
|
ip := c.icmpTarget
|
||||||
copy(icmpTargetIPs, c.icmpTargetIPs)
|
|
||||||
c.configMutex.Unlock()
|
c.configMutex.Unlock()
|
||||||
tryTimeouts := []time.Duration{
|
tryTimeouts := []time.Duration{
|
||||||
5 * time.Second,
|
5 * time.Second,
|
||||||
@@ -147,31 +143,28 @@ func (c *Checker) smallPeriodicCheck(ctx context.Context) error {
|
|||||||
15 * time.Second,
|
15 * time.Second,
|
||||||
30 * time.Second,
|
30 * time.Second,
|
||||||
}
|
}
|
||||||
check := func(ctx context.Context, try int) error {
|
check := func(ctx context.Context) error {
|
||||||
if c.smallCheckType == smallCheckDNS {
|
if c.icmpNotPermitted {
|
||||||
return c.dnsClient.Check(ctx)
|
return c.dnsClient.Check(ctx)
|
||||||
}
|
}
|
||||||
ip := icmpTargetIPs[try%len(icmpTargetIPs)]
|
|
||||||
err := c.echoer.Echo(ctx, ip)
|
err := c.echoer.Echo(ctx, ip)
|
||||||
if errors.Is(err, icmp.ErrNotPermitted) {
|
if errors.Is(err, icmp.ErrNotPermitted) {
|
||||||
c.icmpNotPermitted = true
|
c.icmpNotPermitted = true
|
||||||
c.smallCheckType = smallCheckDNS
|
c.smallCheckName = "plain DNS over UDP"
|
||||||
c.logger.Infof("%s; permanently falling back to %s checks",
|
c.logger.Infof("%s; permanently falling back to %s checks.", c.smallCheckName, err)
|
||||||
smallCheckTypeToString(c.smallCheckType), err)
|
|
||||||
return c.dnsClient.Check(ctx)
|
return c.dnsClient.Check(ctx)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return withRetries(ctx, tryTimeouts, c.logger, smallCheckTypeToString(c.smallCheckType), check)
|
return withRetries(ctx, tryTimeouts, c.logger, c.smallCheckName, check)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Checker) fullPeriodicCheck(ctx context.Context) error {
|
func (c *Checker) fullPeriodicCheck(ctx context.Context) error {
|
||||||
// 20s timeout in case the connection is under stress
|
// 20s timeout in case the connection is under stress
|
||||||
// See https://github.com/qdm12/gluetun/issues/2270
|
// See https://github.com/qdm12/gluetun/issues/2270
|
||||||
tryTimeouts := []time.Duration{10 * time.Second, 15 * time.Second, 30 * time.Second}
|
tryTimeouts := []time.Duration{10 * time.Second, 15 * time.Second, 30 * time.Second}
|
||||||
check := func(ctx context.Context, try int) error {
|
check := func(ctx context.Context) error {
|
||||||
tlsDialAddr := c.tlsDialAddrs[try%len(c.tlsDialAddrs)]
|
return tcpTLSCheck(ctx, c.dialer, c.tlsDialAddr)
|
||||||
return tcpTLSCheck(ctx, c.dialer, tlsDialAddr)
|
|
||||||
}
|
}
|
||||||
return withRetries(ctx, tryTimeouts, c.logger, "TCP+TLS dial", check)
|
return withRetries(ctx, tryTimeouts, c.logger, "TCP+TLS dial", check)
|
||||||
}
|
}
|
||||||
@@ -233,18 +226,18 @@ func makeAddressToDial(address string) (addressToDial string, err error) {
|
|||||||
var ErrAllCheckTriesFailed = errors.New("all check tries failed")
|
var ErrAllCheckTriesFailed = errors.New("all check tries failed")
|
||||||
|
|
||||||
func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
||||||
logger Logger, checkName string, check func(ctx context.Context, try int) error,
|
logger Logger, checkName string, check func(ctx context.Context) error,
|
||||||
) error {
|
) error {
|
||||||
maxTries := len(tryTimeouts)
|
maxTries := len(tryTimeouts)
|
||||||
type errData struct {
|
type errData struct {
|
||||||
err error
|
err error
|
||||||
durationMS int64
|
duration time.Duration
|
||||||
}
|
}
|
||||||
errs := make([]errData, maxTries)
|
errs := make([]errData, maxTries)
|
||||||
for i, timeout := range tryTimeouts {
|
for i, timeout := range tryTimeouts {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
checkCtx, cancel := context.WithTimeout(ctx, timeout)
|
checkCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
err := check(checkCtx, i)
|
err := check(checkCtx)
|
||||||
cancel()
|
cancel()
|
||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case err == nil:
|
||||||
@@ -254,73 +247,12 @@ func withRetries(ctx context.Context, tryTimeouts []time.Duration,
|
|||||||
}
|
}
|
||||||
logger.Debugf("%s attempt %d/%d failed: %s", checkName, i+1, maxTries, err)
|
logger.Debugf("%s attempt %d/%d failed: %s", checkName, i+1, maxTries, err)
|
||||||
errs[i].err = err
|
errs[i].err = err
|
||||||
errs[i].durationMS = time.Since(start).Round(time.Millisecond).Milliseconds()
|
errs[i].duration = time.Since(start)
|
||||||
}
|
}
|
||||||
|
|
||||||
errStrings := make([]string, len(errs))
|
errStrings := make([]string, len(errs))
|
||||||
for i, err := range errs {
|
for i, err := range errs {
|
||||||
errStrings[i] = fmt.Sprintf("attempt %d (%dms): %s", i+1, err.durationMS, err.err)
|
errStrings[i] = fmt.Sprintf("attempt %d (%s): %s", i+1, err.duration, err.err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
|
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Checker) startupCheck(ctx context.Context) error {
|
|
||||||
// connection isn't under load yet when the checker starts, so a short
|
|
||||||
// 6 seconds timeout suffices and provides quick enough feedback that
|
|
||||||
// the new connection is not working. However, since the addresses to dial
|
|
||||||
// may be multiple, we run the check in parallel. If any succeeds, the check passes.
|
|
||||||
// This is to prevent false negatives at startup, if one of the addresses is down
|
|
||||||
// for external reasons.
|
|
||||||
const timeout = 6 * time.Second
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
||||||
defer cancel()
|
|
||||||
errCh := make(chan error)
|
|
||||||
|
|
||||||
for _, address := range c.tlsDialAddrs {
|
|
||||||
go func(addr string) {
|
|
||||||
err := tcpTLSCheck(ctx, c.dialer, addr)
|
|
||||||
errCh <- err
|
|
||||||
}(address)
|
|
||||||
}
|
|
||||||
|
|
||||||
errs := make([]error, 0, len(c.tlsDialAddrs))
|
|
||||||
success := false
|
|
||||||
for range c.tlsDialAddrs {
|
|
||||||
err := <-errCh
|
|
||||||
if err == nil {
|
|
||||||
success = true
|
|
||||||
cancel()
|
|
||||||
continue
|
|
||||||
} else if success {
|
|
||||||
continue // ignore canceled errors after success
|
|
||||||
}
|
|
||||||
|
|
||||||
c.logger.Debugf("startup check parallel attempt failed: %s", err)
|
|
||||||
errs = append(errs, err)
|
|
||||||
}
|
|
||||||
if success {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
errStrings := make([]string, len(errs))
|
|
||||||
for i, err := range errs {
|
|
||||||
errStrings[i] = fmt.Sprintf("parallel attempt %d/%d failed: %s", i+1, len(errs), err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("%w: %s", ErrAllCheckTriesFailed, strings.Join(errStrings, ", "))
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
smallCheckDNS = "dns"
|
|
||||||
smallCheckICMP = "icmp"
|
|
||||||
)
|
|
||||||
|
|
||||||
func smallCheckTypeToString(smallCheckType string) string {
|
|
||||||
switch smallCheckType {
|
|
||||||
case smallCheckICMP:
|
|
||||||
return "ICMP echo"
|
|
||||||
case smallCheckDNS:
|
|
||||||
return "plain DNS over UDP"
|
|
||||||
default:
|
|
||||||
panic("unknown small check type: " + smallCheckType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ func Test_Checker_fullcheck(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
dialer := &net.Dialer{}
|
dialer := &net.Dialer{}
|
||||||
addresses := []string{"badaddress:9876", "cloudflare.com:443", "google.com:443"}
|
const address = "cloudflare.com:443"
|
||||||
|
|
||||||
checker := &Checker{
|
checker := &Checker{
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
tlsDialAddrs: addresses,
|
tlsDialAddr: address,
|
||||||
}
|
}
|
||||||
|
|
||||||
canceledCtx, cancel := context.WithCancel(context.Background())
|
canceledCtx, cancel := context.WithCancel(context.Background())
|
||||||
@@ -52,8 +52,8 @@ func Test_Checker_fullcheck(t *testing.T) {
|
|||||||
|
|
||||||
dialer := &net.Dialer{}
|
dialer := &net.Dialer{}
|
||||||
checker := &Checker{
|
checker := &Checker{
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
tlsDialAddrs: []string{listeningAddress.String()},
|
tlsDialAddr: listeningAddress.String(),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = checker.fullPeriodicCheck(ctx)
|
err = checker.fullPeriodicCheck(ctx)
|
||||||
|
|||||||
@@ -56,10 +56,10 @@ func (c *Client) Check(ctx context.Context) error {
|
|||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
|
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
|
||||||
return fmt.Errorf("with DNS server %s: %w", dnsAddr, err)
|
return err
|
||||||
case len(ips) == 0:
|
case len(ips) == 0:
|
||||||
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
|
c.dnsIPIndex = (c.dnsIPIndex + 1) % len(c.serverAddrs)
|
||||||
return fmt.Errorf("with DNS server %s: %w", dnsAddr, ErrLookupNoIPs)
|
return fmt.Errorf("%w", ErrLookupNoIPs)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,20 +82,20 @@ func (i *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) {
|
|||||||
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
if strings.HasSuffix(err.Error(), "sendto: operation not permitted") {
|
||||||
err = fmt.Errorf("%w", ErrNotPermitted)
|
err = fmt.Errorf("%w", ErrNotPermitted)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("writing ICMP message to %s: %w", ip, err)
|
return fmt.Errorf("writing ICMP message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
receivedData, err := receiveEchoReply(conn, id, i.buffer, ipVersion, i.logger)
|
receivedData, err := receiveEchoReply(conn, id, i.buffer, ipVersion, i.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
|
if errors.Is(err, net.ErrClosed) && ctx.Err() != nil {
|
||||||
return fmt.Errorf("%w from %s", ErrTimedOut, ip)
|
return fmt.Errorf("%w", ErrTimedOut)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("receiving ICMP echo reply from %s: %w", ip, err)
|
return fmt.Errorf("receiving ICMP echo reply: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
|
sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert
|
||||||
if !bytes.Equal(receivedData, sentData) {
|
if !bytes.Equal(receivedData, sentData) {
|
||||||
return fmt.Errorf("%w: sent %x to %s and received %x", ErrICMPEchoDataMismatch, sentData, ip, receivedData)
|
return fmt.Errorf("%w: sent %x and received %x", ErrICMPEchoDataMismatch, sentData, receivedData)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func runCommand(ctx context.Context, cmder Cmder, logger Logger,
|
func runCommand(ctx context.Context, cmder Cmder, logger Logger,
|
||||||
commandTemplate string, ports []uint16, vpnInterface string,
|
commandTemplate string, ports []uint16,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
portStrings := make([]string, len(ports))
|
portStrings := make([]string, len(ports))
|
||||||
for i, port := range ports {
|
for i, port := range ports {
|
||||||
@@ -19,7 +19,6 @@ func runCommand(ctx context.Context, cmder Cmder, logger Logger,
|
|||||||
portsString := strings.Join(portStrings, ",")
|
portsString := strings.Join(portStrings, ",")
|
||||||
commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString)
|
commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString)
|
||||||
commandString = strings.ReplaceAll(commandString, "{{PORT}}", portStrings[0])
|
commandString = strings.ReplaceAll(commandString, "{{PORT}}", portStrings[0])
|
||||||
commandString = strings.ReplaceAll(commandString, "{{VPN_INTERFACE}}", vpnInterface)
|
|
||||||
args, err := command.Split(commandString)
|
args, err := command.Split(commandString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing command: %w", err)
|
return fmt.Errorf("parsing command: %w", err)
|
||||||
|
|||||||
@@ -17,13 +17,12 @@ func Test_Service_runCommand(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
cmder := command.New()
|
cmder := command.New()
|
||||||
const commandTemplate = `/bin/sh -c "echo {{PORTS}}-{{PORT}}-{{VPN_INTERFACE}}"`
|
const commandTemplate = `/bin/sh -c "echo {{PORTS}}"`
|
||||||
ports := []uint16{1234, 5678}
|
ports := []uint16{1234, 5678}
|
||||||
const vpnInterface = "tun0"
|
|
||||||
logger := NewMockLogger(ctrl)
|
logger := NewMockLogger(ctrl)
|
||||||
logger.EXPECT().Info("1234,5678-1234-tun0")
|
logger.EXPECT().Info("1234,5678")
|
||||||
|
|
||||||
err := runCommand(ctx, cmder, logger, commandTemplate, ports, vpnInterface)
|
err := runCommand(ctx, cmder, logger, commandTemplate, ports)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
|
|||||||
s.portMutex.Unlock()
|
s.portMutex.Unlock()
|
||||||
|
|
||||||
if s.settings.UpCommand != "" {
|
if s.settings.UpCommand != "" {
|
||||||
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports, s.settings.Interface)
|
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("running up command: %w", err)
|
err = fmt.Errorf("running up command: %w", err)
|
||||||
s.logger.Error(err.Error())
|
s.logger.Error(err.Error())
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func (s *Service) cleanup() (err error) {
|
|||||||
const downTimeout = 60 * time.Second
|
const downTimeout = 60 * time.Second
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports, s.settings.Interface)
|
err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("running down command: %w", err)
|
err = fmt.Errorf("running down command: %w", err)
|
||||||
s.logger.Error(err.Error())
|
s.logger.Error(err.Error())
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ type Provider struct {
|
|||||||
|
|
||||||
func New(storage common.Storage, randSource rand.Source,
|
func New(storage common.Storage, randSource rand.Source,
|
||||||
client *http.Client, updaterWarner common.Warner,
|
client *http.Client, updaterWarner common.Warner,
|
||||||
email, password string,
|
username, password string,
|
||||||
) *Provider {
|
) *Provider {
|
||||||
return &Provider{
|
return &Provider{
|
||||||
storage: storage,
|
storage: storage,
|
||||||
randSource: randSource,
|
randSource: randSource,
|
||||||
Fetcher: updater.New(client, updaterWarner, email, password),
|
Fetcher: updater.New(client, updaterWarner, username, password),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func (c *apiClient) setHeaders(request *http.Request, cookie cookie) {
|
|||||||
|
|
||||||
// authenticate performs the full Proton authentication flow
|
// authenticate performs the full Proton authentication flow
|
||||||
// to obtain an authenticated cookie (uid, token and session ID).
|
// to obtain an authenticated cookie (uid, token and session ID).
|
||||||
func (c *apiClient) authenticate(ctx context.Context, email, password string,
|
func (c *apiClient) authenticate(ctx context.Context, username, password string,
|
||||||
) (authCookie cookie, err error) {
|
) (authCookie cookie, err error) {
|
||||||
sessionID, err := c.getSessionID(ctx)
|
sessionID, err := c.getSessionID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -98,8 +98,8 @@ func (c *apiClient) authenticate(ctx context.Context, email, password string,
|
|||||||
token: cookieToken,
|
token: cookieToken,
|
||||||
sessionID: sessionID,
|
sessionID: sessionID,
|
||||||
}
|
}
|
||||||
username, modulusPGPClearSigned, serverEphemeralBase64, saltBase64,
|
modulusPGPClearSigned, serverEphemeralBase64, saltBase64,
|
||||||
srpSessionHex, version, err := c.authInfo(ctx, email, unauthCookie)
|
srpSessionHex, version, err := c.authInfo(ctx, username, unauthCookie)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cookie{}, fmt.Errorf("getting auth information: %w", err)
|
return cookie{}, fmt.Errorf("getting auth information: %w", err)
|
||||||
}
|
}
|
||||||
@@ -118,7 +118,7 @@ func (c *apiClient) authenticate(ctx context.Context, email, password string,
|
|||||||
return cookie{}, fmt.Errorf("generating SRP proofs: %w", err)
|
return cookie{}, fmt.Errorf("generating SRP proofs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authCookie, err = c.auth(ctx, unauthCookie, email, srpSessionHex, proofs)
|
authCookie, err = c.auth(ctx, unauthCookie, username, srpSessionHex, proofs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cookie{}, fmt.Errorf("authentifying: %w", err)
|
return cookie{}, fmt.Errorf("authentifying: %w", err)
|
||||||
}
|
}
|
||||||
@@ -299,45 +299,48 @@ func (c *apiClient) cookieToken(ctx context.Context, sessionID, tokenType, acces
|
|||||||
return "", fmt.Errorf("%w", ErrAuthCookieNotFound)
|
return "", fmt.Errorf("%w", ErrAuthCookieNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrUsernameDoesNotExist = errors.New("username does not exist")
|
var (
|
||||||
|
ErrUsernameDoesNotExist = errors.New("username does not exist")
|
||||||
|
ErrUsernameMismatch = errors.New("username in response does not match request username")
|
||||||
|
)
|
||||||
|
|
||||||
// authInfo fetches SRP parameters for the account.
|
// authInfo fetches SRP parameters for the account.
|
||||||
func (c *apiClient) authInfo(ctx context.Context, email string, unauthCookie cookie) (
|
func (c *apiClient) authInfo(ctx context.Context, username string, unauthCookie cookie) (
|
||||||
username, modulusPGPClearSigned, serverEphemeralBase64, saltBase64, srpSessionHex string,
|
modulusPGPClearSigned, serverEphemeralBase64, saltBase64, srpSessionHex string,
|
||||||
version int, err error,
|
version int, err error,
|
||||||
) {
|
) {
|
||||||
type requestBodySchema struct {
|
type requestBodySchema struct {
|
||||||
Intent string `json:"Intent"` // "Proton"
|
Intent string `json:"Intent"` // "Proton"
|
||||||
Username string `json:"Username"`
|
Username string `json:"Username"` // username without @domain.com
|
||||||
}
|
}
|
||||||
requestBody := requestBodySchema{
|
requestBody := requestBodySchema{
|
||||||
Intent: "Proton",
|
Intent: "Proton",
|
||||||
Username: email,
|
Username: username,
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer := bytes.NewBuffer(nil)
|
buffer := bytes.NewBuffer(nil)
|
||||||
encoder := json.NewEncoder(buffer)
|
encoder := json.NewEncoder(buffer)
|
||||||
if err := encoder.Encode(requestBody); err != nil {
|
if err := encoder.Encode(requestBody); err != nil {
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("encoding request body: %w", err)
|
return "", "", "", "", 0, fmt.Errorf("encoding request body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/core/v4/auth/info", buffer)
|
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURLBase+"/core/v4/auth/info", buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("creating request: %w", err)
|
return "", "", "", "", 0, fmt.Errorf("creating request: %w", err)
|
||||||
}
|
}
|
||||||
c.setHeaders(request, unauthCookie)
|
c.setHeaders(request, unauthCookie)
|
||||||
|
|
||||||
response, err := c.httpClient.Do(request)
|
response, err := c.httpClient.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", "", "", 0, err
|
return "", "", "", "", 0, err
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(response.Body)
|
responseBody, err := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("reading response body: %w", err)
|
return "", "", "", "", 0, fmt.Errorf("reading response body: %w", err)
|
||||||
} else if response.StatusCode != http.StatusOK {
|
} else if response.StatusCode != http.StatusOK {
|
||||||
return "", "", "", "", "", 0, buildError(response.StatusCode, responseBody)
|
return "", "", "", "", 0, buildError(response.StatusCode, responseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
var info struct {
|
var info struct {
|
||||||
@@ -351,30 +354,32 @@ func (c *apiClient) authInfo(ctx context.Context, email string, unauthCookie coo
|
|||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &info)
|
err = json.Unmarshal(responseBody, &info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("decoding response body: %w", err)
|
return "", "", "", "", 0, fmt.Errorf("decoding response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const successCode = 1000
|
const successCode = 1000
|
||||||
switch {
|
switch {
|
||||||
case info.Code != successCode:
|
case info.Code != successCode:
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("%w: expected %d got %d",
|
return "", "", "", "", 0, fmt.Errorf("%w: expected %d got %d",
|
||||||
ErrCodeNotSuccess, successCode, info.Code)
|
ErrCodeNotSuccess, successCode, info.Code)
|
||||||
case info.Modulus == "":
|
case info.Modulus == "":
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("%w: modulus is empty", ErrDataFieldMissing)
|
return "", "", "", "", 0, fmt.Errorf("%w: modulus is empty", ErrDataFieldMissing)
|
||||||
case info.ServerEphemeral == "":
|
case info.ServerEphemeral == "":
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("%w: server ephemeral is empty", ErrDataFieldMissing)
|
return "", "", "", "", 0, fmt.Errorf("%w: server ephemeral is empty", ErrDataFieldMissing)
|
||||||
case info.Salt == "":
|
case info.Salt == "":
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("%w (salt data field is empty)", ErrUsernameDoesNotExist)
|
return "", "", "", "", 0, fmt.Errorf("%w (salt data field is empty)", ErrUsernameDoesNotExist)
|
||||||
case info.SRPSession == "":
|
case info.SRPSession == "":
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("%w: SRP session is empty", ErrDataFieldMissing)
|
return "", "", "", "", 0, fmt.Errorf("%w: SRP session is empty", ErrDataFieldMissing)
|
||||||
case info.Username == "":
|
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("%w: username is empty", ErrDataFieldMissing)
|
case !strings.EqualFold(info.Username, username):
|
||||||
|
return "", "", "", "", 0, fmt.Errorf("%w: expected %s got %s",
|
||||||
|
ErrUsernameMismatch, username, info.Username)
|
||||||
case info.Version == nil:
|
case info.Version == nil:
|
||||||
return "", "", "", "", "", 0, fmt.Errorf("%w: version is missing", ErrDataFieldMissing)
|
return "", "", "", "", 0, fmt.Errorf("%w: version is missing", ErrDataFieldMissing)
|
||||||
}
|
}
|
||||||
|
|
||||||
version = int(*info.Version) //nolint:gosec
|
version = int(*info.Version) //nolint:gosec
|
||||||
return info.Username, info.Modulus, info.ServerEphemeral, info.Salt,
|
return info.Modulus, info.ServerEphemeral, info.Salt,
|
||||||
info.SRPSession, version, nil
|
info.SRPSession, version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
|||||||
servers []models.Server, err error,
|
servers []models.Server, err error,
|
||||||
) {
|
) {
|
||||||
switch {
|
switch {
|
||||||
case u.email == "":
|
case u.username == "":
|
||||||
return nil, fmt.Errorf("%w: email is empty", common.ErrCredentialsMissing)
|
return nil, fmt.Errorf("%w: username is empty", common.ErrCredentialsMissing)
|
||||||
case u.password == "":
|
case u.password == "":
|
||||||
return nil, fmt.Errorf("%w: password is empty", common.ErrCredentialsMissing)
|
return nil, fmt.Errorf("%w: password is empty", common.ErrCredentialsMissing)
|
||||||
}
|
}
|
||||||
@@ -25,7 +25,7 @@ func (u *Updater) FetchServers(ctx context.Context, minServers int) (
|
|||||||
return nil, fmt.Errorf("creating API client: %w", err)
|
return nil, fmt.Errorf("creating API client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := apiClient.authenticate(ctx, u.email, u.password)
|
cookie, err := apiClient.authenticate(ctx, u.username, u.password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authentifying with Proton: %w", err)
|
return nil, fmt.Errorf("authentifying with Proton: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,15 +8,15 @@ import (
|
|||||||
|
|
||||||
type Updater struct {
|
type Updater struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
email string
|
username string
|
||||||
password string
|
password string
|
||||||
warner common.Warner
|
warner common.Warner
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(client *http.Client, warner common.Warner, email, password string) *Updater {
|
func New(client *http.Client, warner common.Warner, username, password string) *Updater {
|
||||||
return &Updater{
|
return &Updater{
|
||||||
client: client,
|
client: client,
|
||||||
email: email,
|
username: username,
|
||||||
password: password,
|
password: password,
|
||||||
warner: warner,
|
warner: warner,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func NewProviders(storage Storage, timeNow func() time.Time,
|
|||||||
providers.Privado: privado.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver),
|
providers.Privado: privado.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver),
|
||||||
providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client),
|
providers.PrivateInternetAccess: privateinternetaccess.New(storage, randSource, timeNow, client),
|
||||||
providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
|
providers.Privatevpn: privatevpn.New(storage, randSource, unzipper, updaterWarner, parallelResolver),
|
||||||
providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, *credentials.ProtonEmail, *credentials.ProtonPassword),
|
providers.Protonvpn: protonvpn.New(storage, randSource, client, updaterWarner, *credentials.ProtonUsername, *credentials.ProtonPassword),
|
||||||
providers.Purevpn: purevpn.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver),
|
providers.Purevpn: purevpn.New(storage, randSource, ipFetcher, unzipper, updaterWarner, parallelResolver),
|
||||||
providers.SlickVPN: slickvpn.New(storage, randSource, client, updaterWarner, parallelResolver),
|
providers.SlickVPN: slickvpn.New(storage, randSource, client, updaterWarner, parallelResolver),
|
||||||
providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver),
|
providers.Surfshark: surfshark.New(storage, randSource, client, unzipper, updaterWarner, parallelResolver),
|
||||||
|
|||||||
@@ -26,25 +26,16 @@ func pickConnection(connections []models.Connection,
|
|||||||
return connection, ErrNoConnectionToPickFrom
|
return connection, ErrNoConnectionToPickFrom
|
||||||
}
|
}
|
||||||
|
|
||||||
var targetIP netip.Addr
|
targetIPSet := selection.TargetIP.IsValid() && !selection.TargetIP.IsUnspecified()
|
||||||
switch selection.VPN {
|
|
||||||
case vpn.OpenVPN:
|
|
||||||
targetIP = selection.OpenVPN.EndpointIP
|
|
||||||
case vpn.Wireguard:
|
|
||||||
targetIP = selection.Wireguard.EndpointIP
|
|
||||||
default:
|
|
||||||
panic("unknown VPN type: " + selection.VPN)
|
|
||||||
}
|
|
||||||
targetIPSet := targetIP.IsValid() && !targetIP.IsUnspecified()
|
|
||||||
|
|
||||||
if targetIPSet && selection.VPN == vpn.Wireguard {
|
if targetIPSet && selection.VPN == vpn.Wireguard {
|
||||||
// we need the right public key
|
// we need the right public key
|
||||||
return getTargetIPConnection(connections, targetIP)
|
return getTargetIPConnection(connections, selection.TargetIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
connection = pickRandomConnection(connections, randSource)
|
connection = pickRandomConnection(connections, randSource)
|
||||||
if targetIPSet {
|
if targetIPSet {
|
||||||
connection.IP = targetIP
|
connection.IP = selection.TargetIP
|
||||||
}
|
}
|
||||||
|
|
||||||
return connection, nil
|
return connection, nil
|
||||||
|
|||||||
@@ -21,9 +21,6 @@ func (s *Storage) FlushToFile(path string) error {
|
|||||||
// flushToFile flushes the merged servers data to the file
|
// flushToFile flushes the merged servers data to the file
|
||||||
// specified by path, as indented JSON. It is not thread-safe.
|
// specified by path, as indented JSON. It is not thread-safe.
|
||||||
func (s *Storage) flushToFile(path string) error {
|
func (s *Storage) flushToFile(path string) error {
|
||||||
if path == "" {
|
|
||||||
return nil // no file to write to
|
|
||||||
}
|
|
||||||
const permission = 0o644
|
const permission = 0o644
|
||||||
dirPath := filepath.Dir(path)
|
dirPath := filepath.Dir(path)
|
||||||
if err := os.MkdirAll(dirPath, permission); err != nil {
|
if err := os.MkdirAll(dirPath, permission); err != nil {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func commaJoin(slice []string) string {
|
func commaJoin(slice []string) string {
|
||||||
@@ -149,13 +148,9 @@ func noServerFoundError(selection settings.ServerSelection) (err error) {
|
|||||||
messageParts = append(messageParts, "tor only")
|
messageParts = append(messageParts, "tor only")
|
||||||
}
|
}
|
||||||
|
|
||||||
targetIP := selection.OpenVPN.EndpointIP
|
if selection.TargetIP.IsValid() {
|
||||||
if selection.VPN == vpn.Wireguard {
|
|
||||||
targetIP = selection.Wireguard.EndpointIP
|
|
||||||
}
|
|
||||||
if targetIP.IsValid() {
|
|
||||||
messageParts = append(messageParts,
|
messageParts = append(messageParts,
|
||||||
"target ip address "+targetIP.String())
|
"target ip address "+selection.TargetIP.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
message := "for " + strings.Join(messageParts, "; ")
|
message := "for " + strings.Join(messageParts, "; ")
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -101,7 +101,7 @@ type CmdStarter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HealthChecker interface {
|
type HealthChecker interface {
|
||||||
SetConfig(tlsDialAddrs []string, icmpTargetIPs []netip.Addr, smallCheckType string)
|
SetConfig(tlsDialAddr string, icmpTarget netip.Addr)
|
||||||
Start(ctx context.Context) (runError <-chan error, err error)
|
Start(ctx context.Context) (runError <-chan error, err error)
|
||||||
Stop() error
|
Stop() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/check"
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/version"
|
"github.com/qdm12/gluetun/internal/version"
|
||||||
)
|
)
|
||||||
@@ -31,34 +30,22 @@ func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
icmpTargetIPs := l.healthSettings.ICMPTargetIPs
|
icmpTarget := l.healthSettings.ICMPTargetIP
|
||||||
if len(icmpTargetIPs) == 1 && icmpTargetIPs[0].IsUnspecified() {
|
if icmpTarget.IsUnspecified() {
|
||||||
icmpTargetIPs = []netip.Addr{data.serverIP}
|
icmpTarget = data.serverIP
|
||||||
}
|
}
|
||||||
l.healthChecker.SetConfig(l.healthSettings.TargetAddresses, icmpTargetIPs,
|
l.healthChecker.SetConfig(l.healthSettings.TargetAddress, icmpTarget)
|
||||||
l.healthSettings.SmallCheckType)
|
|
||||||
|
|
||||||
healthErrCh, err := l.healthChecker.Start(ctx)
|
healthErrCh, err := l.healthChecker.Start(ctx)
|
||||||
l.healthServer.SetError(err)
|
l.healthServer.SetError(err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if *l.healthSettings.RestartVPN {
|
// Note this restart call must be done in a separate goroutine
|
||||||
// Note this restart call must be done in a separate goroutine
|
// from the VPN loop goroutine.
|
||||||
// from the VPN loop goroutine.
|
l.restartVPN(loopCtx, err)
|
||||||
l.restartVPN(loopCtx, err)
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
l.logger.Warnf("(ignored) healthchecker start failed: %s", err)
|
|
||||||
l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if *l.dnsLooper.GetSettings().ServerEnabled {
|
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
|
||||||
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
|
|
||||||
} else {
|
|
||||||
err := check.WaitForDNS(ctx, check.Settings{})
|
|
||||||
if err != nil {
|
|
||||||
l.logger.Error("waiting for DNS to be ready: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = l.publicip.RunOnce(ctx)
|
err = l.publicip.RunOnce(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -100,7 +87,7 @@ func (l *Loop) collectHealthErrors(ctx, loopCtx context.Context, healthErrCh <-c
|
|||||||
l.restartVPN(loopCtx, healthErr)
|
l.restartVPN(loopCtx, healthErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
l.logger.Warnf("(ignored) healthcheck failed: %s", healthErr)
|
l.logger.Warnf("healthcheck failed: %s", healthErr)
|
||||||
l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
|
l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
|
||||||
} else if previousHealthErr != nil {
|
} else if previousHealthErr != nil {
|
||||||
l.logger.Info("healthcheck passed successfully after previous failure(s)")
|
l.logger.Info("healthcheck passed successfully after previous failure(s)")
|
||||||
|
|||||||
Reference in New Issue
Block a user