Skip to content

Commit

Permalink
Fetch Auth Info from context when authentication is enabled (#1161)
Browse files Browse the repository at this point in the history
  • Loading branch information
coolwednesday authored Nov 4, 2024
1 parent 36092ca commit 366d3f5
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 0 deletions.
51 changes: 51 additions & 0 deletions pkg/gofr/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package gofr
import (
"context"

"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/websocket"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"

"gofr.dev/pkg/gofr/container"
"gofr.dev/pkg/gofr/http/middleware"
)

type Context struct {
Expand All @@ -28,6 +30,12 @@ type Context struct {
responder Responder
}

type AuthInfo interface {
GetClaims() jwt.MapClaims
GetUsername() string
GetAPIKey() string
}

/*
Trace returns an open telemetry span. We have to always close the span after corresponding work is done. Usages:
Expand Down Expand Up @@ -75,6 +83,49 @@ func (c *Context) WriteMessageToSocket(data any) error {
return conn.WriteMessage(websocket.TextMessage, message)
}

type authInfo struct {
claims jwt.MapClaims
username string
apiKey string
}

// GetAuthInfo is a method on context, to access different methods to retrieve authentication info.
//
// GetAuthInfo().GetClaims() : retrieves the jwt claims.
// GetAuthInfo().GetUsername() : retrieves the username while basic authentication.
// GetAuthInfo().GetAPIKey() : retrieves the APIKey being used for authentication.
func (c *Context) GetAuthInfo() AuthInfo {
claims, _ := c.Request.Context().Value(middleware.JWTClaim).(jwt.MapClaims)

APIKey, _ := c.Request.Context().Value(middleware.APIKey).(string)

username, _ := c.Request.Context().Value(middleware.Username).(string)

return &authInfo{
claims: claims,
username: username,
apiKey: APIKey,
}
}

// GetClaims returns a response of jwt.MapClaims type when OAuth is enabled.
// It returns nil if called, when OAuth is not enabled.
func (a *authInfo) GetClaims() jwt.MapClaims {
return a.claims
}

// GetUsername returns the username when basic auth is enabled.
// It returns an empty string if called, when basic auth is not enabled.
func (a *authInfo) GetUsername() string {
return a.username
}

// GetAPIKey returns the APIKey when APIKey auth is enabled.
// It returns an empty strung if called, when APIKey auth is not enabled.
func (a *authInfo) GetAPIKey() string {
return a.apiKey
}

// func (c *Context) reset(w Responder, r Request) {
// c.Request = r
// c.responder = w
Expand Down
70 changes: 70 additions & 0 deletions pkg/gofr/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http/httptest"
"testing"

"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -16,6 +17,7 @@ import (
"gofr.dev/pkg/gofr/config"
"gofr.dev/pkg/gofr/container"
gofrHTTP "gofr.dev/pkg/gofr/http"
"gofr.dev/pkg/gofr/http/middleware"
"gofr.dev/pkg/gofr/logging"
"gofr.dev/pkg/gofr/version"
)
Expand Down Expand Up @@ -109,3 +111,71 @@ func TestContext_WriteMessageToSocket(t *testing.T) {
expectedResponse := "Hello! GoFr"
assert.Equal(t, expectedResponse, string(message))
}

func TestGetAuthInfo_BasicAuth(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)

ctx := context.WithValue(req.Context(), middleware.Username, "validUser")
*req = *req.Clone(ctx)

mockContainer, _ := container.NewMockContainer(t)
gofrRq := gofrHTTP.NewRequest(req)

c := &Context{
Context: ctx,
Request: gofrRq,
Container: mockContainer,
}

res := c.GetAuthInfo().GetUsername()

assert.Equal(t, "validUser", res)
}

func TestGetAuthInfo_ApiKey(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)

ctx := context.WithValue(req.Context(), middleware.APIKey, "9221e451-451f-4cd6-a23d-2b2d3adea9cf")

*req = *req.Clone(ctx)
gofrRq := gofrHTTP.NewRequest(req)

mockContainer, _ := container.NewMockContainer(t)

c := &Context{
Context: ctx,
Request: gofrRq,
Container: mockContainer,
}

res := c.GetAuthInfo().GetAPIKey()

assert.Equal(t, "9221e451-451f-4cd6-a23d-2b2d3adea9cf", res)
}

func TestGetAuthInfo_JWTClaims(t *testing.T) {
claims := jwt.MapClaims{
"sub": "1234567890",
"name": "John Doe",
"admin": true,
}

req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)

ctx := context.WithValue(req.Context(), middleware.JWTClaim, claims)

*req = *req.Clone(ctx)
gofrRq := gofrHTTP.NewRequest(req)

mockContainer, _ := container.NewMockContainer(t)

c := &Context{
Context: ctx,
Request: gofrRq,
Container: mockContainer,
}

res := c.GetAuthInfo().GetClaims()

assert.Equal(t, claims, res)
}
6 changes: 6 additions & 0 deletions pkg/gofr/http/middleware/apikey_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package middleware

import (
"context"
"net/http"

"gofr.dev/pkg/gofr/container"
Expand All @@ -15,6 +16,8 @@ type APIKeyAuthProvider struct {
Container *container.Container
}

const APIKey authMethod = 2

// APIKeyAuthMiddleware creates a middleware function that enforces API key authentication based on the provided API
// keys or a validation function.
func APIKeyAuthMiddleware(a APIKeyAuthProvider, apiKeys ...string) func(handler http.Handler) http.Handler {
Expand All @@ -36,6 +39,9 @@ func APIKeyAuthMiddleware(a APIKeyAuthProvider, apiKeys ...string) func(handler
return
}

ctx := context.WithValue(r.Context(), APIKey, authKey)
*r = *r.Clone(ctx)

handler.ServeHTTP(w, r)
})
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/gofr/http/middleware/basic_auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"context"
"encoding/base64"
"net/http"
"strings"
Expand All @@ -16,6 +17,8 @@ type BasicAuthProvider struct {
Container *container.Container
}

const Username authMethod = 1

// BasicAuthMiddleware creates a middleware function that enforces basic authentication using the provided BasicAuthProvider.
func BasicAuthMiddleware(basicAuthProvider BasicAuthProvider) func(handler http.Handler) http.Handler {
return func(handler http.Handler) http.Handler {
Expand Down Expand Up @@ -54,6 +57,9 @@ func BasicAuthMiddleware(basicAuthProvider BasicAuthProvider) func(handler http.
return
}

ctx := context.WithValue(r.Context(), Username, username)
*r = *r.Clone(ctx)

handler.ServeHTTP(w, r)
})
}
Expand Down

0 comments on commit 366d3f5

Please sign in to comment.