diff --git a/pkg/gofr/context.go b/pkg/gofr/context.go index 9858a1673..5788fd597 100644 --- a/pkg/gofr/context.go +++ b/pkg/gofr/context.go @@ -3,12 +3,14 @@ package gofr import ( "context" + "github.com/golang-jwt/jwt/v5" "github.com/gorilla/websocket" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" "gofr.dev/pkg/gofr/container" + "gofr.dev/pkg/gofr/http/middleware" ) type Context struct { @@ -28,6 +30,12 @@ type Context struct { responder Responder } +type AuthInfo interface { + GetClaims() jwt.MapClaims + GetUsername() string + GetAPIKey() string +} + /* Trace returns an open telemetry span. We have to always close the span after corresponding work is done. Usages: @@ -75,6 +83,49 @@ func (c *Context) WriteMessageToSocket(data any) error { return conn.WriteMessage(websocket.TextMessage, message) } +type authInfo struct { + claims jwt.MapClaims + username string + apiKey string +} + +// GetAuthInfo is a method on context, to access different methods to retrieve authentication info. +// +// GetAuthInfo().GetClaims() : retrieves the jwt claims. +// GetAuthInfo().GetUsername() : retrieves the username while basic authentication. +// GetAuthInfo().GetAPIKey() : retrieves the APIKey being used for authentication. +func (c *Context) GetAuthInfo() AuthInfo { + claims, _ := c.Request.Context().Value(middleware.JWTClaim).(jwt.MapClaims) + + APIKey, _ := c.Request.Context().Value(middleware.APIKey).(string) + + username, _ := c.Request.Context().Value(middleware.Username).(string) + + return &authInfo{ + claims: claims, + username: username, + apiKey: APIKey, + } +} + +// GetClaims returns a response of jwt.MapClaims type when OAuth is enabled. +// It returns nil if called, when OAuth is not enabled. +func (a *authInfo) GetClaims() jwt.MapClaims { + return a.claims +} + +// GetUsername returns the username when basic auth is enabled. +// It returns an empty string if called, when basic auth is not enabled. +func (a *authInfo) GetUsername() string { + return a.username +} + +// GetAPIKey returns the APIKey when APIKey auth is enabled. +// It returns an empty strung if called, when APIKey auth is not enabled. +func (a *authInfo) GetAPIKey() string { + return a.apiKey +} + // func (c *Context) reset(w Responder, r Request) { // c.Request = r // c.responder = w diff --git a/pkg/gofr/context_test.go b/pkg/gofr/context_test.go index 599791bf4..9aa23958e 100644 --- a/pkg/gofr/context_test.go +++ b/pkg/gofr/context_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "testing" + "github.com/golang-jwt/jwt/v5" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,6 +17,7 @@ import ( "gofr.dev/pkg/gofr/config" "gofr.dev/pkg/gofr/container" gofrHTTP "gofr.dev/pkg/gofr/http" + "gofr.dev/pkg/gofr/http/middleware" "gofr.dev/pkg/gofr/logging" "gofr.dev/pkg/gofr/version" ) @@ -109,3 +111,71 @@ func TestContext_WriteMessageToSocket(t *testing.T) { expectedResponse := "Hello! GoFr" assert.Equal(t, expectedResponse, string(message)) } + +func TestGetAuthInfo_BasicAuth(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + + ctx := context.WithValue(req.Context(), middleware.Username, "validUser") + *req = *req.Clone(ctx) + + mockContainer, _ := container.NewMockContainer(t) + gofrRq := gofrHTTP.NewRequest(req) + + c := &Context{ + Context: ctx, + Request: gofrRq, + Container: mockContainer, + } + + res := c.GetAuthInfo().GetUsername() + + assert.Equal(t, "validUser", res) +} + +func TestGetAuthInfo_ApiKey(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + + ctx := context.WithValue(req.Context(), middleware.APIKey, "9221e451-451f-4cd6-a23d-2b2d3adea9cf") + + *req = *req.Clone(ctx) + gofrRq := gofrHTTP.NewRequest(req) + + mockContainer, _ := container.NewMockContainer(t) + + c := &Context{ + Context: ctx, + Request: gofrRq, + Container: mockContainer, + } + + res := c.GetAuthInfo().GetAPIKey() + + assert.Equal(t, "9221e451-451f-4cd6-a23d-2b2d3adea9cf", res) +} + +func TestGetAuthInfo_JWTClaims(t *testing.T) { + claims := jwt.MapClaims{ + "sub": "1234567890", + "name": "John Doe", + "admin": true, + } + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + + ctx := context.WithValue(req.Context(), middleware.JWTClaim, claims) + + *req = *req.Clone(ctx) + gofrRq := gofrHTTP.NewRequest(req) + + mockContainer, _ := container.NewMockContainer(t) + + c := &Context{ + Context: ctx, + Request: gofrRq, + Container: mockContainer, + } + + res := c.GetAuthInfo().GetClaims() + + assert.Equal(t, claims, res) +} diff --git a/pkg/gofr/http/middleware/apikey_auth.go b/pkg/gofr/http/middleware/apikey_auth.go index cf6961b90..5ed1210fb 100644 --- a/pkg/gofr/http/middleware/apikey_auth.go +++ b/pkg/gofr/http/middleware/apikey_auth.go @@ -3,6 +3,7 @@ package middleware import ( + "context" "net/http" "gofr.dev/pkg/gofr/container" @@ -15,6 +16,8 @@ type APIKeyAuthProvider struct { Container *container.Container } +const APIKey authMethod = 2 + // APIKeyAuthMiddleware creates a middleware function that enforces API key authentication based on the provided API // keys or a validation function. func APIKeyAuthMiddleware(a APIKeyAuthProvider, apiKeys ...string) func(handler http.Handler) http.Handler { @@ -36,6 +39,9 @@ func APIKeyAuthMiddleware(a APIKeyAuthProvider, apiKeys ...string) func(handler return } + ctx := context.WithValue(r.Context(), APIKey, authKey) + *r = *r.Clone(ctx) + handler.ServeHTTP(w, r) }) } diff --git a/pkg/gofr/http/middleware/basic_auth.go b/pkg/gofr/http/middleware/basic_auth.go index 37390acb6..0f389f016 100644 --- a/pkg/gofr/http/middleware/basic_auth.go +++ b/pkg/gofr/http/middleware/basic_auth.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "encoding/base64" "net/http" "strings" @@ -16,6 +17,8 @@ type BasicAuthProvider struct { Container *container.Container } +const Username authMethod = 1 + // BasicAuthMiddleware creates a middleware function that enforces basic authentication using the provided BasicAuthProvider. func BasicAuthMiddleware(basicAuthProvider BasicAuthProvider) func(handler http.Handler) http.Handler { return func(handler http.Handler) http.Handler { @@ -54,6 +57,9 @@ func BasicAuthMiddleware(basicAuthProvider BasicAuthProvider) func(handler http. return } + ctx := context.WithValue(r.Context(), Username, username) + *r = *r.Clone(ctx) + handler.ServeHTTP(w, r) }) }