Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fastapi: fix wrapping of middlewares #3012

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `opentelemetry-instrumentation-httpx`: instrument_client is a static method again
([#3003](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3003))
- `opentelemetry-instrumentation-fastapi`: instrument unhandled exceptions
([#3012](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3012))

### Breaking changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,14 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
from __future__ import annotations

import logging
import types
from typing import Collection, Literal

import fastapi
from starlette.applications import Starlette
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.routing import Match
from starlette.types import ASGIApp

from opentelemetry.instrumentation._semconv import (
_get_schema_url,
Expand Down Expand Up @@ -280,21 +284,40 @@ def instrument_app(
schema_url=_get_schema_url(sem_conv_opt_in_mode),
)

app.add_middleware(
OpenTelemetryMiddleware,
excluded_urls=excluded_urls,
default_span_details=_get_default_span_details,
server_request_hook=server_request_hook,
client_request_hook=client_request_hook,
client_response_hook=client_response_hook,
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
tracer=tracer,
meter=meter,
http_capture_headers_server_request=http_capture_headers_server_request,
http_capture_headers_server_response=http_capture_headers_server_response,
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
exclude_spans=exclude_spans,
# Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware
# as the outermost middleware.
# Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able
# to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do.

def build_middleware_stack(self: Starlette) -> ASGIApp:
stack = type(self).build_middleware_stack(self)
stack = OpenTelemetryMiddleware(
stack,
excluded_urls=excluded_urls,
default_span_details=_get_default_span_details,
server_request_hook=server_request_hook,
client_request_hook=client_request_hook,
client_response_hook=client_response_hook,
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
tracer=tracer,
meter=meter,
http_capture_headers_server_request=http_capture_headers_server_request,
http_capture_headers_server_response=http_capture_headers_server_response,
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
exclude_spans=exclude_spans,
)
# Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware
# are handled.
# This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that
# to impact the user's application just because we wrapped the middlewares in this order.
stack = ServerErrorMiddleware(stack)
Copy link
Contributor

@lzchen lzchen Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, is the purpose of this pr to handle exceptions generated from the OpenTelemetryMiddleware itself (if there is a bug) or unhandled exceptions thrown from the request?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! If there is a bug in OpenTelemetryMiddleware without this double wrapping then the server might not send a 500 to the client and instead abruptly disconnect. In practice I believe every ASGI server I know of does also have it's own handling of unhandled exceptions such that this would not be the case, but I'm not sure we want to rely on that + there would still be a noticeable change (different response content at least, maybe different headers, etc.)

return stack

app._original_build_middleware_stack = app.build_middleware_stack
app.build_middleware_stack = types.MethodType(
build_middleware_stack, app
)

app._is_instrumented_by_opentelemetry = True
if app not in _InstrumentedFastAPI._instrumented_fastapi_apps:
_InstrumentedFastAPI._instrumented_fastapi_apps.add(app)
Expand All @@ -305,11 +328,12 @@ def instrument_app(

@staticmethod
def uninstrument_app(app: fastapi.FastAPI):
app.user_middleware = [
x
for x in app.user_middleware
if x.cls is not OpenTelemetryMiddleware
]
original_build_middleware_stack = getattr(
app, "_original_build_middleware_stack", None
)
if original_build_middleware_stack:
app.build_middleware_stack = original_build_middleware_stack
del app._original_build_middleware_stack
app.middleware_stack = app.build_middleware_stack()
app._is_instrumented_by_opentelemetry = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,19 @@ async def _(param: str):
async def _():
return {"message": "ok"}

@app.get("/error")
async def _():
raise UnhandledException("This is an unhandled exception")

app.mount("/sub", app=sub_app)

return app


class UnhandledException(Exception):
pass


class TestBaseManualFastAPI(TestBaseFastAPI):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -398,6 +406,26 @@ def test_fastapi_excluded_urls(self):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)

def test_fastapi_unhandled_exception(self):
"""If the application has an unhandled error the instrumentation should capture that a 500 response is returned."""
try:
self._client.get("/error")
except UnhandledException:
pass
else:
self.fail("Expected UnhandledException")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
for span in spans:
self.assertIn("GET /error", span.name)
self.assertEqual(
span.attributes[SpanAttributes.HTTP_ROUTE], "/error"
)
self.assertEqual(
span.attributes[SpanAttributes.HTTP_STATUS_CODE], 500
)

def test_fastapi_excluded_urls_not_env(self):
"""Ensure that given fastapi routes are excluded when passed explicitly (not in the environment)"""
app = self._create_app_explicit_excluded_urls()
Expand Down
Loading