124 lines
3.4 KiB
Go
124 lines
3.4 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func Test_authHandler_ServeHTTP(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := map[string]struct {
|
|
settings Settings
|
|
makeLogger func(ctrl *gomock.Controller) *MockDebugLogger
|
|
requestMethod string
|
|
requestPath string
|
|
statusCode int
|
|
responseBody string
|
|
}{
|
|
"route_has_no_role": {
|
|
settings: Settings{
|
|
Roles: []Role{
|
|
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
|
},
|
|
},
|
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
|
logger := NewMockDebugLogger(ctrl)
|
|
logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b")
|
|
return logger
|
|
},
|
|
requestMethod: http.MethodGet,
|
|
requestPath: "/b",
|
|
statusCode: http.StatusUnauthorized,
|
|
responseBody: "Unauthorized\n",
|
|
},
|
|
"authorized_unprotected_by_default": {
|
|
settings: Settings{
|
|
Roles: []Role{
|
|
{Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}},
|
|
},
|
|
},
|
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
|
logger := NewMockDebugLogger(ctrl)
|
|
logger.EXPECT().Warnf("route %s is unprotected by default, "+
|
|
"please set up authentication following the documentation at "+
|
|
"https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+
|
|
"since this will become no longer publicly accessible after release v3.40.",
|
|
"GET /v1/vpn/status")
|
|
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
|
"GET /v1/vpn/status", "public")
|
|
return logger
|
|
},
|
|
requestMethod: http.MethodGet,
|
|
requestPath: "/v1/vpn/status",
|
|
statusCode: http.StatusOK,
|
|
},
|
|
"authorized_none": {
|
|
settings: Settings{
|
|
Roles: []Role{
|
|
{Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}},
|
|
},
|
|
},
|
|
makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger {
|
|
logger := NewMockDebugLogger(ctrl)
|
|
logger.EXPECT().Debugf("access to route %s authorized for role %s",
|
|
"GET /a", "role1")
|
|
return logger
|
|
},
|
|
requestMethod: http.MethodGet,
|
|
requestPath: "/a",
|
|
statusCode: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for name, testCase := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctrl := gomock.NewController(t)
|
|
|
|
var debugLogger DebugLogger
|
|
if testCase.makeLogger != nil {
|
|
debugLogger = testCase.makeLogger(ctrl)
|
|
}
|
|
middleware, err := New(testCase.settings, debugLogger)
|
|
require.NoError(t, err)
|
|
|
|
childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
handler := middleware(childHandler)
|
|
|
|
server := httptest.NewServer(handler)
|
|
t.Cleanup(server.Close)
|
|
|
|
client := server.Client()
|
|
|
|
requestURL, err := url.JoinPath(server.URL, testCase.requestPath)
|
|
require.NoError(t, err)
|
|
request, err := http.NewRequestWithContext(context.Background(),
|
|
testCase.requestMethod, requestURL, nil)
|
|
require.NoError(t, err)
|
|
|
|
response, err := client.Do(request)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
err = response.Body.Close()
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
assert.Equal(t, testCase.statusCode, response.StatusCode)
|
|
body, err := io.ReadAll(response.Body)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, testCase.responseBody, string(body))
|
|
})
|
|
}
|
|
}
|