feat(server): role based authentication system (#2434)
- Parse toml configuration file, see https://github.com/qdm12/gluetun-wiki/blob/main/setup/advanced/control-server.md#authentication - Retro-compatible with existing AND documented routes, until after v3.41 release - Log a warning if an unprotected-by-default route is accessed unprotected - Authentication methods: none, apikey, basic - `genkey` command to generate API keys - move log middleware to internal/server/middlewares/log Co-authored-by: Joe Jose <45399349+joejose97@users.noreply.github.com>
This commit is contained in:
@@ -197,6 +197,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
# Control server
|
# Control server
|
||||||
HTTP_CONTROL_SERVER_LOG=on \
|
HTTP_CONTROL_SERVER_LOG=on \
|
||||||
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
HTTP_CONTROL_SERVER_ADDRESS=":8000" \
|
||||||
|
HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \
|
||||||
# Server data updater
|
# Server data updater
|
||||||
UPDATER_PERIOD=0 \
|
UPDATER_PERIOD=0 \
|
||||||
UPDATER_MIN_RATIO=0.8 \
|
UPDATER_MIN_RATIO=0.8 \
|
||||||
|
|||||||
@@ -161,12 +161,14 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
return cli.Update(ctx, args[2:], logger)
|
return cli.Update(ctx, args[2:], logger)
|
||||||
case "format-servers":
|
case "format-servers":
|
||||||
return cli.FormatServers(args[2:])
|
return cli.FormatServers(args[2:])
|
||||||
|
case "genkey":
|
||||||
|
return cli.GenKey(args[2:])
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
return fmt.Errorf("%w: %s", errCommandUnknown, args[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
announcementExp, err := time.Parse(time.RFC3339, "2023-07-01T00:00:00Z")
|
announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -177,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
Commit: buildInfo.Commit,
|
Commit: buildInfo.Commit,
|
||||||
Created: buildInfo.Created,
|
Created: buildInfo.Created,
|
||||||
Announcement: "Wiki moved to https://github.com/qdm12/gluetun-wiki",
|
Announcement: "All control server routes will become private by default after the v3.41.0 release",
|
||||||
AnnounceExp: announcementExp,
|
AnnounceExp: announcementExp,
|
||||||
// Sponsor information
|
// Sponsor information
|
||||||
PaypalUser: "qmcgaw",
|
PaypalUser: "qmcgaw",
|
||||||
@@ -474,6 +476,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
"http server", goroutine.OptionTimeout(defaultShutdownTimeout))
|
||||||
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
|
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
|
||||||
logger.New(log.SetComponent("http server")),
|
logger.New(log.SetComponent("http server")),
|
||||||
|
allSettings.ControlServer.AuthFilePath,
|
||||||
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
|
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
|
||||||
storage, ipv6Supported)
|
storage, ipv6Supported)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -595,6 +598,7 @@ type clier interface {
|
|||||||
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
|
OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error
|
||||||
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
|
HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error
|
||||||
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
|
Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error
|
||||||
|
GenKey(args []string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tun interface {
|
type Tun interface {
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -8,6 +8,7 @@ require (
|
|||||||
github.com/golang/mock v1.6.0
|
github.com/golang/mock v1.6.0
|
||||||
github.com/klauspost/compress v1.17.8
|
github.com/klauspost/compress v1.17.8
|
||||||
github.com/klauspost/pgzip v1.2.6
|
github.com/klauspost/pgzip v1.2.6
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2
|
||||||
github.com/qdm12/dns v1.11.0
|
github.com/qdm12/dns v1.11.0
|
||||||
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
|
github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6
|
||||||
github.com/qdm12/gosettings v0.4.2
|
github.com/qdm12/gosettings v0.4.2
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -83,6 +83,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
|
|||||||
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
|
||||||
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
|
||||||
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
|
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
|
github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
@@ -113,10 +115,16 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
|
|||||||
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
|
github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=
|
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=
|
||||||
|
|||||||
66
internal/cli/genkey.go
Normal file
66
internal/cli/genkey.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *CLI) GenKey(args []string) (err error) {
|
||||||
|
flagSet := flag.NewFlagSet("genkey", flag.ExitOnError)
|
||||||
|
err = flagSet.Parse(args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing flags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const keyLength = 128 / 8
|
||||||
|
keyBytes := make([]byte, keyLength)
|
||||||
|
|
||||||
|
_, _ = rand.Read(keyBytes)
|
||||||
|
|
||||||
|
key := base58Encode(keyBytes)
|
||||||
|
fmt.Println(key)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func base58Encode(data []byte) string {
|
||||||
|
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||||
|
const radix = 58
|
||||||
|
|
||||||
|
zcount := 0
|
||||||
|
for zcount < len(data) && data[zcount] == 0 {
|
||||||
|
zcount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// integer simplification of ceil(log(256)/log(58))
|
||||||
|
ceilLog256Div58 := (len(data)-zcount)*555/406 + 1 //nolint:gomnd
|
||||||
|
size := zcount + ceilLog256Div58
|
||||||
|
|
||||||
|
output := make([]byte, size)
|
||||||
|
|
||||||
|
high := size - 1
|
||||||
|
for _, b := range data {
|
||||||
|
i := size - 1
|
||||||
|
for carry := uint32(b); i > high || carry != 0; i-- {
|
||||||
|
carry += 256 * uint32(output[i]) //nolint:gomnd
|
||||||
|
output[i] = byte(carry % radix)
|
||||||
|
carry /= radix
|
||||||
|
}
|
||||||
|
high = i
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the additional "zero-gap" in the output buffer
|
||||||
|
additionalZeroGapEnd := zcount
|
||||||
|
for additionalZeroGapEnd < size && output[additionalZeroGapEnd] == 0 {
|
||||||
|
additionalZeroGapEnd++
|
||||||
|
}
|
||||||
|
|
||||||
|
val := output[additionalZeroGapEnd-zcount:]
|
||||||
|
size = len(val)
|
||||||
|
for i := range val {
|
||||||
|
output[i] = alphabet[val[i]]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(output[:size])
|
||||||
|
}
|
||||||
@@ -19,6 +19,11 @@ type ControlServer struct {
|
|||||||
// Log can be true or false to enable logging on requests.
|
// Log can be true or false to enable logging on requests.
|
||||||
// It cannot be nil in the internal state.
|
// It cannot be nil in the internal state.
|
||||||
Log *bool
|
Log *bool
|
||||||
|
// AuthFilePath is the path to the file containing the authentication
|
||||||
|
// configuration for the middleware.
|
||||||
|
// It cannot be empty in the internal state and defaults to
|
||||||
|
// /gluetun/auth/config.toml.
|
||||||
|
AuthFilePath string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c ControlServer) validate() (err error) {
|
func (c ControlServer) validate() (err error) {
|
||||||
@@ -46,6 +51,7 @@ func (c *ControlServer) copy() (copied ControlServer) {
|
|||||||
return ControlServer{
|
return ControlServer{
|
||||||
Address: gosettings.CopyPointer(c.Address),
|
Address: gosettings.CopyPointer(c.Address),
|
||||||
Log: gosettings.CopyPointer(c.Log),
|
Log: gosettings.CopyPointer(c.Log),
|
||||||
|
AuthFilePath: c.AuthFilePath,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,11 +61,13 @@ func (c *ControlServer) copy() (copied ControlServer) {
|
|||||||
func (c *ControlServer) overrideWith(other ControlServer) {
|
func (c *ControlServer) overrideWith(other ControlServer) {
|
||||||
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
|
c.Address = gosettings.OverrideWithPointer(c.Address, other.Address)
|
||||||
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
|
c.Log = gosettings.OverrideWithPointer(c.Log, other.Log)
|
||||||
|
c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ControlServer) setDefaults() {
|
func (c *ControlServer) setDefaults() {
|
||||||
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
|
c.Address = gosettings.DefaultPointer(c.Address, ":8000")
|
||||||
c.Log = gosettings.DefaultPointer(c.Log, true)
|
c.Log = gosettings.DefaultPointer(c.Log, true)
|
||||||
|
c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c ControlServer) String() string {
|
func (c ControlServer) String() string {
|
||||||
@@ -70,6 +78,7 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) {
|
|||||||
node = gotree.New("Control server settings:")
|
node = gotree.New("Control server settings:")
|
||||||
node.Appendf("Listening address: %s", *c.Address)
|
node.Appendf("Listening address: %s", *c.Address)
|
||||||
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
|
node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log))
|
||||||
|
node.Appendf("Authentication file path: %s", c.AuthFilePath)
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,6 +87,10 @@ func (c *ControlServer) read(r *reader.Reader) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
|
c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS")
|
||||||
|
|
||||||
|
c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,7 +78,8 @@ func Test_Settings_String(t *testing.T) {
|
|||||||
| └── Enabled: no
|
| └── Enabled: no
|
||||||
├── Control server settings:
|
├── Control server settings:
|
||||||
| ├── Listening address: :8000
|
| ├── Listening address: :8000
|
||||||
| └── Logging: yes
|
| ├── Logging: yes
|
||||||
|
| └── Authentication file path: /gluetun/auth/config.toml
|
||||||
├── OS Alpine settings:
|
├── OS Alpine settings:
|
||||||
| ├── Process UID: 1000
|
| ├── Process UID: 1000
|
||||||
| └── Process GID: 1000
|
| └── Process GID: 1000
|
||||||
|
|||||||
@@ -2,13 +2,17 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||||
|
"github.com/qdm12/gluetun/internal/server/middlewares/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
func newHandler(ctx context.Context, logger Logger, logging bool,
|
||||||
|
authSettings auth.Settings,
|
||||||
buildInfo models.BuildInformation,
|
buildInfo models.BuildInformation,
|
||||||
vpnLooper VPNLooper,
|
vpnLooper VPNLooper,
|
||||||
pfGetter PortForwardedGetter,
|
pfGetter PortForwardedGetter,
|
||||||
@@ -17,7 +21,7 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
|||||||
publicIPLooper PublicIPLoop,
|
publicIPLooper PublicIPLoop,
|
||||||
storage Storage,
|
storage Storage,
|
||||||
ipv6Supported bool,
|
ipv6Supported bool,
|
||||||
) http.Handler {
|
) (httpHandler http.Handler, err error) {
|
||||||
handler := &handler{}
|
handler := &handler{}
|
||||||
|
|
||||||
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
|
vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
|
||||||
@@ -29,16 +33,25 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
|
|||||||
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
|
handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper)
|
||||||
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
|
handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip)
|
||||||
|
|
||||||
handlerWithLog := withLogMiddleware(handler, logger, logging)
|
authMiddleware, err := auth.New(authSettings, logger)
|
||||||
handler.setLogEnabled = handlerWithLog.setEnabled
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating auth middleware: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return handlerWithLog
|
middlewares := []func(http.Handler) http.Handler{
|
||||||
|
authMiddleware,
|
||||||
|
log.New(logger, logging),
|
||||||
|
}
|
||||||
|
httpHandler = handler
|
||||||
|
for _, middleware := range middlewares {
|
||||||
|
httpHandler = middleware(httpHandler)
|
||||||
|
}
|
||||||
|
return httpHandler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type handler struct {
|
type handler struct {
|
||||||
v0 http.Handler
|
v0 http.Handler
|
||||||
v1 http.Handler
|
v1 http.Handler
|
||||||
setLogEnabled func(enabled bool)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
type Logger interface {
|
type Logger interface {
|
||||||
|
Debugf(format string, args ...any)
|
||||||
infoer
|
infoer
|
||||||
warner
|
warner
|
||||||
|
Warnf(format string, args ...any)
|
||||||
errorer
|
errorer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
36
internal/server/middlewares/auth/apikey.go
Normal file
36
internal/server/middlewares/auth/apikey.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiKeyMethod struct {
|
||||||
|
apiKeyDigest [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAPIKeyMethod(apiKey string) *apiKeyMethod {
|
||||||
|
return &apiKeyMethod{
|
||||||
|
apiKeyDigest: sha256.Sum256([]byte(apiKey)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal returns true if another auth checker is equal.
|
||||||
|
// This is used to deduplicate checkers for a particular route.
|
||||||
|
func (a *apiKeyMethod) equal(other authorizationChecker) bool {
|
||||||
|
otherTokenMethod, ok := other.(*apiKeyMethod)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.apiKeyDigest == otherTokenMethod.apiKeyDigest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool {
|
||||||
|
xAPIKey := request.Header.Get("X-API-Key")
|
||||||
|
if xAPIKey == "" {
|
||||||
|
xAPIKey = request.URL.Query().Get("api_key")
|
||||||
|
}
|
||||||
|
xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey))
|
||||||
|
return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1
|
||||||
|
}
|
||||||
37
internal/server/middlewares/auth/basic.go
Normal file
37
internal/server/middlewares/auth/basic.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type basicAuthMethod struct {
|
||||||
|
authDigest [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBasicAuthMethod(username, password string) *basicAuthMethod {
|
||||||
|
return &basicAuthMethod{
|
||||||
|
authDigest: sha256.Sum256([]byte(username + password)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal returns true if another auth checker is equal.
|
||||||
|
// This is used to deduplicate checkers for a particular route.
|
||||||
|
func (a *basicAuthMethod) equal(other authorizationChecker) bool {
|
||||||
|
otherBasicMethod, ok := other.(*basicAuthMethod)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.authDigest == otherBasicMethod.authDigest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool {
|
||||||
|
username, password, ok := request.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
requestAuthDigest := sha256.Sum256([]byte(username + password))
|
||||||
|
return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1
|
||||||
|
}
|
||||||
35
internal/server/middlewares/auth/configfile.go
Normal file
35
internal/server/middlewares/auth/configfile.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/pelletier/go-toml/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read reads the toml file specified by the filepath given.
|
||||||
|
// If the file does not exist, it returns empty settings and no error.
|
||||||
|
func Read(filepath string) (settings Settings, err error) {
|
||||||
|
file, err := os.Open(filepath)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return Settings{}, nil
|
||||||
|
}
|
||||||
|
return settings, fmt.Errorf("opening file: %w", err)
|
||||||
|
}
|
||||||
|
decoder := toml.NewDecoder(file)
|
||||||
|
decoder.DisallowUnknownFields()
|
||||||
|
err = decoder.Decode(&settings)
|
||||||
|
if err == nil {
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
strictErr := new(toml.StrictMissingError)
|
||||||
|
ok := errors.As(err, &strictErr)
|
||||||
|
if !ok {
|
||||||
|
return settings, fmt.Errorf("toml decoding file: %w", err)
|
||||||
|
}
|
||||||
|
return settings, fmt.Errorf("toml decoding file: %w:\n%s",
|
||||||
|
strictErr, strictErr.String())
|
||||||
|
}
|
||||||
80
internal/server/middlewares/auth/configfile_test.go
Normal file
80
internal/server/middlewares/auth/configfile_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read reads the toml file specified by the filepath given.
|
||||||
|
func Test_Read(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
fileContent string
|
||||||
|
settings Settings
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"empty_file": {},
|
||||||
|
"malformed_toml": {
|
||||||
|
fileContent: "this is not a toml file",
|
||||||
|
errMessage: `toml decoding file: toml: expected character =`,
|
||||||
|
},
|
||||||
|
"unknown_field": {
|
||||||
|
fileContent: `unknown = "what is this"`,
|
||||||
|
errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct:
|
||||||
|
1| unknown = "what is this"
|
||||||
|
| ~~~~~~~ missing field`,
|
||||||
|
},
|
||||||
|
"filled_settings": {
|
||||||
|
fileContent: `[[roles]]
|
||||||
|
name = "public"
|
||||||
|
auth = "none"
|
||||||
|
routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"]
|
||||||
|
|
||||||
|
[[roles]]
|
||||||
|
name = "client"
|
||||||
|
auth = "apikey"
|
||||||
|
apikey = "xyz"
|
||||||
|
routes = ["GET /v1/vpn/status"]
|
||||||
|
`,
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{{
|
||||||
|
Name: "public",
|
||||||
|
Auth: AuthNone,
|
||||||
|
Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"},
|
||||||
|
}, {
|
||||||
|
Name: "client",
|
||||||
|
Auth: AuthAPIKey,
|
||||||
|
APIKey: "xyz",
|
||||||
|
Routes: []string{"GET /v1/vpn/status"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
filepath := tempDir + "/config.toml"
|
||||||
|
const permissions fs.FileMode = 0600
|
||||||
|
err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
settings, err := Read(filepath)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.settings, settings)
|
||||||
|
if testCase.errMessage != "" {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
22
internal/server/middlewares/auth/format.go
Normal file
22
internal/server/middlewares/auth/format.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
func andStrings(strings []string) (result string) {
|
||||||
|
return joinStrings(strings, "and")
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinStrings(strings []string, lastJoin string) (result string) {
|
||||||
|
if len(strings) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
result = strings[0]
|
||||||
|
for i := 1; i < len(strings); i++ {
|
||||||
|
if i < len(strings)-1 {
|
||||||
|
result += ", " + strings[i]
|
||||||
|
} else {
|
||||||
|
result += " " + lastJoin + " " + strings[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
6
internal/server/middlewares/auth/interfaces.go
Normal file
6
internal/server/middlewares/auth/interfaces.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
type DebugLogger interface {
|
||||||
|
Debugf(format string, args ...any)
|
||||||
|
Warnf(format string, args ...any)
|
||||||
|
}
|
||||||
8
internal/server/middlewares/auth/interfaces_local.go
Normal file
8
internal/server/middlewares/auth/interfaces_local.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type authorizationChecker interface {
|
||||||
|
equal(other authorizationChecker) bool
|
||||||
|
isAuthorized(headers http.Header, request *http.Request) bool
|
||||||
|
}
|
||||||
47
internal/server/middlewares/auth/lookup.go
Normal file
47
internal/server/middlewares/auth/lookup.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type internalRole struct {
|
||||||
|
name string
|
||||||
|
checker authorizationChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) {
|
||||||
|
routeToRoles = make(map[string][]internalRole)
|
||||||
|
for _, role := range settings.Roles {
|
||||||
|
var checker authorizationChecker
|
||||||
|
switch role.Auth {
|
||||||
|
case AuthNone:
|
||||||
|
checker = newNoneMethod()
|
||||||
|
case AuthAPIKey:
|
||||||
|
checker = newAPIKeyMethod(role.APIKey)
|
||||||
|
case AuthBasic:
|
||||||
|
checker = newBasicAuthMethod(role.Username, role.Password)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
iRole := internalRole{
|
||||||
|
name: role.Name,
|
||||||
|
checker: checker,
|
||||||
|
}
|
||||||
|
for _, route := range role.Routes {
|
||||||
|
checkerExists := false
|
||||||
|
for _, role := range routeToRoles[route] {
|
||||||
|
if role.checker.equal(iRole.checker) {
|
||||||
|
checkerExists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if checkerExists {
|
||||||
|
// even if the role name is different, if the checker is the same, skip it.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
routeToRoles[route] = append(routeToRoles[route], iRole)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return routeToRoles, nil
|
||||||
|
}
|
||||||
60
internal/server/middlewares/auth/lookup_test.go
Normal file
60
internal/server/middlewares/auth/lookup_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read reads the toml file specified by the filepath given.
|
||||||
|
func Test_settingsToLookupMap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
settings Settings
|
||||||
|
routeToRoles map[string][]internalRole
|
||||||
|
errWrapped error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"empty_settings": {
|
||||||
|
routeToRoles: map[string][]internalRole{},
|
||||||
|
},
|
||||||
|
"auth_method_not_supported": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{{Name: "a", Auth: "bad"}},
|
||||||
|
},
|
||||||
|
errWrapped: ErrMethodNotSupported,
|
||||||
|
errMessage: "authentication method not supported: bad",
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{
|
||||||
|
{Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}},
|
||||||
|
{Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
routeToRoles: map[string][]internalRole{
|
||||||
|
"GET /path": {
|
||||||
|
{name: "a", checker: newNoneMethod()}, // deduplicated method
|
||||||
|
},
|
||||||
|
"PUT /path": {
|
||||||
|
{name: "b", checker: newNoneMethod()},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
routeToRoles, err := settingsToLookupMap(testCase.settings)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.routeToRoles, routeToRoles)
|
||||||
|
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||||
|
if testCase.errWrapped != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
111
internal/server/middlewares/auth/middleware.go
Normal file
111
internal/server/middlewares/auth/middleware.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(settings Settings, debugLogger DebugLogger) (
|
||||||
|
middleware func(http.Handler) http.Handler,
|
||||||
|
err error) {
|
||||||
|
routeToRoles, err := settingsToLookupMap(settings)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("converting settings to lookup maps: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:goconst
|
||||||
|
return func(handler http.Handler) http.Handler {
|
||||||
|
return &authHandler{
|
||||||
|
childHandler: handler,
|
||||||
|
routeToRoles: routeToRoles,
|
||||||
|
unprotectedRoutes: map[string]struct{}{
|
||||||
|
http.MethodGet + " /openvpn/actions/restart": {},
|
||||||
|
http.MethodGet + " /unbound/actions/restart": {},
|
||||||
|
http.MethodGet + " /updater/restart": {},
|
||||||
|
http.MethodGet + " /v1/version": {},
|
||||||
|
http.MethodGet + " /v1/vpn/status": {},
|
||||||
|
http.MethodPut + " /v1/vpn/status": {},
|
||||||
|
// GET /v1/vpn/settings is protected by default
|
||||||
|
// PUT /v1/vpn/settings is protected by default
|
||||||
|
http.MethodGet + " /v1/openvpn/status": {},
|
||||||
|
http.MethodPut + " /v1/openvpn/status": {},
|
||||||
|
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||||
|
// GET /v1/openvpn/settings is protected by default
|
||||||
|
http.MethodGet + " /v1/dns/status": {},
|
||||||
|
http.MethodPut + " /v1/dns/status": {},
|
||||||
|
http.MethodGet + " /v1/updater/status": {},
|
||||||
|
http.MethodPut + " /v1/updater/status": {},
|
||||||
|
http.MethodGet + " /v1/publicip/ip": {},
|
||||||
|
},
|
||||||
|
logger: debugLogger,
|
||||||
|
}
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type authHandler struct {
|
||||||
|
childHandler http.Handler
|
||||||
|
routeToRoles map[string][]internalRole
|
||||||
|
unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove
|
||||||
|
logger DebugLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
route := request.Method + " " + request.URL.Path
|
||||||
|
roles := h.routeToRoles[route]
|
||||||
|
if len(roles) == 0 {
|
||||||
|
h.logger.Debugf("no authentication role defined for route %s", route)
|
||||||
|
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responseHeader := make(http.Header, 0)
|
||||||
|
for _, role := range roles {
|
||||||
|
if !role.checker.isAuthorized(responseHeader, request) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove
|
||||||
|
|
||||||
|
h.logger.Debugf("access to route %s authorized for role %s", route, role.name)
|
||||||
|
h.childHandler.ServeHTTP(writer, request)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush out response headers if all roles failed to authenticate
|
||||||
|
for headerKey, headerValues := range responseHeader {
|
||||||
|
for _, headerValue := range headerValues {
|
||||||
|
writer.Header().Add(headerKey, headerValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allRoleNames := make([]string, len(roles))
|
||||||
|
for i, role := range roles {
|
||||||
|
allRoleNames[i] = role.name
|
||||||
|
}
|
||||||
|
h.logger.Debugf("access to route %s unauthorized after checking for roles %s",
|
||||||
|
route, andStrings(allRoleNames))
|
||||||
|
http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) {
|
||||||
|
// TODO v3.41.0 remove
|
||||||
|
if role.name != "public" {
|
||||||
|
// custom role name, allow none authentication to be specified
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, isNoneChecker := role.checker.(*noneMethod)
|
||||||
|
if !isNoneChecker {
|
||||||
|
// not the none authentication method
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, isUnprotectedByDefault := h.unprotectedRoutes[route]
|
||||||
|
if !isUnprotectedByDefault {
|
||||||
|
// route is not unprotected by default, so this is a user decision
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Warnf("route %s is unprotected by default, "+
|
||||||
|
"please set up authentication following the documentation at "+
|
||||||
|
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||||
|
"since this will become no longer publicly accessible after release v3.40.",
|
||||||
|
route)
|
||||||
|
}
|
||||||
124
internal/server/middlewares/auth/middleware_test.go
Normal file
124
internal/server/middlewares/auth/middleware_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_authHandler_ServeHTTP(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
settings Settings
|
||||||
|
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
|
||||||
|
requestMethod string
|
||||||
|
requestPath string
|
||||||
|
statusCode int
|
||||||
|
responseBody string
|
||||||
|
}{
|
||||||
|
"route_has_no_role": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{
|
||||||
|
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
|
logger := NewMockDebugLogger(ctrl)
|
||||||
|
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
requestMethod: http.MethodGet,
|
||||||
|
requestPath: "/b",
|
||||||
|
statusCode: http.StatusUnauthorized,
|
||||||
|
responseBody: "Unauthorized\n",
|
||||||
|
},
|
||||||
|
"authorized_unprotected_by_default": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{
|
||||||
|
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
|
logger := NewMockDebugLogger(ctrl)
|
||||||
|
logger.EXPECT().Warnf("route %s is unprotected by default, "+
|
||||||
|
"please set up authentication following the documentation at "+
|
||||||
|
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
||||||
|
"since this will become no longer publicly accessible after release v3.40.",
|
||||||
|
"GET /v1/vpn/status")
|
||||||
|
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||||
|
"GET /v1/vpn/status", "public")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
requestMethod: http.MethodGet,
|
||||||
|
requestPath: "/v1/vpn/status",
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
"authorized_none": {
|
||||||
|
settings: Settings{
|
||||||
|
Roles: []Role{
|
||||||
|
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
|
logger := NewMockDebugLogger(ctrl)
|
||||||
|
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
||||||
|
"GET /a", "role1")
|
||||||
|
return logger
|
||||||
|
},
|
||||||
|
requestMethod: http.MethodGet,
|
||||||
|
requestPath: "/a",
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
|
var debugLogger DebugLogger
|
||||||
|
if testCase.makeLogger != nil {
|
||||||
|
debugLogger = testCase.makeLogger(ctrl)
|
||||||
|
}
|
||||||
|
middleware, err := New(testCase.settings, debugLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
handler := middleware(childHandler)
|
||||||
|
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
t.Cleanup(server.Close)
|
||||||
|
|
||||||
|
client := server.Client()
|
||||||
|
|
||||||
|
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
request, err := http.NewRequestWithContext(context.Background(),
|
||||||
|
testCase.requestMethod, requestURL, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
response, err := client.Do(request)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = response.Body.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.statusCode, response.StatusCode)
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testCase.responseBody, string(body))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
3
internal/server/middlewares/auth/mocks_generate_test.go
Normal file
3
internal/server/middlewares/auth/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger
|
||||||
68
internal/server/middlewares/auth/mocks_test.go
Normal file
68
internal/server/middlewares/auth/mocks_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger)
|
||||||
|
|
||||||
|
// Package auth is a generated GoMock package.
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockDebugLogger is a mock of DebugLogger interface.
|
||||||
|
type MockDebugLogger struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockDebugLoggerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger.
|
||||||
|
type MockDebugLoggerMockRecorder struct {
|
||||||
|
mock *MockDebugLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockDebugLogger creates a new mock instance.
|
||||||
|
func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger {
|
||||||
|
mock := &MockDebugLogger{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockDebugLoggerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf mocks base method.
|
||||||
|
func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Debugf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf indicates an expected call of Debugf.
|
||||||
|
func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf mocks base method.
|
||||||
|
func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
m.ctrl.Call(m, "Warnf", varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf indicates an expected call of Warnf.
|
||||||
|
func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...)
|
||||||
|
}
|
||||||
20
internal/server/middlewares/auth/none.go
Normal file
20
internal/server/middlewares/auth/none.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type noneMethod struct{}
|
||||||
|
|
||||||
|
func newNoneMethod() *noneMethod {
|
||||||
|
return &noneMethod{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal returns true if another auth checker is equal.
|
||||||
|
// This is used to deduplicate checkers for a particular route.
|
||||||
|
func (n *noneMethod) equal(other authorizationChecker) bool {
|
||||||
|
_, ok := other.(*noneMethod)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
131
internal/server/middlewares/auth/settings.go
Normal file
131
internal/server/middlewares/auth/settings.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/qdm12/gosettings"
|
||||||
|
"github.com/qdm12/gosettings/validate"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Settings struct {
|
||||||
|
// Roles is a list of roles with their associated authentication
|
||||||
|
// and routes.
|
||||||
|
Roles []Role
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Settings) SetDefaults() {
|
||||||
|
s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty
|
||||||
|
Name: "public",
|
||||||
|
Auth: "none",
|
||||||
|
Routes: []string{
|
||||||
|
http.MethodGet + " /openvpn/actions/restart",
|
||||||
|
http.MethodGet + " /unbound/actions/restart",
|
||||||
|
http.MethodGet + " /updater/restart",
|
||||||
|
http.MethodGet + " /v1/version",
|
||||||
|
http.MethodGet + " /v1/vpn/status",
|
||||||
|
http.MethodPut + " /v1/vpn/status",
|
||||||
|
http.MethodGet + " /v1/openvpn/status",
|
||||||
|
http.MethodPut + " /v1/openvpn/status",
|
||||||
|
http.MethodGet + " /v1/openvpn/portforwarded",
|
||||||
|
http.MethodGet + " /v1/dns/status",
|
||||||
|
http.MethodPut + " /v1/dns/status",
|
||||||
|
http.MethodGet + " /v1/updater/status",
|
||||||
|
http.MethodPut + " /v1/updater/status",
|
||||||
|
http.MethodGet + " /v1/publicip/ip",
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Settings) Validate() (err error) {
|
||||||
|
for i, role := range s.Roles {
|
||||||
|
err = role.validate()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("role %s (%d of %d): %w",
|
||||||
|
role.Name, i+1, len(s.Roles), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthNone = "none"
|
||||||
|
AuthAPIKey = "apikey"
|
||||||
|
AuthBasic = "basic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Role contains the role name, authentication method name and
|
||||||
|
// routes that the role can access.
|
||||||
|
type Role struct {
|
||||||
|
// Name is the role name and is only used for documentation
|
||||||
|
// and in the authentication middleware debug logs.
|
||||||
|
Name string
|
||||||
|
// Auth is the authentication method to use, which can be 'none' or 'apikey'.
|
||||||
|
Auth string
|
||||||
|
// APIKey is the API key to use when using the 'apikey' authentication.
|
||||||
|
APIKey string
|
||||||
|
// Username for HTTP Basic authentication method.
|
||||||
|
Username string
|
||||||
|
// Password for HTTP Basic authentication method.
|
||||||
|
Password string
|
||||||
|
// Routes is a list of routes that the role can access in the format
|
||||||
|
// "HTTP_METHOD PATH", for example "GET /v1/vpn/status"
|
||||||
|
Routes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMethodNotSupported = errors.New("authentication method not supported")
|
||||||
|
ErrAPIKeyEmpty = errors.New("api key is empty")
|
||||||
|
ErrBasicUsernameEmpty = errors.New("username is empty")
|
||||||
|
ErrBasicPasswordEmpty = errors.New("password is empty")
|
||||||
|
ErrRouteNotSupported = errors.New("route not supported by the control server")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r Role) validate() (err error) {
|
||||||
|
err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case r.Auth == AuthAPIKey && r.APIKey == "":
|
||||||
|
return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty)
|
||||||
|
case r.Auth == AuthBasic && r.Username == "":
|
||||||
|
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty)
|
||||||
|
case r.Auth == AuthBasic && r.Password == "":
|
||||||
|
return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, route := range r.Routes {
|
||||||
|
_, ok := validRoutes[route]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("route %d of %d: %w: %s",
|
||||||
|
i+1, len(r.Routes), ErrRouteNotSupported, route)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WARNING: do not mutate programmatically.
|
||||||
|
var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals
|
||||||
|
http.MethodGet + " /openvpn/actions/restart": {},
|
||||||
|
http.MethodGet + " /unbound/actions/restart": {},
|
||||||
|
http.MethodGet + " /updater/restart": {},
|
||||||
|
http.MethodGet + " /v1/version": {},
|
||||||
|
http.MethodGet + " /v1/vpn/status": {},
|
||||||
|
http.MethodPut + " /v1/vpn/status": {},
|
||||||
|
http.MethodGet + " /v1/vpn/settings": {},
|
||||||
|
http.MethodPut + " /v1/vpn/settings": {},
|
||||||
|
http.MethodGet + " /v1/openvpn/status": {},
|
||||||
|
http.MethodPut + " /v1/openvpn/status": {},
|
||||||
|
http.MethodGet + " /v1/openvpn/portforwarded": {},
|
||||||
|
http.MethodGet + " /v1/openvpn/settings": {},
|
||||||
|
http.MethodGet + " /v1/dns/status": {},
|
||||||
|
http.MethodPut + " /v1/dns/status": {},
|
||||||
|
http.MethodGet + " /v1/updater/status": {},
|
||||||
|
http.MethodPut + " /v1/updater/status": {},
|
||||||
|
http.MethodGet + " /v1/publicip/ip": {},
|
||||||
|
}
|
||||||
5
internal/server/middlewares/log/interfaces.go
Normal file
5
internal/server/middlewares/log/interfaces.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package log
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Info(message string)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -7,18 +7,21 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func withLogMiddleware(childHandler http.Handler, logger infoer, enabled bool) *logMiddleware {
|
func New(logger Logger, enabled bool) (
|
||||||
|
middleware func(http.Handler) http.Handler) {
|
||||||
|
return func(handler http.Handler) http.Handler {
|
||||||
return &logMiddleware{
|
return &logMiddleware{
|
||||||
childHandler: childHandler,
|
childHandler: handler,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
timeNow: time.Now,
|
timeNow: time.Now,
|
||||||
enabled: enabled,
|
enabled: enabled,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type logMiddleware struct {
|
type logMiddleware struct {
|
||||||
childHandler http.Handler
|
childHandler http.Handler
|
||||||
logger infoer
|
logger Logger
|
||||||
timeNow func() time.Time
|
timeNow func() time.Time
|
||||||
enabled bool
|
enabled bool
|
||||||
enabledMu sync.RWMutex
|
enabledMu sync.RWMutex
|
||||||
@@ -39,7 +42,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
r.RemoteAddr + " in " + duration.String())
|
r.RemoteAddr + " in " + duration.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *logMiddleware) setEnabled(enabled bool) {
|
func (m *logMiddleware) SetEnabled(enabled bool) {
|
||||||
m.enabledMu.Lock()
|
m.enabledMu.Lock()
|
||||||
defer m.enabledMu.Unlock()
|
defer m.enabledMu.Unlock()
|
||||||
m.enabled = enabled
|
m.enabled = enabled
|
||||||
@@ -6,17 +6,31 @@ import (
|
|||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/httpserver"
|
"github.com/qdm12/gluetun/internal/httpserver"
|
||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
|
"github.com/qdm12/gluetun/internal/server/middlewares/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
|
func New(ctx context.Context, address string, logEnabled bool, logger Logger,
|
||||||
buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
|
||||||
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
|
pfGetter PortForwardedGetter, unboundLooper DNSLoop,
|
||||||
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
|
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
|
||||||
ipv6Supported bool) (
|
ipv6Supported bool) (
|
||||||
server *httpserver.Server, err error) {
|
server *httpserver.Server, err error) {
|
||||||
handler := newHandler(ctx, logger, logEnabled, buildInfo,
|
authSettings, err := auth.Read(authConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading auth settings: %w", err)
|
||||||
|
}
|
||||||
|
authSettings.SetDefaults()
|
||||||
|
err = authSettings.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validating auth settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
|
||||||
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
|
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
|
||||||
storage, ipv6Supported)
|
storage, ipv6Supported)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating handler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
httpServerSettings := httpserver.Settings{
|
httpServerSettings := httpserver.Settings{
|
||||||
Address: address,
|
Address: address,
|
||||||
|
|||||||
Reference in New Issue
Block a user