Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TraceCache Guarded Iteration #704

Merged
merged 15 commits into from Dec 7, 2022
10 changes: 5 additions & 5 deletions newrelic/core/context.py
Expand Up @@ -46,7 +46,7 @@ def log_propagation_failure(s):
elif trace is not None:
self.trace = trace
elif trace_cache_id is not None:
self.trace = self.trace_cache._cache.get(trace_cache_id, None)
self.trace = self.trace_cache.get(trace_cache_id, None)
if self.trace is None:
log_propagation_failure("No trace with id %d." % trace_cache_id)
elif hasattr(request, "_nr_trace") and request._nr_trace is not None:
Expand All @@ -60,22 +60,22 @@ def __enter__(self):
self.thread_id = self.trace_cache.current_thread_id()

# Save previous cache contents
self.restore = self.trace_cache._cache.get(self.thread_id, None)
self.restore = self.trace_cache.get(self.thread_id, None)
self.should_restore = True

# Set context in trace cache
self.trace_cache._cache[self.thread_id] = self.trace
self.trace_cache[self.thread_id] = self.trace

return self

def __exit__(self, exc, value, tb):
if self.should_restore:
if self.restore is not None:
# Restore previous contents
self.trace_cache._cache[self.thread_id] = self.restore
self.trace_cache[self.thread_id] = self.restore
else:
# Remove entry from cache
self.trace_cache._cache.pop(self.thread_id)
self.trace_cache.pop(self.thread_id)


def context_wrapper(func, trace=None, request=None, trace_cache_id=None, strict=True):
Expand Down
103 changes: 82 additions & 21 deletions newrelic/core/trace_cache.py
Expand Up @@ -28,6 +28,11 @@
except ImportError:
import _thread as thread

try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping

from newrelic.core.config import global_settings
from newrelic.core.loop_node import LoopNode

Expand Down Expand Up @@ -92,15 +97,15 @@ class TraceCacheActiveTraceError(RuntimeError):
pass


class TraceCache(object):
class TraceCache(MutableMapping):
asyncio = cached_module("asyncio")
greenlet = cached_module("greenlet")

def __init__(self):
self._cache = weakref.WeakValueDictionary()

def __repr__(self):
return "<%s object at 0x%x %s>" % (self.__class__.__name__, id(self), str(dict(self._cache.items())))
return "<%s object at 0x%x %s>" % (self.__class__.__name__, id(self), str(dict(self.items())))

def current_thread_id(self):
"""Returns the thread ID for the caller.
Expand Down Expand Up @@ -135,22 +140,22 @@ def current_thread_id(self):
def task_start(self, task):
trace = self.current_trace()
if trace:
self._cache[id(task)] = trace
self[id(task)] = trace

def task_stop(self, task):
self._cache.pop(id(task), None)
self.pop(id(task), None)

def current_transaction(self):
"""Return the transaction object if one exists for the currently
executing thread.

"""

trace = self._cache.get(self.current_thread_id())
trace = self.get(self.current_thread_id())
return trace and trace.transaction

def current_trace(self):
return self._cache.get(self.current_thread_id())
return self.get(self.current_thread_id())

def active_threads(self):
"""Returns an iterator over all current stack frames for all
Expand All @@ -169,7 +174,7 @@ def active_threads(self):
# First yield up those for real Python threads.

for thread_id, frame in sys._current_frames().items():
trace = self._cache.get(thread_id)
trace = self.get(thread_id)
transaction = trace and trace.transaction
if transaction is not None:
if transaction.background_task:
Expand Down Expand Up @@ -197,7 +202,7 @@ def active_threads(self):
debug = global_settings().debug

if debug.enable_coroutine_profiling:
for thread_id, trace in list(self._cache.items()):
for thread_id, trace in self.items():
transaction = trace.transaction
if transaction and transaction._greenlet is not None:
gr = transaction._greenlet()
Expand All @@ -212,7 +217,7 @@ def prepare_for_root(self):
trace in the cache is from a different task (for asyncio). Returns the
current trace after the cache is updated."""
thread_id = self.current_thread_id()
trace = self._cache.get(thread_id)
trace = self.get(thread_id)
if not trace:
return None

Expand All @@ -221,11 +226,11 @@ def prepare_for_root(self):

task = current_task(self.asyncio)
if task is not None and id(trace._task) != id(task):
self._cache.pop(thread_id, None)
self.pop(thread_id, None)
return None

if trace.root and trace.root.exited:
self._cache.pop(thread_id, None)
self.pop(thread_id, None)
return None

return trace
Expand All @@ -240,8 +245,8 @@ def save_trace(self, trace):

thread_id = trace.thread_id

if thread_id in self._cache:
cache_root = self._cache[thread_id].root
if thread_id in self:
cache_root = self[thread_id].root
if cache_root and cache_root is not trace.root and not cache_root.exited:
# Cached trace exists and has a valid root still
_logger.error(
Expand All @@ -253,7 +258,7 @@ def save_trace(self, trace):

raise TraceCacheActiveTraceError("transaction already active")

self._cache[thread_id] = trace
self[thread_id] = trace

# We judge whether we are actually running in a coroutine by
# seeing if the current thread ID is actually listed in the set
Expand Down Expand Up @@ -284,7 +289,7 @@ def pop_current(self, trace):

thread_id = trace.thread_id
parent = trace.parent
self._cache[thread_id] = parent
self[thread_id] = parent

def complete_root(self, root):
"""Completes a trace specified by the given root
Expand All @@ -301,7 +306,7 @@ def complete_root(self, root):
to_complete = []

