Skip to content

Commit

Permalink
Rate limiting (per operation, per auth/ip) (#2640)
Browse files Browse the repository at this point in the history
* 2347: Admin/Runtime: Rate limiting (per operation, per auth/ip)

* Provided descriptions for key structs

* Refactored according to PR comments

* Limiter interface

* Fixed formatting

* Fixed tests

* Updated func declarations according the lint

* Changed LimiterHTTPHandler visibility

* Made limit checks more explicit by materializing them in respective packages

* Added a request to Metadata so that it is available in a handler

* Fixed linter errors

* Fixed linter errors

* added zap.Error(err)

* Reverted auto-IDE change

* Added a comment

* Simplified by using grpc_auth.UnaryServerInterceptor

* Tripled rate limits before we have real numbers measured by a usage tracking

* Reverted auto change

* check if claims == nil

---------

Co-authored-by: e.sevastyanov <[email protected]>
Co-authored-by: Benjamin Egelund-Müller <[email protected]>
  • Loading branch information
3 people committed Jul 3, 2023
1 parent 3e8090a commit 4a9b8e5
Show file tree
Hide file tree
Showing 19 changed files with 359 additions and 38 deletions.
37 changes: 27 additions & 10 deletions admin/server/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package auth
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/url"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/rilldata/rill/admin/database"
"github.com/rilldata/rill/runtime/pkg/middleware"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/ratelimit"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.uber.org/zap"
"golang.org/x/oauth2"
Expand All @@ -25,18 +28,32 @@ const (
// RegisterEndpoints adds HTTP endpoints for auth.
// The mux must be served on the ExternalURL of the Authenticator since the logic in these handlers relies on knowing the full external URIs.
// Note that these are not gRPC handlers, just regular HTTP endpoints that we mount on the gRPC-gateway mux.
func (a *Authenticator) RegisterEndpoints(mux *http.ServeMux) {
func (a *Authenticator) RegisterEndpoints(mux *http.ServeMux, limiter ratelimit.Limiter) {
// checkLimit needs access to limiter
checkLimit := func(route string, req *http.Request) error {
claims := GetClaims(req.Context())
if claims == nil || claims.OwnerType() == OwnerTypeAnon {
limitKey := ratelimit.AnonLimitKey(route, observability.HTTPPeer(req))
if err := limiter.Limit(req.Context(), limitKey, ratelimit.Sensitive); err != nil {
if errors.As(err, &ratelimit.QuotaExceededError{}) {
return middleware.NewHTTPError(http.StatusTooManyRequests, err.Error())
}
return err
}
}
return nil
}
// TODO: Add helper utils to clean this up
inner := http.NewServeMux()
inner.Handle("/auth/signup", otelhttp.WithRouteTag("/auth/signup", http.HandlerFunc(a.authSignup)))
inner.Handle("/auth/login", otelhttp.WithRouteTag("/auth/login", http.HandlerFunc(a.authLogin)))
inner.Handle("/auth/callback", otelhttp.WithRouteTag("/auth/callback", http.HandlerFunc(a.authLoginCallback)))
inner.Handle("/auth/with-token", otelhttp.WithRouteTag("/auth/with-token", http.HandlerFunc(a.authWithToken)))
inner.Handle("/auth/logout", otelhttp.WithRouteTag("/auth/logout", http.HandlerFunc(a.authLogout)))
inner.Handle("/auth/logout/callback", otelhttp.WithRouteTag("/auth/logout/callback", http.HandlerFunc(a.authLogoutCallback)))
inner.Handle("/auth/oauth/device_authorization", otelhttp.WithRouteTag("/auth/oauth/device_authorization", http.HandlerFunc(a.handleDeviceCodeRequest)))
inner.Handle("/auth/oauth/device", otelhttp.WithRouteTag("/auth/oauth/device", a.HTTPMiddleware(http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware
inner.Handle("/auth/oauth/token", otelhttp.WithRouteTag("/auth/oauth/token", http.HandlerFunc(a.getAccessToken)))
inner.Handle("/auth/signup", otelhttp.WithRouteTag("/auth/signup", middleware.RequestHTTPHandler("/auth/signup", checkLimit, http.HandlerFunc(a.authSignup))))
inner.Handle("/auth/login", otelhttp.WithRouteTag("/auth/login", middleware.RequestHTTPHandler("/auth/login", checkLimit, http.HandlerFunc(a.authLogin))))
inner.Handle("/auth/callback", otelhttp.WithRouteTag("/auth/callback", middleware.RequestHTTPHandler("/auth/callback", checkLimit, http.HandlerFunc(a.authLoginCallback))))
inner.Handle("/auth/with-token", otelhttp.WithRouteTag("/auth/with-token", middleware.RequestHTTPHandler("/auth/with-token", checkLimit, http.HandlerFunc(a.authWithToken))))
inner.Handle("/auth/logout", otelhttp.WithRouteTag("/auth/logout", middleware.RequestHTTPHandler("/auth/logout", checkLimit, http.HandlerFunc(a.authLogout))))
inner.Handle("/auth/logout/callback", otelhttp.WithRouteTag("/auth/logout/callback", middleware.RequestHTTPHandler("/auth/logout/callback", checkLimit, http.HandlerFunc(a.authLogoutCallback))))
inner.Handle("/auth/oauth/device_authorization", otelhttp.WithRouteTag("/auth/oauth/device_authorization", middleware.RequestHTTPHandler("/auth/oauth/device_authorization", checkLimit, http.HandlerFunc(a.handleDeviceCodeRequest))))
inner.Handle("/auth/oauth/device", otelhttp.WithRouteTag("/auth/oauth/device", a.HTTPMiddleware(middleware.RequestHTTPHandler("/auth/oauth/device", checkLimit, http.HandlerFunc(a.handleUserCodeConfirmation))))) // NOTE: Uses auth middleware
inner.Handle("/auth/oauth/token", otelhttp.WithRouteTag("/auth/oauth/token", middleware.RequestHTTPHandler("/auth/oauth/token", checkLimit, http.HandlerFunc(a.getAccessToken))))
mux.Handle("/auth/", observability.Middleware("admin", a.logger, inner))
}

Expand Down
31 changes: 26 additions & 5 deletions admin/server/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import (
"github.com/rilldata/rill/admin/pkg/urlutil"
"github.com/rilldata/rill/admin/server/auth"
adminv1 "github.com/rilldata/rill/proto/gen/rill/admin/v1"
"github.com/rilldata/rill/runtime/pkg/middleware"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/ratelimit"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -142,11 +144,16 @@ func (s *Server) registerGithubEndpoints(mux *http.ServeMux) {
// TODO: Add helper utils to clean this up
inner := http.NewServeMux()
inner.Handle("/github/webhook", otelhttp.WithRouteTag("/github/webhook", http.HandlerFunc(s.githubWebhook)))
inner.Handle("/github/connect", otelhttp.WithRouteTag("/github/connect", s.authenticator.HTTPMiddleware(http.HandlerFunc(s.githubConnect))))
inner.Handle("/github/connect/callback", otelhttp.WithRouteTag("/github/connect/callback", s.authenticator.HTTPMiddleware(http.HandlerFunc(s.githubConnectCallback))))
inner.Handle("/github/auth/login", otelhttp.WithRouteTag("github/auth/login", s.authenticator.HTTPMiddleware(http.HandlerFunc(s.githubAuthLogin))))
inner.Handle("/github/auth/callback", otelhttp.WithRouteTag("github/auth/callback", s.authenticator.HTTPMiddleware(http.HandlerFunc(s.githubAuthCallback))))
inner.Handle("/github/post-auth-redirect", otelhttp.WithRouteTag("github/post-auth-redirect", s.authenticator.HTTPMiddleware(http.HandlerFunc(s.githubRepoStatus))))
inner.Handle("/github/connect", otelhttp.WithRouteTag("/github/connect", s.authenticator.HTTPMiddleware(
middleware.RequestHTTPHandler("/github/connect", s.checkGithubRateLimit, http.HandlerFunc(s.githubConnect)))))
inner.Handle("/github/connect/callback", otelhttp.WithRouteTag("/github/connect/callback", s.authenticator.HTTPMiddleware(
middleware.RequestHTTPHandler("/github/connect/callback", s.checkGithubRateLimit, http.HandlerFunc(s.githubConnectCallback)))))
inner.Handle("/github/auth/login", otelhttp.WithRouteTag("github/auth/login", s.authenticator.HTTPMiddleware(
middleware.RequestHTTPHandler("github/auth/login", s.checkGithubRateLimit, http.HandlerFunc(s.githubAuthLogin)))))
inner.Handle("/github/auth/callback", otelhttp.WithRouteTag("github/auth/callback", s.authenticator.HTTPMiddleware(
middleware.RequestHTTPHandler("github/auth/callback", s.checkGithubRateLimit, http.HandlerFunc(s.githubAuthCallback)))))
inner.Handle("/github/post-auth-redirect", otelhttp.WithRouteTag("github/post-auth-redirect", s.authenticator.HTTPMiddleware(
middleware.RequestHTTPHandler("github/post-auth-redirect", s.checkGithubRateLimit, http.HandlerFunc(s.githubRepoStatus)))))
mux.Handle("/github/", observability.Middleware("admin", s.logger, inner))
}

Expand Down Expand Up @@ -585,3 +592,17 @@ func (s *Server) redirectLogin(w http.ResponseWriter, r *http.Request) {

http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
}

func (s *Server) checkGithubRateLimit(route string, req *http.Request) error {
claims := auth.GetClaims(req.Context())
if claims == nil || claims.OwnerType() == auth.OwnerTypeAnon {
limitKey := ratelimit.AnonLimitKey(route, observability.HTTPPeer(req))
if err := s.limiter.Limit(req.Context(), limitKey, ratelimit.Sensitive); err != nil {
if errors.As(err, &ratelimit.QuotaExceededError{}) {
return middleware.NewHTTPError(http.StatusTooManyRequests, err.Error())
}
return err
}
}
return nil
}
31 changes: 29 additions & 2 deletions admin/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/rilldata/rill/runtime/pkg/graceful"
"github.com/rilldata/rill/runtime/pkg/middleware"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/ratelimit"
runtimeauth "github.com/rilldata/rill/runtime/server/auth"
"github.com/rs/cors"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
Expand Down Expand Up @@ -68,11 +69,12 @@ type Server struct {
authenticator *auth.Authenticator
issuer *runtimeauth.Issuer
urls *externalURLs
limiter ratelimit.Limiter
}

var _ adminv1.AdminServiceServer = (*Server)(nil)

func New(logger *zap.Logger, adm *admin.Service, issuer *runtimeauth.Issuer, opts *Options) (*Server, error) {
func New(logger *zap.Logger, adm *admin.Service, issuer *runtimeauth.Issuer, limiter ratelimit.Limiter, opts *Options) (*Server, error) {
externalURL, err := url.Parse(opts.ExternalURL)
if err != nil {
return nil, fmt.Errorf("failed to parse external URL: %w", err)
Expand Down Expand Up @@ -106,6 +108,7 @@ func New(logger *zap.Logger, adm *admin.Service, issuer *runtimeauth.Issuer, opt
authenticator: authenticator,
issuer: issuer,
urls: newURLRegistry(opts),
limiter: limiter,
}, nil
}

Expand All @@ -120,6 +123,7 @@ func (s *Server) ServeGRPC(ctx context.Context) error {
grpc_auth.StreamServerInterceptor(checkUserAgent),
grpc_validator.StreamServerInterceptor(),
s.authenticator.StreamServerInterceptor(),
grpc_auth.StreamServerInterceptor(s.checkRateLimit),
),
grpc.ChainUnaryInterceptor(
middleware.TimeoutUnaryServerInterceptor(timeoutSelector),
Expand All @@ -129,6 +133,7 @@ func (s *Server) ServeGRPC(ctx context.Context) error {
grpc_auth.UnaryServerInterceptor(checkUserAgent),
grpc_validator.UnaryServerInterceptor(),
s.authenticator.UnaryServerInterceptor(),
grpc_auth.UnaryServerInterceptor(s.checkRateLimit),
),
)

Expand Down Expand Up @@ -176,7 +181,7 @@ func (s *Server) HTTPHandler(ctx context.Context) (http.Handler, error) {
mux.Handle("/.well-known/jwks.json", s.issuer.WellKnownHandler())

// Add auth endpoints (not gRPC handlers, just regular endpoints on /auth/*)
s.authenticator.RegisterEndpoints(mux)
s.authenticator.RegisterEndpoints(mux, s.limiter)

// Add Github-related endpoints (not gRPC handlers, just regular endpoints on /github/*)
s.registerGithubEndpoints(mux)
Expand Down Expand Up @@ -336,3 +341,25 @@ func newURLRegistry(opts *Options) *externalURLs {
authLogin: urlutil.MustJoinURL(opts.ExternalURL, "/auth/login"),
}
}

func (s *Server) checkRateLimit(ctx context.Context) (context.Context, error) {
var limitKey string
method, ok := grpc.Method(ctx)
if !ok {
return ctx, fmt.Errorf("server context does not have a method")
}
if auth.GetClaims(ctx).OwnerType() == auth.OwnerTypeAnon {
limitKey = ratelimit.AnonLimitKey(method, observability.GrpcPeer(ctx))
} else {
limitKey = ratelimit.AuthLimitKey(method, auth.GetClaims(ctx).OwnerID())
}

if err := s.limiter.Limit(ctx, limitKey, ratelimit.Default); err != nil {
if errors.As(err, &ratelimit.QuotaExceededError{}) {
return ctx, status.Errorf(codes.ResourceExhausted, err.Error())
}
return ctx, err
}

return ctx, nil
}
15 changes: 14 additions & 1 deletion cli/cmd/admin/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import (

"github.com/joho/godotenv"
"github.com/kelseyhightower/envconfig"
"github.com/redis/go-redis/v9"
"github.com/rilldata/rill/admin"
"github.com/rilldata/rill/admin/email"
"github.com/rilldata/rill/admin/server"
"github.com/rilldata/rill/admin/worker"
"github.com/rilldata/rill/cli/pkg/config"
"github.com/rilldata/rill/runtime/pkg/graceful"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/ratelimit"
"github.com/rilldata/rill/runtime/server/auth"
"github.com/spf13/cobra"
"go.uber.org/zap"
Expand Down Expand Up @@ -61,6 +63,7 @@ type Config struct {
EmailSenderEmail string `split_words:"true"`
EmailSenderName string `split_words:"true"`
EmailBCC string `split_words:"true"`
RedisURL string `default:"" split_words:"true"`
}

// StartCmd starts an admin server. It only allows configuration using environment variables.
Expand Down Expand Up @@ -176,7 +179,17 @@ func StartCmd(cliCfg *config.Config) *cobra.Command {

// Init and run server
if runServer {
srv, err := server.New(logger, adm, issuer, &server.Options{
var limiter ratelimit.Limiter
if conf.RedisURL == "" {
limiter = ratelimit.NewNoop()
} else {
opts, err := redis.ParseURL(conf.RedisURL)
if err != nil {
logger.Fatal("failed to parse redis url", zap.Error(err))
}
limiter = ratelimit.NewRedis(redis.NewClient(opts))
}
srv, err := server.New(logger, adm, issuer, limiter, &server.Options{
HTTPPort: conf.HTTPPort,
GRPCPort: conf.GRPCPort,
ExternalURL: conf.ExternalURL,
Expand Down
17 changes: 16 additions & 1 deletion cli/cmd/runtime/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import (

"github.com/joho/godotenv"
"github.com/kelseyhightower/envconfig"
"github.com/redis/go-redis/v9"
"github.com/rilldata/rill/cli/pkg/config"
"github.com/rilldata/rill/runtime"
"github.com/rilldata/rill/runtime/pkg/graceful"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/ratelimit"
"github.com/rilldata/rill/runtime/server"
"github.com/spf13/cobra"
"go.uber.org/zap"
Expand Down Expand Up @@ -52,6 +54,8 @@ type Config struct {
// AllowHostAccess controls whether instance can use host credentials and
// local_file sources can access directory outside repo
AllowHostAccess bool `default:"false" split_words:"true"`
// Redis server address host:port
RedisURL string `default:"" split_words:"true"`
}

// StartCmd starts a stand-alone runtime server. It only allows configuration using environment variables.
Expand Down Expand Up @@ -119,6 +123,17 @@ func StartCmd(cliCfg *config.Config) *cobra.Command {
// Create ctx that cancels on termination signals
ctx := graceful.WithCancelOnTerminate(context.Background())

var limiter ratelimit.Limiter
if conf.RedisURL == "" {
limiter = ratelimit.NewNoop()
} else {
opts, err := redis.ParseURL(conf.RedisURL)
if err != nil {
logger.Fatal("failed to parse redis url", zap.Error(err))
}
limiter = ratelimit.NewRedis(redis.NewClient(opts))
}

// Init server
srvOpts := &server.Options{
HTTPPort: conf.HTTPPort,
Expand All @@ -129,7 +144,7 @@ func StartCmd(cliCfg *config.Config) *cobra.Command {
AuthIssuerURL: conf.AuthIssuerURL,
AuthAudienceURL: conf.AuthAudienceURL,
}
s, err := server.NewServer(ctx, srvOpts, rt, logger)
s, err := server.NewServer(ctx, srvOpts, rt, logger, limiter)
if err != nil {
logger.Fatal("error: could not create server", zap.Error(err))
}
Expand Down
3 changes: 2 additions & 1 deletion cli/pkg/local/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/rilldata/rill/runtime/drivers"
"github.com/rilldata/rill/runtime/pkg/graceful"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/ratelimit"
runtimeserver "github.com/rilldata/rill/runtime/server"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
Expand Down Expand Up @@ -252,7 +253,7 @@ func (a *App) Serve(httpPort, grpcPort int, enableUI, openBrowser, readonly bool
AllowedOrigins: []string{"*"},
ServePrometheus: true,
}
runtimeServer, err := runtimeserver.NewServer(ctx, opts, a.Runtime, serverLogger)
runtimeServer, err := runtimeserver.NewServer(ctx, opts, a.Runtime, serverLogger, ratelimit.NewNoop())
if err != nil {
return err
}
Expand Down
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/MicahParks/keyfunc v1.9.0
github.com/NYTimes/gziphandler v1.1.1
github.com/XSAM/otelsql v0.23.0
github.com/alicebob/miniredis v2.5.0+incompatible
github.com/apache/arrow/go/v11 v11.0.0
github.com/apache/calcite-avatica-go/v5 v5.2.0
github.com/aws/aws-sdk-go v1.44.268
Expand All @@ -25,6 +26,7 @@ require (
github.com/go-git/go-git/v5 v5.7.0
github.com/go-logr/zapr v1.2.4
github.com/go-playground/validator/v10 v10.14.0
github.com/go-redis/redis_rate/v10 v10.0.1
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/go-github/v50 v50.2.0
github.com/google/uuid v1.3.0
Expand All @@ -44,6 +46,7 @@ require (
github.com/marcboeker/go-duckdb v1.4.1
github.com/mitchellh/mapstructure v1.5.0
github.com/prometheus/client_golang v1.15.1
github.com/redis/go-redis/v9 v9.0.2
github.com/rs/cors v1.9.0
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
Expand Down Expand Up @@ -93,6 +96,7 @@ require (
github.com/Microsoft/go-winio v0.6.1 // indirect
github.com/ProtonMail/go-crypto v0.0.0-20230518184743-7afd39499903 // indirect
github.com/acomagu/bufpipe v1.0.4 // indirect
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
github.com/andybalholm/brotli v1.0.5 // indirect
github.com/apache/thrift v0.18.1 // indirect
github.com/aws/aws-sdk-go-v2 v1.18.0 // indirect
Expand Down Expand Up @@ -121,6 +125,7 @@ require (
github.com/containerd/containerd v1.7.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/docker/distribution v2.8.1+incompatible // indirect
github.com/docker/docker v23.0.5+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
Expand All @@ -142,6 +147,7 @@ require (
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/gomodule/redigo v1.8.9 // indirect
github.com/google/flatbuffers v23.5.9+incompatible // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/go-github/v52 v52.0.0 // indirect
Expand Down Expand Up @@ -214,6 +220,7 @@ require (
github.com/xanzy/ssh-agent v0.3.3 // indirect
github.com/xuri/efp v0.0.0-20220603152613-6918739fd470 // indirect
github.com/xuri/nfp v0.0.0-20220409054826-5e722a1d9e22 // indirect
github.com/yuin/gopher-lua v1.1.0 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.16.0 // indirect
Expand Down
Loading

0 comments on commit 4a9b8e5

Please sign in to comment.