Skip to content

Commit

Permalink
Fix FastAPI Context Propagation (#420)
Browse files Browse the repository at this point in the history
* Fix context propagation in fastapi.

Co-authored-by: Lalleh Rafeei <lrafeei@users.noreply.github.com>

* Add testing for context propagation errors

* Format

* Restore disabled test

* Clean up context implementation

* Format

* [Mega-Linter] Apply linters fixes

* Bump Tests

* Expand setuptools-scm versions

* Format

* [Mega-Linter] Apply linters fixes

* Bump Tests

Co-authored-by: Lalleh Rafeei <lrafeei@users.noreply.github.com>
Co-authored-by: TimPansino <TimPansino@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 3, 2021
1 parent a3ef06f commit 768019d
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 280 deletions.
63 changes: 51 additions & 12 deletions newrelic/core/context.py
Expand Up @@ -16,41 +16,80 @@
This module implements utilities for context propagation for tracing across threads.
"""

import logging

from newrelic.common.object_wrapper import function_wrapper
from newrelic.core.trace_cache import trace_cache

_logger = logging.getLogger(__name__)


class ContextOf(object):
def __init__(self, trace_cache_id):
def __init__(self, trace=None, request=None, trace_cache_id=None):
self.trace = None
self.trace_cache = trace_cache()
self.trace = self.trace_cache._cache.get(trace_cache_id)
self.thread_id = None
self.restore = None
self.should_restore = False

# Extract trace if possible, else leave as None for safety
if trace is None and request is None and trace_cache_id is None:
_logger.error(
"Runtime instrumentation error. Request context propagation failed. No trace or request provided. Report this issue to New Relic support.",
)
elif trace is not None:
self.trace = trace
elif trace_cache_id is not None:
self.trace = self.trace_cache._cache.get(trace_cache_id, None)
if self.trace is None:
_logger.error(
"Runtime instrumentation error. Request context propagation failed. No trace with id %s. Report this issue to New Relic support.",
trace_cache_id,
)
elif hasattr(request, "_nr_trace") and request._nr_trace is not None:
# Unpack traces from objects patched with them
self.trace = request._nr_trace
else:
_logger.error(
"Runtime instrumentation error. Request context propagation failed. No context attached to request. Report this issue to New Relic support.",
)

def __enter__(self):
if self.trace:
self.thread_id = self.trace_cache.current_thread_id()
self.restore = self.trace_cache._cache.get(self.thread_id)

# Save previous cache contents
self.restore = self.trace_cache._cache.get(self.thread_id, None)
self.should_restore = True

# Set context in trace cache
self.trace_cache._cache[self.thread_id] = self.trace

return self

def __exit__(self, exc, value, tb):
if self.restore:
self.trace_cache._cache[self.thread_id] = self.restore
if self.should_restore:
if self.restore is not None:
# Restore previous contents
self.trace_cache._cache[self.thread_id] = self.restore
else:
# Remove entry from cache
self.trace_cache._cache.pop(self.thread_id)


async def context_wrapper_async(awaitable, trace_cache_id):
with ContextOf(trace_cache_id):
return await awaitable


def context_wrapper(func, trace_cache_id):
def context_wrapper(func, trace=None, request=None, trace_cache_id=None):
@function_wrapper
def _context_wrapper(wrapped, instance, args, kwargs):
with ContextOf(trace_cache_id):
with ContextOf(trace=trace, request=request, trace_cache_id=trace_cache_id):
return wrapped(*args, **kwargs)

return _context_wrapper(func)


async def context_wrapper_async(awaitable, trace=None, request=None, trace_cache_id=None):
with ContextOf(trace=trace, request=request, trace_cache_id=trace_cache_id):
return await awaitable


def current_thread_id():
return trace_cache().current_thread_id()
19 changes: 0 additions & 19 deletions newrelic/core/trace_cache.py
Expand Up @@ -275,25 +275,6 @@ def save_trace(self, trace):
task = current_task(self.asyncio)
trace._task = task

def thread_start(self, trace):
current_thread_id = self.current_thread_id()
if current_thread_id not in self._cache:
self._cache[current_thread_id] = trace
else:
_logger.error(
"Runtime instrumentation error. An active "
"trace already exists in the cache on thread_id %s. Report "
"this issue to New Relic support.\n ",
current_thread_id,
)
return None

return current_thread_id

def thread_stop(self, thread_id):
if thread_id:
self._cache.pop(thread_id, None)

def pop_current(self, trace):
"""Restore the trace's parent under the thread ID of the current
executing thread."""
Expand Down
14 changes: 6 additions & 8 deletions newrelic/hooks/adapter_asgiref.py
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from newrelic.api.time_trace import current_trace
from newrelic.common.object_wrapper import wrap_function_wrapper
from newrelic.core.trace_cache import trace_cache
from newrelic.core.context import context_wrapper_async, ContextOf
from newrelic.core.context import ContextOf, context_wrapper_async


def _bind_thread_handler(loop, source_task, *args, **kwargs):
Expand All @@ -23,17 +23,15 @@ def _bind_thread_handler(loop, source_task, *args, **kwargs):

def thread_handler_wrapper(wrapped, instance, args, kwargs):
task = _bind_thread_handler(*args, **kwargs)
with ContextOf(id(task)):
with ContextOf(trace_cache_id=id(task)):
return wrapped(*args, **kwargs)


def main_wrap_wrapper(wrapped, instance, args, kwargs):
awaitable = wrapped(*args, **kwargs)
return context_wrapper_async(awaitable, trace_cache().current_thread_id())
return context_wrapper_async(awaitable, current_trace())


def instrument_asgiref_sync(module):
wrap_function_wrapper(module, 'SyncToAsync.thread_handler',
thread_handler_wrapper)
wrap_function_wrapper(module, 'AsyncToSync.main_wrap',
main_wrap_wrapper)
wrap_function_wrapper(module, "SyncToAsync.thread_handler", thread_handler_wrapper)
wrap_function_wrapper(module, "AsyncToSync.main_wrap", main_wrap_wrapper)
29 changes: 6 additions & 23 deletions newrelic/hooks/framework_fastapi.py
Expand Up @@ -13,25 +13,11 @@
# limitations under the License.

from copy import copy
from newrelic.api.time_trace import current_trace

from newrelic.api.function_trace import FunctionTraceWrapper
from newrelic.common.object_wrapper import wrap_function_wrapper, function_wrapper
from newrelic.api.time_trace import current_trace
from newrelic.common.object_names import callable_name
from newrelic.core.trace_cache import trace_cache


def use_context(trace):

@function_wrapper
def context_wrapper(wrapped, instance, args, kwargs):
cache = trace_cache()
thread_id = cache.thread_start(trace)
try:
return wrapped(*args, **kwargs)
finally:
cache.thread_stop(thread_id)

return context_wrapper
from newrelic.common.object_wrapper import wrap_function_wrapper


def wrap_run_endpoint_function(wrapped, instance, args, kwargs):
Expand All @@ -41,12 +27,9 @@ def wrap_run_endpoint_function(wrapped, instance, args, kwargs):
name = callable_name(dependant.call)
trace.transaction.set_transaction_name(name)

if not kwargs["is_coroutine"]:
dependant = kwargs["dependant"] = copy(dependant)
dependant.call = use_context(trace)(FunctionTraceWrapper(dependant.call))
return wrapped(*args, **kwargs)
else:
return FunctionTraceWrapper(wrapped, name=name)(*args, **kwargs)
dependant = kwargs["dependant"] = copy(dependant)
dependant.call = FunctionTraceWrapper(dependant.call)
return wrapped(*args, **kwargs)

return wrapped(*args, **kwargs)

Expand Down
71 changes: 15 additions & 56 deletions newrelic/hooks/framework_starlette.py
Expand Up @@ -25,8 +25,7 @@
wrap_function_wrapper,
)
from newrelic.core.config import should_ignore_error
from newrelic.core.context import context_wrapper, current_thread_id
from newrelic.core.trace_cache import trace_cache
from newrelic.core.context import ContextOf, context_wrapper


def framework_details():
Expand All @@ -43,30 +42,10 @@ def bind_exc(request, exc, *args, **kwargs):
return exc


class RequestContext(object):
def __init__(self, request):
self.request = request
self.force_propagate = False
self.thread_id = None

def __enter__(self):
trace = getattr(self.request, "_nr_trace", None)
self.force_propagate = trace and current_trace() is None

# Propagate trace context onto the current task
if self.force_propagate:
self.thread_id = trace_cache().thread_start(trace)

def __exit__(self, exc, value, tb):
# Remove any context from the current thread as it was force propagated above
if self.force_propagate:
trace_cache().thread_stop(self.thread_id)


@function_wrapper
def route_naming_wrapper(wrapped, instance, args, kwargs):

with RequestContext(bind_request(*args, **kwargs)):
with ContextOf(request=bind_request(*args, **kwargs)):
transaction = current_transaction()
if transaction:
transaction.set_transaction_name(callable_name(wrapped), priority=2)
Expand Down Expand Up @@ -136,16 +115,14 @@ def wrap_add_middleware(wrapped, instance, args, kwargs):
return wrapped(wrap_middleware(middleware), *args, **kwargs)


def bind_middleware_starlette(
debug=False, routes=None, middleware=None, *args, **kwargs
):
def bind_middleware_starlette(debug=False, routes=None, middleware=None, *args, **kwargs): # pylint: disable=W1113
return middleware


def wrap_starlette(wrapped, instance, args, kwargs):
middlewares = bind_middleware_starlette(*args, **kwargs)
if middlewares:
for middleware in middlewares:
for middleware in middlewares: # pylint: disable=E1133
cls = getattr(middleware, "cls", None)
if cls and not hasattr(cls, "__wrapped__"):
middleware.cls = wrap_middleware(cls)
Expand All @@ -171,11 +148,9 @@ async def wrap_exception_handler_async(coro, exc):

def wrap_exception_handler(wrapped, instance, args, kwargs):
if is_coroutine_function(wrapped):
return wrap_exception_handler_async(
FunctionTraceWrapper(wrapped)(*args, **kwargs), bind_exc(*args, **kwargs)
)
return wrap_exception_handler_async(FunctionTraceWrapper(wrapped)(*args, **kwargs), bind_exc(*args, **kwargs))
else:
with RequestContext(bind_request(*args, **kwargs)):
with ContextOf(request=bind_request(*args, **kwargs)):
response = FunctionTraceWrapper(wrapped)(*args, **kwargs)
record_response_error(response, bind_exc(*args, **kwargs))
return response
Expand All @@ -190,9 +165,7 @@ def wrap_server_error_handler(wrapped, instance, args, kwargs):


def wrap_add_exception_handler(wrapped, instance, args, kwargs):
exc_class_or_status_code, handler, args, kwargs = bind_add_exception_handler(
*args, **kwargs
)
exc_class_or_status_code, handler, args, kwargs = bind_add_exception_handler(*args, **kwargs)
handler = FunctionWrapper(handler, wrap_exception_handler)
return wrapped(exc_class_or_status_code, handler, *args, **kwargs)

Expand All @@ -217,7 +190,7 @@ async def wrap_run_in_threadpool(wrapped, instance, args, kwargs):
return await wrapped(*args, **kwargs)

func, args, kwargs = bind_run_in_threadpool(*args, **kwargs)
func = context_wrapper(func, current_thread_id())
func = context_wrapper(func, trace)

return await wrapped(func, *args, **kwargs)

Expand All @@ -241,35 +214,21 @@ def instrument_starlette_requests(module):


def instrument_starlette_middleware_errors(module):
wrap_function_wrapper(
module, "ServerErrorMiddleware.__call__", error_middleware_wrapper
)
wrap_function_wrapper(module, "ServerErrorMiddleware.__call__", error_middleware_wrapper)

wrap_function_wrapper(
module, "ServerErrorMiddleware.__init__", wrap_server_error_handler
)
wrap_function_wrapper(module, "ServerErrorMiddleware.__init__", wrap_server_error_handler)

wrap_function_wrapper(
module, "ServerErrorMiddleware.error_response", wrap_exception_handler
)
wrap_function_wrapper(module, "ServerErrorMiddleware.error_response", wrap_exception_handler)

wrap_function_wrapper(
module, "ServerErrorMiddleware.debug_response", wrap_exception_handler
)
wrap_function_wrapper(module, "ServerErrorMiddleware.debug_response", wrap_exception_handler)


def instrument_starlette_exceptions(module):
wrap_function_wrapper(
module, "ExceptionMiddleware.__call__", error_middleware_wrapper
)
wrap_function_wrapper(module, "ExceptionMiddleware.__call__", error_middleware_wrapper)

wrap_function_wrapper(
module, "ExceptionMiddleware.http_exception", wrap_exception_handler
)
wrap_function_wrapper(module, "ExceptionMiddleware.http_exception", wrap_exception_handler)

wrap_function_wrapper(
module, "ExceptionMiddleware.add_exception_handler", wrap_add_exception_handler
)
wrap_function_wrapper(module, "ExceptionMiddleware.add_exception_handler", wrap_add_exception_handler)


def instrument_starlette_background_task(module):
Expand Down

0 comments on commit 768019d

Please sign in to comment.