feat(server): role based authentication system (#2434)
- Parse toml configuration file, see https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/control-server.md#authentication - Retro-compatible with existing AND documented routes, until after v3.41 release - Log a warning if an unprotected-by-default route is accessed unprotected - Authentication methods: none, apikey, basic - `genkey` command to generate API keys - move log middleware to internal/server/middlewares/log Co-authored-by: Joe Jose <45399349+joejose97@users.noreply.github.com>
This commit is contained in:
@@ -197,6 +197,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
# Control server
|
||||
HTTP_CONTROL_SERVER_LOG=on \
|
||||
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
||||
HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \
|
||||
# Server data updater
|
||||
UPDATER_PERIOD=0 \
|
||||
UPDATER_MIN_RATIO=0.8 \
|
||||
|
||||
@@ -161,12 +161,14 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
return cli.Update(ctx, args[2:], logger)
|
||||
case "format-servers":
|
||||
return cli.FormatServers(args[2:])
|
||||
case "genkey":
|
||||
return cli.GenKey(args[2:])
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
||||
}
|
||||
}
|
||||
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2023-07-01T00:00:00Z")
|
||||
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -177,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
Version: buildInfo.Version,
|
||||
Commit: buildInfo.Commit,
|
||||
Created: buildInfo.Created,
|
||||
Announcement: "Wiki moved to https://github.com/qdm12/gluetun-wiki",
|
||||
Announcement: "All control server routes will become private by default after the v3.41.0 release",
|
||||
AnnounceExp: announcementExp,
|
||||
// Sponsor information
|
||||
PaypalUser: "qmcgaw",
|
||||
@@ -474,6 +476,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
||||
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
|
||||
logger.New(log.SetComponent("http server")),
|
||||
allSettings.ControlServer.AuthFilePath,
|
||||
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
|
||||
storage, ipv6Supported)
|
||||
if err != nil {
|
||||
@@ -595,6 +598,7 @@ type clier interface {
|
||||
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
|
||||
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
|
||||
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
|
||||
GenKey(args []string) error
|
||||
}
|
||||
|
||||
type Tun interface {
|
||||
|
||||
1
go.mod
1
go.mod
@@ -8,6 +8,7 @@ require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/klauspost/compress v1.17.8
|
||||
github.com/klauspost/pgzip v1.2.6
|
||||
github.com/pelletier/go-toml/v2 v2.2.2
|
||||
github.com/qdm12/dns v1.11.0
|
||||
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
|
||||
github.com/qdm12/gosettings v0.4.2
|
||||
|
||||
8
go.sum
8
go.sum
@@ -83,6 +83,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
||||
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -113,10 +115,16 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
|
||||
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=
|
||||
|
||||
66
internal/cli/genkey.go
Normal file
66
internal/cli/genkey.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"flag"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (c *CLI) GenKey(args []string) (err error) {
|
||||
flagSet := flag.NewFlagSet("genkey", flag.ExitOnError)
|
||||
err = flagSet.Parse(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing flags: %w", err)
|
||||
}
|
||||
|
||||
const keyLength = 128 / 8
|
||||
keyBytes := make([]byte, keyLength)
|
||||
|
||||
_, _ = rand.Read(keyBytes)
|
||||
|
||||
key := base58Encode(keyBytes)
|
||||
fmt.Println(key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func base58Encode(data []byte) string {
|
||||
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||
const radix = 58
|
||||
|
||||
zcount := 0
|
||||
for zcount < len(data) && data[zcount] == 0 {
|
||||
zcount++
|
||||
}
|
||||
|
||||
// integer simplification of ceil(log(256)/log(58))
|
||||
ceilLog256Div58 := (len(data)-zcount)*555/406 + 1 //nolint:gomnd
|
||||
size := zcount + ceilLog256Div58
|
||||
|
||||
output := make([]byte, size)
|
||||
|
||||
high := size - 1
|
||||
for _, b := range data {
|
||||
i := size - 1
|
||||
for carry := uint32(b); i > high || carry != 0; i-- {
|
||||
carry += 256 * uint32(output[i]) //nolint:gomnd
|
||||
output[i] = byte(carry % radix)
|
||||
carry /= radix
|
||||
}
|
||||
high = i
|
||||
}
|
||||
|
||||
// Determine the additional "zero-gap" in the output buffer
|
||||
additionalZeroGapEnd := zcount
|
||||
for additionalZeroGapEnd < size && output[additionalZeroGapEnd] == 0 {
|
||||
additionalZeroGapEnd++
|
||||
}
|
||||
|
||||
val := output[additionalZeroGapEnd-zcount:]
|
||||
size = len(val)
|
||||
for i := range val {
|
||||
output[i] = alphabet[val[i]]
|
||||
}
|
||||
|
||||
return string(output[:size])
|
||||
}
|
||||
@@ -19,6 +19,11 @@ type ControlServer struct {
|
||||
// Log can be true or false to enable logging on requests.
|
||||
// It cannot be nil in the internal state.
|
||||
Log *bool
|
||||
// AuthFilePath is the path to the file containing the authentication
|
||||
// configuration for the middleware.
|
||||
// It cannot be empty in the internal state and defaults to
|
||||
// /gluetun/auth/config.toml.
|
||||
AuthFilePath string
|
||||
}
|
||||
|
||||
func (c ControlServer) validate() (err error) {
|
||||
@@ -44,8 +49,9 @@ func (c ControlServer) validate() (err error) {
|
||||
|
||||
func (c *ControlServer) copy() (copied ControlServer) {
|
||||
return ControlServer{
|
||||
Address: gosettings.CopyPointer(c.Address),
|
||||
Log: gosettings.CopyPointer(c.Log),
|
||||
Address: gosettings.CopyPointer(c.Address),
|
||||
Log: gosettings.CopyPointer(c.Log),
|
||||
AuthFilePath: c.AuthFilePath,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,11 +61,13 @@ func (c *ControlServer) copy() (copied ControlServer) {
|
||||
func (c *ControlServer) overrideWith(other ControlServer) {
|
||||
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
|
||||
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
|
||||
c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath)
|
||||
}
|
||||
|
||||
func (c *ControlServer) setDefaults() {
|
||||
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
|
||||
c.Log = gosettings.DefaultPointer(c.Log, true)
|
||||
c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml")
|
||||
}
|
||||
|
||||
func (c ControlServer) String() string {
|
||||
@@ -70,6 +78,7 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) {
|
||||
node = gotree.New("Control server settings:")
|
||||
node.Appendf("Listening address: %s", *c.Address)
|
||||
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
|
||||
node.Appendf("Authentication file path: %s", c.AuthFilePath)
|
||||
return node
|
||||
}
|
||||
|
||||
@@ -78,6 +87,10 @@ func (c *ControlServer) read(r *reader.Reader) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
|
||||
|
||||
c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -78,7 +78,8 @@ func Test_Settings_String(t *testing.T) {
|
||||
| └── Enabled: no
|
||||
├── Control server settings:
|
||||
| ├── Listening address: :8000
|
||||
| └── Logging: yes
|
||||
| ├── Logging: yes
|
||||
| └── Authentication file path: /gluetun/auth/config.toml
|
||||
├── OS Alpine settings:
|
||||
| ├── Process UID: 1000
|
||||
| └── Process GID: 1000
|
||||
|
||||
@@ -2,13 +2,17 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/log"
|
||||
)
|
||||
|
||||
func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
||||
func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||
authSettings auth.Settings,
|
||||
buildInfo models.BuildInformation,
|
||||
vpnLooper VPNLooper,
|
||||
pfGetter PortForwardedGetter,
|
||||
@@ -17,7 +21,7 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
||||
publicIPLooper PublicIPLoop,
|
||||
storage Storage,
|
||||
ipv6Supported bool,
|
||||
) http.Handler {
|
||||
) (httpHandler http.Handler, err error) {
|
||||
handler := &handler{}
|
||||
|
||||
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
|
||||
@@ -29,16 +33,25 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
||||
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
|
||||
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
|
||||
|
||||
handlerWithLog := withLogMiddleware(handler, logger, logging)
|
||||
handler.setLogEnabled = handlerWithLog.setEnabled
|
||||
authMiddleware, err := auth.New(authSettings, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating auth middleware: %w", err)
|
||||
}
|
||||
|
||||
return handlerWithLog
|
||||
middlewares := []func(http.Handler) http.Handler{
|
||||
authMiddleware,
|
||||
log.New(logger, logging),
|
||||
}
|
||||
httpHandler = handler
|
||||
for _, middleware := range middlewares {
|
||||
httpHandler = middleware(httpHandler)
|
||||
}
|
||||
return httpHandler, nil
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
v0 http.Handler
|
||||
v1 http.Handler
|
||||
setLogEnabled func(enabled bool)
|
||||
v0 http.Handler
|
||||
v1 http.Handler
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package server
|
||||
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...any)
|
||||
infoer
|
||||
warner
|
||||
Warnf(format string, args ...any)
|
||||
errorer
|
||||
}
|
||||
|
||||
|
||||
36
internal/server/middlewares/auth/apikey.go
Normal file
36
internal/server/middlewares/auth/apikey.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type apiKeyMethod struct {
|
||||
apiKeyDigest [32]byte
|
||||
}
|
||||
|
||||
func newAPIKeyMethod(apiKey string) *apiKeyMethod {
|
||||
return &apiKeyMethod{
|
||||
apiKeyDigest: sha256.Sum256([]byte(apiKey)),
|
||||
}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (a *apiKeyMethod) equal(other authorizationChecker) bool {
|
||||
otherTokenMethod, ok := other.(*apiKeyMethod)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return a.apiKeyDigest == otherTokenMethod.apiKeyDigest
|
||||
}
|
||||
|
||||
func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool {
|
||||
xAPIKey := request.Header.Get("X-API-Key")
|
||||
if xAPIKey == "" {
|
||||
xAPIKey = request.URL.Query().Get("api_key")
|
||||
}
|
||||
xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey))
|
||||
return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1
|
||||
}
|
||||
37
internal/server/middlewares/auth/basic.go
Normal file
37
internal/server/middlewares/auth/basic.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type basicAuthMethod struct {
|
||||
authDigest [32]byte
|
||||
}
|
||||
|
||||
func newBasicAuthMethod(username, password string) *basicAuthMethod {
|
||||
return &basicAuthMethod{
|
||||
authDigest: sha256.Sum256([]byte(username + password)),
|
||||
}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (a *basicAuthMethod) equal(other authorizationChecker) bool {
|
||||
otherBasicMethod, ok := other.(*basicAuthMethod)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return a.authDigest == otherBasicMethod.authDigest
|
||||
}
|
||||
|
||||
func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool {
|
||||
username, password, ok := request.BasicAuth()
|
||||
if !ok {
|
||||
headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
return false
|
||||
}
|
||||
requestAuthDigest := sha256.Sum256([]byte(username + password))
|
||||
return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1
|
||||
}
|
||||
35
internal/server/middlewares/auth/configfile.go
Normal file
35
internal/server/middlewares/auth/configfile.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
// If the file does not exist, it returns empty settings and no error.
|
||||
func Read(filepath string) (settings Settings, err error) {
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return Settings{}, nil
|
||||
}
|
||||
return settings, fmt.Errorf("opening file: %w", err)
|
||||
}
|
||||
decoder := toml.NewDecoder(file)
|
||||
decoder.DisallowUnknownFields()
|
||||
err = decoder.Decode(&settings)
|
||||
if err == nil {
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
strictErr := new(toml.StrictMissingError)
|
||||
ok := errors.As(err, &strictErr)
|
||||
if !ok {
|
||||
return settings, fmt.Errorf("toml decoding file: %w", err)
|
||||
}
|
||||
return settings, fmt.Errorf("toml decoding file: %w:\n%s",
|
||||
strictErr, strictErr.String())
|
||||
}
|
||||
80
internal/server/middlewares/auth/configfile_test.go
Normal file
80
internal/server/middlewares/auth/configfile_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
func Test_Read(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
fileContent string
|
||||
settings Settings
|
||||
errMessage string
|
||||
}{
|
||||
"empty_file": {},
|
||||
"malformed_toml": {
|
||||
fileContent: "this is not a toml file",
|
||||
errMessage: `toml decoding file: toml: expected character =`,
|
||||
},
|
||||
"unknown_field": {
|
||||
fileContent: `unknown = "what is this"`,
|
||||
errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct:
|
||||
1| unknown = "what is this"
|
||||
| ~~~~~~~ missing field`,
|
||||
},
|
||||
"filled_settings": {
|
||||
fileContent: `[[roles]]
|
||||
name = "public"
|
||||
auth = "none"
|
||||
routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"]
|
||||
|
||||
[[roles]]
|
||||
name = "client"
|
||||
auth = "apikey"
|
||||
apikey = "xyz"
|
||||
routes = ["GET /v1/vpn/status"]
|
||||
`,
|
||||
settings: Settings{
|
||||
Roles: []Role{{
|
||||
Name: "public",
|
||||
Auth: AuthNone,
|
||||
Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"},
|
||||
}, {
|
||||
Name: "client",
|
||||
Auth: AuthAPIKey,
|
||||
APIKey: "xyz",
|
||||
Routes: []string{"GET /v1/vpn/status"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
filepath := tempDir + "/config.toml"
|
||||
const permissions fs.FileMode = 0600
|
||||
err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
settings, err := Read(filepath)
|
||||
|
||||
assert.Equal(t, testCase.settings, settings)
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
22
internal/server/middlewares/auth/format.go
Normal file
22
internal/server/middlewares/auth/format.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package auth
|
||||
|
||||
func andStrings(strings []string) (result string) {
|
||||
return joinStrings(strings, "and")
|
||||
}
|
||||
|
||||
func joinStrings(strings []string, lastJoin string) (result string) {
|
||||
if len(strings) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result = strings[0]
|
||||
for i := 1; i < len(strings); i++ {
|
||||
if i < len(strings)-1 {
|
||||
result += ", " + strings[i]
|
||||
} else {
|
||||
result += " " + lastJoin + " " + strings[i]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
6
internal/server/middlewares/auth/interfaces.go
Normal file
6
internal/server/middlewares/auth/interfaces.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package auth
|
||||
|
||||
type DebugLogger interface {
|
||||
Debugf(format string, args ...any)
|
||||
Warnf(format string, args ...any)
|
||||
}
|
||||
8
internal/server/middlewares/auth/interfaces_local.go
Normal file
8
internal/server/middlewares/auth/interfaces_local.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package auth
|
||||
|
||||
import "net/http"
|
||||
|
||||
type authorizationChecker interface {
|
||||
equal(other authorizationChecker) bool
|
||||
isAuthorized(headers http.Header, request *http.Request) bool
|
||||
}
|
||||
47
internal/server/middlewares/auth/lookup.go
Normal file
47
internal/server/middlewares/auth/lookup.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type internalRole struct {
|
||||
name string
|
||||
checker authorizationChecker
|
||||
}
|
||||
|
||||
func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) {
|
||||
routeToRoles = make(map[string][]internalRole)
|
||||
for _, role := range settings.Roles {
|
||||
var checker authorizationChecker
|
||||
switch role.Auth {
|
||||
case AuthNone:
|
||||
checker = newNoneMethod()
|
||||
case AuthAPIKey:
|
||||
checker = newAPIKeyMethod(role.APIKey)
|
||||
case AuthBasic:
|
||||
checker = newBasicAuthMethod(role.Username, role.Password)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth)
|
||||
}
|
||||
|
||||
iRole := internalRole{
|
||||
name: role.Name,
|
||||
checker: checker,
|
||||
}
|
||||
for _, route := range role.Routes {
|
||||
checkerExists := false
|
||||
for _, role := range routeToRoles[route] {
|
||||
if role.checker.equal(iRole.checker) {
|
||||
checkerExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if checkerExists {
|
||||
// even if the role name is different, if the checker is the same, skip it.
|
||||
continue
|
||||
}
|
||||
routeToRoles[route] = append(routeToRoles[route], iRole)
|
||||
}
|
||||
}
|
||||
return routeToRoles, nil
|
||||
}
|
||||
60
internal/server/middlewares/auth/lookup_test.go
Normal file
60
internal/server/middlewares/auth/lookup_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Read reads the toml file specified by the filepath given.
|
||||
func Test_settingsToLookupMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
routeToRoles map[string][]internalRole
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"empty_settings": {
|
||||
routeToRoles: map[string][]internalRole{},
|
||||
},
|
||||
"auth_method_not_supported": {
|
||||
settings: Settings{
|
||||
Roles: []Role{{Name: "a", Auth: "bad"}},
|
||||
},
|
||||
errWrapped: ErrMethodNotSupported,
|
||||
errMessage: "authentication method not supported: bad",
|
||||
},
|
||||
"success": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}},
|
||||
{Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}},
|
||||
},
|
||||
},
|
||||
routeToRoles: map[string][]internalRole{
|
||||
"GET /path": {
|
||||
{name: "a", checker: newNoneMethod()}, // deduplicated method
|
||||
},
|
||||
"PUT /path": {
|
||||
{name: "b", checker: newNoneMethod()},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
routeToRoles, err := settingsToLookupMap(testCase.settings)
|
||||
|
||||
assert.Equal(t, testCase.routeToRoles, routeToRoles)
|
||||
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||
if testCase.errWrapped != nil {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
111
internal/server/middlewares/auth/middleware.go
Normal file
111
internal/server/middlewares/auth/middleware.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func New(settings Settings, debugLogger DebugLogger) (
|
||||
middleware func(http.Handler) http.Handler,
|
||||
err error) {
|
||||
routeToRoles, err := settingsToLookupMap(settings)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting settings to lookup maps: %w", err)
|
||||
}
|
||||
|
||||
//nolint:goconst
|
||||
return func(handler http.Handler) http.Handler {
|
||||
return &authHandler{
|
||||
childHandler: handler,
|
||||
routeToRoles: routeToRoles,
|
||||
unprotectedRoutes: map[string]struct{}{
|
||||
http.MethodGet + " /openvpn/actions/restart": {},
|
||||
http.MethodGet + " /unbound/actions/restart": {},
|
||||
http.MethodGet + " /updater/restart": {},
|
||||
http.MethodGet + " /v1/version": {},
|
||||
http.MethodGet + " /v1/vpn/status": {},
|
||||
http.MethodPut + " /v1/vpn/status": {},
|
||||
// GET /v1/vpn/settings is protected by default
|
||||
// PUT /v1/vpn/settings is protected by default
|
||||
http.MethodGet + " /v1/openvpn/status": {},
|
||||
http.MethodPut + " /v1/openvpn/status": {},
|
||||
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||
// GET /v1/openvpn/settings is protected by default
|
||||
http.MethodGet + " /v1/dns/status": {},
|
||||
http.MethodPut + " /v1/dns/status": {},
|
||||
http.MethodGet + " /v1/updater/status": {},
|
||||
http.MethodPut + " /v1/updater/status": {},
|
||||
http.MethodGet + " /v1/publicip/ip": {},
|
||||
},
|
||||
logger: debugLogger,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
childHandler http.Handler
|
||||
routeToRoles map[string][]internalRole
|
||||
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove
|
||||
logger DebugLogger
|
||||
}
|
||||
|
||||
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
route := request.Method + " " + request.URL.Path
|
||||
roles := h.routeToRoles[route]
|
||||
if len(roles) == 0 {
|
||||
h.logger.Debugf("no authentication role defined for route %s", route)
|
||||
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
responseHeader := make(http.Header, 0)
|
||||
for _, role := range roles {
|
||||
if !role.checker.isAuthorized(responseHeader, request) {
|
||||
continue
|
||||
}
|
||||
|
||||
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
|
||||
|
||||
h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
|
||||
h.childHandler.ServeHTTP(writer, request)
|
||||
return
|
||||
}
|
||||
|
||||
// Flush out response headers if all roles failed to authenticate
|
||||
for headerKey, headerValues := range responseHeader {
|
||||
for _, headerValue := range headerValues {
|
||||
writer.Header().Add(headerKey, headerValue)
|
||||
}
|
||||
}
|
||||
|
||||
allRoleNames := make([]string, len(roles))
|
||||
for i, role := range roles {
|
||||
allRoleNames[i] = role.name
|
||||
}
|
||||
h.logger.Debugf("access to route %s unauthorized after checking for roles %s",
|
||||
route, andStrings(allRoleNames))
|
||||
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
|
||||
// TODO v3.41.0 remove
|
||||
if role.name != "public" {
|
||||
// custom role name, allow none authentication to be specified
|
||||
return
|
||||
}
|
||||
_, isNoneChecker := role.checker.(*noneMethod)
|
||||
if !isNoneChecker {
|
||||
// not the none authentication method
|
||||
return
|
||||
}
|
||||
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
|
||||
if !isUnprotectedByDefault {
|
||||
// route is not unprotected by default, so this is a user decision
|
||||
return
|
||||
}
|
||||
h.logger.Warnf("route %s is unprotected by default, "+
|
||||
"please set up authentication following the documentation at "+
|
||||
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||
"since this will become no longer publicly accessible after release v3.40.",
|
||||
route)
|
||||
}
|
||||
124
internal/server/middlewares/auth/middleware_test.go
Normal file
124
internal/server/middlewares/auth/middleware_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_authHandler_ServeHTTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
settings Settings
|
||||
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
|
||||
requestMethod string
|
||||
requestPath string
|
||||
statusCode int
|
||||
responseBody string
|
||||
}{
|
||||
"route_has_no_role": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/b",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
responseBody: "Unauthorized\n",
|
||||
},
|
||||
"authorized_unprotected_by_default": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Warnf("route %s is unprotected by default, "+
|
||||
"please set up authentication following the documentation at "+
|
||||
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||
"since this will become no longer publicly accessible after release v3.40.",
|
||||
"GET /v1/vpn/status")
|
||||
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||
"GET /v1/vpn/status", "public")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/v1/vpn/status",
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
"authorized_none": {
|
||||
settings: Settings{
|
||||
Roles: []Role{
|
||||
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||
},
|
||||
},
|
||||
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
logger := NewMockDebugLogger(ctrl)
|
||||
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||
"GET /a", "role1")
|
||||
return logger
|
||||
},
|
||||
requestMethod: http.MethodGet,
|
||||
requestPath: "/a",
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var debugLogger DebugLogger
|
||||
if testCase.makeLogger != nil {
|
||||
debugLogger = testCase.makeLogger(ctrl)
|
||||
}
|
||||
middleware, err := New(testCase.settings, debugLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := middleware(childHandler)
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := server.Client()
|
||||
|
||||
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
|
||||
require.NoError(t, err)
|
||||
request, err := http.NewRequestWithContext(context.Background(),
|
||||
testCase.requestMethod, requestURL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := client.Do(request)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err = response.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
assert.Equal(t, testCase.statusCode, response.StatusCode)
|
||||
body, err := io.ReadAll(response.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testCase.responseBody, string(body))
|
||||
})
|
||||
}
|
||||
}
|
||||
3
internal/server/middlewares/auth/mocks_generate_test.go
Normal file
3
internal/server/middlewares/auth/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package auth
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger
|
||||
68
internal/server/middlewares/auth/mocks_test.go
Normal file
68
internal/server/middlewares/auth/mocks_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger)
|
||||
|
||||
// Package auth is a generated GoMock package.
|
||||
package auth
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDebugLogger is a mock of DebugLogger interface.
|
||||
type MockDebugLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDebugLoggerMockRecorder
|
||||
}
|
||||
|
||||
// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger.
|
||||
type MockDebugLoggerMockRecorder struct {
|
||||
mock *MockDebugLogger
|
||||
}
|
||||
|
||||
// NewMockDebugLogger creates a new mock instance.
|
||||
func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger {
|
||||
mock := &MockDebugLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockDebugLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Debugf mocks base method.
|
||||
func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Debugf", varargs...)
|
||||
}
|
||||
|
||||
// Debugf indicates an expected call of Debugf.
|
||||
func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...)
|
||||
}
|
||||
|
||||
// Warnf mocks base method.
|
||||
func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "Warnf", varargs...)
|
||||
}
|
||||
|
||||
// Warnf indicates an expected call of Warnf.
|
||||
func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...)
|
||||
}
|
||||
20
internal/server/middlewares/auth/none.go
Normal file
20
internal/server/middlewares/auth/none.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package auth
|
||||
|
||||
import "net/http"
|
||||
|
||||
type noneMethod struct{}
|
||||
|
||||
func newNoneMethod() *noneMethod {
|
||||
return &noneMethod{}
|
||||
}
|
||||
|
||||
// equal returns true if another auth checker is equal.
|
||||
// This is used to deduplicate checkers for a particular route.
|
||||
func (n *noneMethod) equal(other authorizationChecker) bool {
|
||||
_, ok := other.(*noneMethod)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
131
internal/server/middlewares/auth/settings.go
Normal file
131
internal/server/middlewares/auth/settings.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gosettings"
|
||||
"github.com/qdm12/gosettings/validate"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
// Roles is a list of roles with their associated authentication
|
||||
// and routes.
|
||||
Roles []Role
|
||||
}
|
||||
|
||||
func (s *Settings) SetDefaults() {
|
||||
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
|
||||
Name: "public",
|
||||
Auth: "none",
|
||||
Routes: []string{
|
||||
http.MethodGet + " /openvpn/actions/restart",
|
||||
http.MethodGet + " /unbound/actions/restart",
|
||||
http.MethodGet + " /updater/restart",
|
||||
http.MethodGet + " /v1/version",
|
||||
http.MethodGet + " /v1/vpn/status",
|
||||
http.MethodPut + " /v1/vpn/status",
|
||||
http.MethodGet + " /v1/openvpn/status",
|
||||
http.MethodPut + " /v1/openvpn/status",
|
||||
http.MethodGet + " /v1/openvpn/portforwarded",
|
||||
http.MethodGet + " /v1/dns/status",
|
||||
http.MethodPut + " /v1/dns/status",
|
||||
http.MethodGet + " /v1/updater/status",
|
||||
http.MethodPut + " /v1/updater/status",
|
||||
http.MethodGet + " /v1/publicip/ip",
|
||||
},
|
||||
}})
|
||||
}
|
||||
|
||||
func (s Settings) Validate() (err error) {
|
||||
for i, role := range s.Roles {
|
||||
err = role.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("role %s (%d of %d): %w",
|
||||
role.Name, i+1, len(s.Roles), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
AuthNone = "none"
|
||||
AuthAPIKey = "apikey"
|
||||
AuthBasic = "basic"
|
||||
)
|
||||
|
||||
// Role contains the role name, authentication method name and
|
||||
// routes that the role can access.
|
||||
type Role struct {
|
||||
// Name is the role name and is only used for documentation
|
||||
// and in the authentication middleware debug logs.
|
||||
Name string
|
||||
// Auth is the authentication method to use, which can be 'none' or 'apikey'.
|
||||
Auth string
|
||||
// APIKey is the API key to use when using the 'apikey' authentication.
|
||||
APIKey string
|
||||
// Username for HTTP Basic authentication method.
|
||||
Username string
|
||||
// Password for HTTP Basic authentication method.
|
||||
Password string
|
||||
// Routes is a list of routes that the role can access in the format
|
||||
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
|
||||
Routes []string
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMethodNotSupported = errors.New("authentication method not supported")
|
||||
ErrAPIKeyEmpty = errors.New("api key is empty")
|
||||
ErrBasicUsernameEmpty = errors.New("username is empty")
|
||||
ErrBasicPasswordEmpty = errors.New("password is empty")
|
||||
ErrRouteNotSupported = errors.New("route not supported by the control server")
|
||||
)
|
||||
|
||||
func (r Role) validate() (err error) {
|
||||
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Auth == AuthAPIKey && r.APIKey == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty)
|
||||
case r.Auth == AuthBasic && r.Username == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty)
|
||||
case r.Auth == AuthBasic && r.Password == "":
|
||||
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty)
|
||||
}
|
||||
|
||||
for i, route := range r.Routes {
|
||||
_, ok := validRoutes[route]
|
||||
if !ok {
|
||||
return fmt.Errorf("route %d of %d: %w: %s",
|
||||
i+1, len(r.Routes), ErrRouteNotSupported, route)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WARNING: do not mutate programmatically.
|
||||
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
|
||||
http.MethodGet + " /openvpn/actions/restart": {},
|
||||
http.MethodGet + " /unbound/actions/restart": {},
|
||||
http.MethodGet + " /updater/restart": {},
|
||||
http.MethodGet + " /v1/version": {},
|
||||
http.MethodGet + " /v1/vpn/status": {},
|
||||
http.MethodPut + " /v1/vpn/status": {},
|
||||
http.MethodGet + " /v1/vpn/settings": {},
|
||||
http.MethodPut + " /v1/vpn/settings": {},
|
||||
http.MethodGet + " /v1/openvpn/status": {},
|
||||
http.MethodPut + " /v1/openvpn/status": {},
|
||||
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||
http.MethodGet + " /v1/openvpn/settings": {},
|
||||
http.MethodGet + " /v1/dns/status": {},
|
||||
http.MethodPut + " /v1/dns/status": {},
|
||||
http.MethodGet + " /v1/updater/status": {},
|
||||
http.MethodPut + " /v1/updater/status": {},
|
||||
http.MethodGet + " /v1/publicip/ip": {},
|
||||
}
|
||||
5
internal/server/middlewares/log/interfaces.go
Normal file
5
internal/server/middlewares/log/interfaces.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package log
|
||||
|
||||
type Logger interface {
|
||||
Info(message string)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package log
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -7,18 +7,21 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func withLogMiddleware(childHandler http.Handler, logger infoer, enabled bool) *logMiddleware {
|
||||
return &logMiddleware{
|
||||
childHandler: childHandler,
|
||||
logger: logger,
|
||||
timeNow: time.Now,
|
||||
enabled: enabled,
|
||||
func New(logger Logger, enabled bool) (
|
||||
middleware func(http.Handler) http.Handler) {
|
||||
return func(handler http.Handler) http.Handler {
|
||||
return &logMiddleware{
|
||||
childHandler: handler,
|
||||
logger: logger,
|
||||
timeNow: time.Now,
|
||||
enabled: enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type logMiddleware struct {
|
||||
childHandler http.Handler
|
||||
logger infoer
|
||||
logger Logger
|
||||
timeNow func() time.Time
|
||||
enabled bool
|
||||
enabledMu sync.RWMutex
|
||||
@@ -39,7 +42,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.RemoteAddr + " in " + duration.String())
|
||||
}
|
||||
|
||||
func (m *logMiddleware) setEnabled(enabled bool) {
|
||||
func (m *logMiddleware) SetEnabled(enabled bool) {
|
||||
m.enabledMu.Lock()
|
||||
defer m.enabledMu.Unlock()
|
||||
m.enabled = enabled
|
||||
@@ -6,17 +6,31 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/httpserver"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||
)
|
||||
|
||||
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
|
||||
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||
authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
|
||||
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
|
||||
ipv6Supported bool) (
|
||||
server *httpserver.Server, err error) {
|
||||
handler := newHandler(ctx, logger, logEnabled, buildInfo,
|
||||
authSettings, err := auth.Read(authConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading auth settings: %w", err)
|
||||
}
|
||||
authSettings.SetDefaults()
|
||||
err = authSettings.Validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating auth settings: %w", err)
|
||||
}
|
||||
|
||||
handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
|
||||
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
|
||||
storage, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating handler: %w", err)
|
||||
}
|
||||
|
||||
httpServerSettings := httpserver.Settings{
|
||||
Address: address,
|
||||
|
||||
Reference in New Issue
Block a user