Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FastAPI Context Propagation #420

Merged
merged 13 commits into from Nov 3, 2021
63 changes: 51 additions & 12 deletions newrelic/core/context.py
Expand Up @@ -16,41 +16,80 @@
This module implements utilities for context propagation for tracing across threads.
"""

import logging

from newrelic.common.object_wrapper import function_wrapper
from newrelic.core.trace_cache import trace_cache

_logger = logging.getLogger(__name__)


class ContextOf(object):
def __init__(self, trace_cache_id):
def __init__(self, trace=None, request=None, trace_cache_id=None):
self.trace = None
self.trace_cache = trace_cache()
self.trace = self.trace_cache._cache.get(trace_cache_id)
self.thread_id = None
self.restore = None
self.should_restore = False

# Extract trace if possible, else leave as None for safety
if trace is None and request is None and trace_cache_id is None:
_logger.error(
"Runtime instrumentation error. Request context propagation failed. No trace or request provided. Report this issue to New Relic support.",
)
elif trace is not None:
self.trace = trace
elif trace_cache_id is not None:
self.trace = self.trace_cache._cache.get(trace_cache_id, None)
if self.trace is None:
_logger.error(
"Runtime instrumentation error. Request context propagation failed. No trace with id %s. Report this issue to New Relic support.",
trace_cache_id,
)
elif hasattr(request, "_nr_trace") and request._nr_trace is not None:
# Unpack traces from objects patched with them
self.trace = request._nr_trace
else:
_logger.error(
"Runtime instrumentation error. Request context propagation failed. No context attached to request. Report this issue to New Relic support.",
)

def __enter__(self):
if self.trace:
self.thread_id = self.trace_cache.current_thread_id()
self.restore = self.trace_cache._cache.get(self.thread_id)

# Save previous cache contents
self.restore = self.trace_cache._cache.get(self.thread_id, None)
self.should_restore = True

# Set context in trace cache
self.trace_cache._cache[self.thread_id] = self.trace

return self

def __exit__(self, exc, value, tb):
if self.restore:
self.trace_cache._cache[self.thread_id] = self.restore
if self.should_restore:
if self.restore is not None:
# Restore previous contents
self.trace_cache._cache[self.thread_id] = self.restore
else:
# Remove entry from cache
self.trace_cache._cache.pop(self.thread_id)


async def context_wrapper_async(awaitable, trace_cache_id):
with ContextOf(trace_cache_id):
return await awaitable


def context_wrapper(func, trace_cache_id):
def context_wrapper(func, trace=None, request=None, trace_cache_id=None):
@function_wrapper
def _context_wrapper(wrapped, instance, args, kwargs):
with ContextOf(trace_cache_id):
with ContextOf(trace=trace, request=request, trace_cache_id=trace_cache_id):
return wrapped(*args, **kwargs)

return _context_wrapper(func)


async def context_wrapper_async(awaitable, trace=None, request=None, trace_cache_id=None):
with ContextOf(trace=trace, request=request, trace_cache_id=trace_cache_id):
return await awaitable


def current_thread_id():
return trace_cache().current_thread_id()
19 changes: 0 additions & 19 deletions newrelic/core/trace_cache.py
Expand Up @@ -275,25 +275,6 @@ def save_trace(self, trace):
task = current_task(self.asyncio)
trace._task = task

def thread_start(self, trace):
current_thread_id = self.current_thread_id()
if current_thread_id not in self._cache:
self._cache[current_thread_id] = trace
else:
_logger.error(
"Runtime instrumentation error. An active "
"trace already exists in the cache on thread_id %s. Report "
"this issue to New Relic support.\n ",
current_thread_id,
)
return None

return current_thread_id

def thread_stop(self, thread_id):
if thread_id:
self._cache.pop(thread_id, None)

def pop_current(self, trace):
"""Restore the trace's parent under the thread ID of the current
executing thread."""
Expand Down
14 changes: 6 additions & 8 deletions newrelic/hooks/adapter_asgiref.py
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from newrelic.api.time_trace import current_trace
from newrelic.common.object_wrapper import wrap_function_wrapper
from newrelic.core.trace_cache import trace_cache
from newrelic.core.context import context_wrapper_async, ContextOf
from newrelic.core.context import ContextOf, context_wrapper_async


def _bind_thread_handler(loop, source_task, *args, **kwargs):
Expand All @@ -23,17 +23,15 @@ def _bind_thread_handler(loop, source_task, *args, **kwargs):

def thread_handler_wrapper(wrapped, instance, args, kwargs):
task = _bind_thread_handler(*args, **kwargs)
with ContextOf(id(task)):
with ContextOf(trace_cache_id=id(task)):
return wrapped(*args, **kwargs)


def main_wrap_wrapper(wrapped, instance, args, kwargs):
awaitable = wrapped(*args, **kwargs)
return context_wrapper_async(awaitable, trace_cache().current_thread_id())
return context_wrapper_async(awaitable, current_trace())


def instrument_asgiref_sync(module):
wrap_function_wrapper(module, 'SyncToAsync.thread_handler',
thread_handler_wrapper)
wrap_function_wrapper(module, 'AsyncToSync.main_wrap',
main_wrap_wrapper)
wrap_function_wrapper(module, "SyncToAsync.thread_handler", thread_handler_wrapper)
wrap_function_wrapper(module, "AsyncToSync.main_wrap", main_wrap_wrapper)
29 changes: 6 additions & 23 deletions newrelic/hooks/framework_fastapi.py
Expand Up @@ -13,25 +13,11 @@
# limitations under the License.

from copy import copy
from newrelic.api.time_trace import current_trace

from newrelic.api.function_trace import FunctionTraceWrapper
from newrelic.common.object_wrapper import wrap_function_wrapper, function_wrapper
from newrelic.api.time_trace import current_trace
from newrelic.common.object_names import callable_name
from newrelic.core.trace_cache import trace_cache


def use_context(trace):

@function_wrapper
def context_wrapper(wrapped, instance, args, kwargs):
cache = trace_cache()
thread_id = cache.thread_start(trace)
try:
return wrapped(*args, **kwargs)
finally:
cache.thread_stop(thread_id)

return context_wrapper
from newrelic.common.object_wrapper import wrap_function_wrapper


def wrap_run_endpoint_function(wrapped, instance, args, kwargs):
Expand All @@ -41,12 +27,9 @@ def wrap_run_endpoint_function(wrapped, instance, args, kwargs):
name = callable_name(dependant.call)
trace.transaction.set_transaction_name(name)

if not kwargs["is_coroutine"]:
dependant = kwargs["dependant"] = copy(dependant)
dependant.call = use_context(trace)(FunctionTraceWrapper(dependant.call))
return wrapped(*args, **kwargs)
else:
return FunctionTraceWrapper(wrapped, name=name)(*args, **kwargs)
dependant = kwargs["dependant"] = copy(dependant)
dependant.call = FunctionTraceWrapper(dependant.call)
return wrapped(*args, **kwargs)

return wrapped(*args, **kwargs)

Expand Down
71 changes: 15 additions & 56 deletions newrelic/hooks/framework_starlette.py
Expand Up @@ -25,8 +25,7 @@
wrap_function_wrapper,
)
from newrelic.core.config import should_ignore_error
from newrelic.core.context import context_wrapper, current_thread_id
from newrelic.core.trace_cache import trace_cache
from newrelic.core.context import ContextOf, context_wrapper


def framework_details():
Expand All @@ -43,30 +42,10 @@ def bind_exc(request, exc, *args, **kwargs):
return exc


class RequestContext(object):
def __init__(self, request):
self.request = request
self.force_propagate = False
self.thread_id = None

def __enter__(self):
trace = getattr(self.request, "_nr_trace", None)
self.force_propagate = trace and current_trace() is None

# Propagate trace context onto the current task
if self.force_propagate:
self.thread_id = trace_cache().thread_start(trace)

def __exit__(self, exc, value, tb):
# Remove any context from the current thread as it was force propagated above
if self.force_propagate:
trace_cache().thread_stop(self.thread_id)


@function_wrapper
def route_naming_wrapper(wrapped, instance, args, kwargs):

with RequestContext(bind_request(*args, **kwargs)):
with ContextOf(request=bind_request(*args, **kwargs)):
transaction = current_transaction()
if transaction:
transaction.set_transaction_name(callable_name(wrapped), priority=2)
Expand Down Expand Up @@ -136,16 +115,14 @@ def wrap_add_middleware(wrapped, instance, args, kwargs):
return wrapped(wrap_middleware(middleware), *args, **kwargs)


def bind_middleware_starlette(
debug=False, routes=None, middleware=None, *args, **kwargs
):
def bind_middleware_starlette(debug=False, routes=None, middleware=None, *args, **kwargs): # pylint: disable=W1113
return middleware


def wrap_starlette(wrapped, instance, args, kwargs):
middlewares = bind_middleware_starlette(*args, **kwargs)
if middlewares:
for middleware in middlewares:
for middleware in middlewares: # pylint: disable=E1133
cls = getattr(middleware, "cls", None)
if cls and not hasattr(cls, "__wrapped__"):
middleware.cls = wrap_middleware(cls)
Expand All @@ -171,11 +148,9 @@ async def wrap_exception_handler_async(coro, exc):

def wrap_exception_handler(wrapped, instance, args, kwargs):
if is_coroutine_function(wrapped):
return wrap_exception_handler_async(
FunctionTraceWrapper(wrapped)(*args, **kwargs), bind_exc(*args, **kwargs)
)
return wrap_exception_handler_async(FunctionTraceWrapper(wrapped)(*args, **kwargs), bind_exc(*args, **kwargs))
else:
with RequestContext(bind_request(*args, **kwargs)):
with ContextOf(request=bind_request(*args, **kwargs)):
response = FunctionTraceWrapper(wrapped)(*args, **kwargs)
record_response_error(response, bind_exc(*args, **kwargs))
return response
Expand All @@ -190,9 +165,7 @@ def wrap_server_error_handler(wrapped, instance, args, kwargs):


def wrap_add_exception_handler(wrapped, instance, args, kwargs):
exc_class_or_status_code, handler, args, kwargs = bind_add_exception_handler(
*args, **kwargs
)
exc_class_or_status_code, handler, args, kwargs = bind_add_exception_handler(*args, **kwargs)
handler = FunctionWrapper(handler, wrap_exception_handler)
return wrapped(exc_class_or_status_code, handler, *args, **kwargs)

Expand All @@ -217,7 +190,7 @@ async def wrap_run_in_threadpool(wrapped, instance, args, kwargs):
return await wrapped(*args, **kwargs)

func, args, kwargs = bind_run_in_threadpool(*args, **kwargs)
func = context_wrapper(func, current_thread_id())
func = context_wrapper(func, trace)

return await wrapped(func, *args, **kwargs)

Expand All @@ -241,35 +214,21 @@ def instrument_starlette_requests(module):


def instrument_starlette_middleware_errors(module):
wrap_function_wrapper(
module, "ServerErrorMiddleware.__call__", error_middleware_wrapper
)
wrap_function_wrapper(module, "ServerErrorMiddleware.__call__", error_middleware_wrapper)

wrap_function_wrapper(
module, "ServerErrorMiddleware.__init__", wrap_server_error_handler
)
wrap_function_wrapper(module, "ServerErrorMiddleware.__init__", wrap_server_error_handler)

wrap_function_wrapper(
module, "ServerErrorMiddleware.error_response", wrap_exception_handler
)
wrap_function_wrapper(module, "ServerErrorMiddleware.error_response", wrap_exception_handler)

wrap_function_wrapper(
module, "ServerErrorMiddleware.debug_response", wrap_exception_handler
)
wrap_function_wrapper(module, "ServerErrorMiddleware.debug_response", wrap_exception_handler)


def instrument_starlette_exceptions(module):
wrap_function_wrapper(
module, "ExceptionMiddleware.__call__", error_middleware_wrapper
)
wrap_function_wrapper(module, "ExceptionMiddleware.__call__", error_middleware_wrapper)

wrap_function_wrapper(
module, "ExceptionMiddleware.http_exception", wrap_exception_handler
)
wrap_function_wrapper(module, "ExceptionMiddleware.http_exception", wrap_exception_handler)

wrap_function_wrapper(
module, "ExceptionMiddleware.add_exception_handler", wrap_add_exception_handler
)
wrap_function_wrapper(module, "ExceptionMiddleware.add_exception_handler", wrap_add_exception_handler)


def instrument_starlette_background_task(module):
Expand Down