for task_id in task_ids:
entry = self._cache.get(task_id)
entry = self.get(task_id)

if entry and entry is not root and entry.root is root:
to_complete.append(entry)
Expand All @@ -316,12 +321,12 @@ def complete_root(self, root):

thread_id = root.thread_id

if thread_id not in self._cache:
if thread_id not in self:
thread_id = self.current_thread_id()
if thread_id not in self._cache:
if thread_id not in self:
raise TraceCacheNoActiveTraceError("no active trace")

current = self._cache.get(thread_id)
current = self.get(thread_id)

if root is not current:
_logger.error(
Expand All @@ -333,7 +338,7 @@ def complete_root(self, root):

raise RuntimeError("not the current trace")

del self._cache[thread_id]
del self[thread_id]
root._greenlet = None

def record_event_loop_wait(self, start_time, end_time):
Expand All @@ -359,7 +364,7 @@ def record_event_loop_wait(self, start_time, end_time):
task = getattr(transaction.root_span, "_task", None)
loop = get_event_loop(task)

for trace in list(self._cache.values()):
for trace in self.values():
if trace in seen:
continue

Expand Down Expand Up @@ -390,6 +395,62 @@ def record_event_loop_wait(self, start_time, end_time):
root.increment_child_count()
root.add_child(node)

# MutableMapping methods

def items(self):
"""
Safely iterates on self._cache.items() indirectly using a list of value references
to avoid RuntimeErrors from size changes during iteration.
"""
for wr in self._cache.valuerefs():
value = wr() # Dereferenced value is potentially no longer live.
if (
value is not None
): # weakref is None means weakref has been garbage collected and is no longer live. Ignore.
yield wr.key, value # wr.key is the original dict key

def keys(self):
"""
Iterates on self._cache.keys() indirectly using a list of value references
to avoid RuntimeErrors from size changes during iteration.

NOTE: Returned keys are keys to weak references which may at any point be garbage collected.
It is only safe to retrieve values from the trace cache using trace_cache.get(key, None).
Retrieving values using trace_cache[key] can cause a KeyError if the item has been garbage collected.
"""
for wr in self._cache.valuerefs():
yield wr.key # wr.key is the original dict key
TimPansino marked this conversation as resolved.
Show resolved Hide resolved

def values(self):
"""
Safely iterates on self._cache.values() indirectly using a list of value references
to avoid RuntimeErrors from size changes during iteration.
"""
for wr in self._cache.valuerefs():
value = wr() # Dereferenced value is potentially no longer live.
if (
value is not None
): # weakref is None means weakref has been garbage collected and is no longer live. Ignore.
yield value

def __getitem__(self, key):
return self._cache.__getitem__(key)

def __setitem__(self, key, value):
self._cache.__setitem__(key, value)

def __delitem__(self, key):
self._cache.__delitem__(key)

def __iter__(self):
return self.keys()

def __len__(self):
return self._cache.__len__()

def __bool__(self):
return bool(self._cache.__len__())


_trace_cache = TraceCache()

Expand Down
12 changes: 6 additions & 6 deletions tests/agent_features/test_async_context_propagation.py
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

import pytest
from testing_support.fixtures import (
function_not_called,
override_generic_settings,
from testing_support.fixtures import function_not_called, override_generic_settings
from testing_support.validators.validate_transaction_metrics import (
validate_transaction_metrics,
)
from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics

from newrelic.api.application import application_instance as application
from newrelic.api.background_task import BackgroundTask, background_task
from newrelic.api.database_trace import database_trace
Expand Down Expand Up @@ -131,7 +131,7 @@ def handle_exception(loop, context):
# The agent should have removed all traces from the cache since
# run_until_complete has terminated (all callbacks scheduled inside the
# task have run)
assert not trace_cache()._cache
assert not trace_cache()

# Assert that no exceptions have occurred
assert not exceptions, exceptions
Expand Down Expand Up @@ -286,7 +286,7 @@ def _test():

# The agent should have removed all traces from the cache since
# run_until_complete has terminated
assert not trace_cache()._cache
assert not trace_cache()

# Assert that no exceptions have occurred
assert not exceptions, exceptions
Expand Down
2 changes: 1 addition & 1 deletion tests/agent_features/test_event_loop_wait_time.py
Expand Up @@ -140,7 +140,7 @@ def _test():
def test_record_event_loop_wait_outside_task():
# Insert a random trace into the trace cache
trace = FunctionTrace(name="testing")
trace_cache()._cache[0] = trace
trace_cache()[0] = trace

@background_task(name="test_record_event_loop_wait_outside_task")
def _test():
Expand Down