Skip to content

Commit

Permalink
Add TraceCache Guarded Iteration (#704)
Browse files Browse the repository at this point in the history
* Add MutableMapping API to TraceCache

* Update trace cache usage to use guarded APIs.

* [Mega-Linter] Apply linters fixes

* Bump tests

* Fix keys iterator

* Comments for trace cache methods

* Reorganize tests

* Fix fixture refs

* Fix testing refs

* [Mega-Linter] Apply linters fixes

* Bump tests

* Upper case constant

Co-authored-by: TimPansino <TimPansino@users.noreply.github.com>
Co-authored-by: Lalleh Rafeei <84813886+lrafeei@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 7, 2022
1 parent a63e33f commit f977ba6
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 39 deletions.
10 changes: 5 additions & 5 deletions newrelic/core/context.py
Expand Up @@ -46,7 +46,7 @@ def log_propagation_failure(s):
elif trace is not None:
self.trace = trace
elif trace_cache_id is not None:
self.trace = self.trace_cache._cache.get(trace_cache_id, None)
self.trace = self.trace_cache.get(trace_cache_id, None)
if self.trace is None:
log_propagation_failure("No trace with id %d." % trace_cache_id)
elif hasattr(request, "_nr_trace") and request._nr_trace is not None:
Expand All @@ -60,22 +60,22 @@ def __enter__(self):
self.thread_id = self.trace_cache.current_thread_id()

# Save previous cache contents
self.restore = self.trace_cache._cache.get(self.thread_id, None)
self.restore = self.trace_cache.get(self.thread_id, None)
self.should_restore = True

# Set context in trace cache
self.trace_cache._cache[self.thread_id] = self.trace
self.trace_cache[self.thread_id] = self.trace

return self

def __exit__(self, exc, value, tb):
if self.should_restore:
if self.restore is not None:
# Restore previous contents
self.trace_cache._cache[self.thread_id] = self.restore
self.trace_cache[self.thread_id] = self.restore
else:
# Remove entry from cache
self.trace_cache._cache.pop(self.thread_id)
self.trace_cache.pop(self.thread_id)


def context_wrapper(func, trace=None, request=None, trace_cache_id=None, strict=True):
Expand Down
103 changes: 82 additions & 21 deletions newrelic/core/trace_cache.py
Expand Up @@ -28,6 +28,11 @@
except ImportError:
import _thread as thread

try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping

from newrelic.core.config import global_settings
from newrelic.core.loop_node import LoopNode

Expand Down Expand Up @@ -92,15 +97,15 @@ class TraceCacheActiveTraceError(RuntimeError):
pass


class TraceCache(object):
class TraceCache(MutableMapping):
asyncio = cached_module("asyncio")
greenlet = cached_module("greenlet")

def __init__(self):
self._cache = weakref.WeakValueDictionary()

def __repr__(self):
return "<%s object at 0x%x %s>" % (self.__class__.__name__, id(self), str(dict(self._cache.items())))
return "<%s object at 0x%x %s>" % (self.__class__.__name__, id(self), str(dict(self.items())))

def current_thread_id(self):
"""Returns the thread ID for the caller.
Expand Down Expand Up @@ -135,22 +140,22 @@ def current_thread_id(self):
def task_start(self, task):
trace = self.current_trace()
if trace:
self._cache[id(task)] = trace
self[id(task)] = trace

def task_stop(self, task):
self._cache.pop(id(task), None)
self.pop(id(task), None)

def current_transaction(self):
"""Return the transaction object if one exists for the currently
executing thread.
"""

trace = self._cache.get(self.current_thread_id())
trace = self.get(self.current_thread_id())
return trace and trace.transaction

def current_trace(self):
return self._cache.get(self.current_thread_id())
return self.get(self.current_thread_id())

def active_threads(self):
"""Returns an iterator over all current stack frames for all
Expand All @@ -169,7 +174,7 @@ def active_threads(self):
# First yield up those for real Python threads.

for thread_id, frame in sys._current_frames().items():
trace = self._cache.get(thread_id)
trace = self.get(thread_id)
transaction = trace and trace.transaction
if transaction is not None:
if transaction.background_task:
Expand Down Expand Up @@ -197,7 +202,7 @@ def active_threads(self):
debug = global_settings().debug

if debug.enable_coroutine_profiling:
for thread_id, trace in list(self._cache.items()):
for thread_id, trace in self.items():
transaction = trace.transaction
if transaction and transaction._greenlet is not None:
gr = transaction._greenlet()
Expand All @@ -212,7 +217,7 @@ def prepare_for_root(self):
trace in the cache is from a different task (for asyncio). Returns the
current trace after the cache is updated."""
thread_id = self.current_thread_id()
trace = self._cache.get(thread_id)
trace = self.get(thread_id)
if not trace:
return None

Expand All @@ -221,11 +226,11 @@ def prepare_for_root(self):

task = current_task(self.asyncio)
if task is not None and id(trace._task) != id(task):
self._cache.pop(thread_id, None)
self.pop(thread_id, None)
return None

if trace.root and trace.root.exited:
self._cache.pop(thread_id, None)
self.pop(thread_id, None)
return None

return trace
Expand All @@ -240,8 +245,8 @@ def save_trace(self, trace):

thread_id = trace.thread_id

if thread_id in self._cache:
cache_root = self._cache[thread_id].root
if thread_id in self:
cache_root = self[thread_id].root
if cache_root and cache_root is not trace.root and not cache_root.exited:
# Cached trace exists and has a valid root still
_logger.error(
Expand All @@ -253,7 +258,7 @@ def save_trace(self, trace):

raise TraceCacheActiveTraceError("transaction already active")

self._cache[thread_id] = trace
self[thread_id] = trace

# We judge whether we are actually running in a coroutine by
# seeing if the current thread ID is actually listed in the set
Expand Down Expand Up @@ -284,7 +289,7 @@ def pop_current(self, trace):

thread_id = trace.thread_id
parent = trace.parent
self._cache[thread_id] = parent
self[thread_id] = parent

def complete_root(self, root):
"""Completes a trace specified by the given root
Expand All @@ -301,7 +306,7 @@ def complete_root(self, root):
to_complete = []

for task_id in task_ids:
entry = self._cache.get(task_id)
entry = self.get(task_id)

if entry and entry is not root and entry.root is root:
to_complete.append(entry)
Expand All @@ -316,12 +321,12 @@ def complete_root(self, root):

thread_id = root.thread_id

if thread_id not in self._cache:
if thread_id not in self:
thread_id = self.current_thread_id()
if thread_id not in self._cache:
if thread_id not in self:
raise TraceCacheNoActiveTraceError("no active trace")

current = self._cache.get(thread_id)
current = self.get(thread_id)

if root is not current:
_logger.error(
Expand All @@ -333,7 +338,7 @@ def complete_root(self, root):

raise RuntimeError("not the current trace")

del self._cache[thread_id]
del self[thread_id]
root._greenlet = None

def record_event_loop_wait(self, start_time, end_time):
Expand All @@ -359,7 +364,7 @@ def record_event_loop_wait(self, start_time, end_time):
task = getattr(transaction.root_span, "_task", None)
loop = get_event_loop(task)

for trace in list(self._cache.values()):
for trace in self.values():
if trace in seen:
continue

Expand Down Expand Up @@ -390,6 +395,62 @@ def record_event_loop_wait(self, start_time, end_time):
root.increment_child_count()
root.add_child(node)

# MutableMapping methods

def items(self):
"""
Safely iterates on self._cache.items() indirectly using a list of value references
to avoid RuntimeErrors from size changes during iteration.
"""
for wr in self._cache.valuerefs():
value = wr() # Dereferenced value is potentially no longer live.
if (
value is not None
): # weakref is None means weakref has been garbage collected and is no longer live. Ignore.
yield wr.key, value # wr.key is the original dict key

def keys(self):
"""
Iterates on self._cache.keys() indirectly using a list of value references
to avoid RuntimeErrors from size changes during iteration.
NOTE: Returned keys are keys to weak references which may at any point be garbage collected.
It is only safe to retrieve values from the trace cache using trace_cache.get(key, None).
Retrieving values using trace_cache[key] can cause a KeyError if the item has been garbage collected.
"""
for wr in self._cache.valuerefs():
yield wr.key # wr.key is the original dict key

def values(self):
"""
Safely iterates on self._cache.values() indirectly using a list of value references
to avoid RuntimeErrors from size changes during iteration.
"""
for wr in self._cache.valuerefs():
value = wr() # Dereferenced value is potentially no longer live.
if (
value is not None
): # weakref is None means weakref has been garbage collected and is no longer live. Ignore.
yield value

def __getitem__(self, key):
return self._cache.__getitem__(key)

def __setitem__(self, key, value):
self._cache.__setitem__(key, value)

def __delitem__(self, key):
self._cache.__delitem__(key)

def __iter__(self):
return self.keys()

def __len__(self):
return self._cache.__len__()

def __bool__(self):
return bool(self._cache.__len__())


_trace_cache = TraceCache()

Expand Down
12 changes: 6 additions & 6 deletions tests/agent_features/test_async_context_propagation.py
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

import pytest
from testing_support.fixtures import (
function_not_called,
override_generic_settings,
from testing_support.fixtures import function_not_called, override_generic_settings
from testing_support.validators.validate_transaction_metrics import (
validate_transaction_metrics,
)
from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics

from newrelic.api.application import application_instance as application
from newrelic.api.background_task import BackgroundTask, background_task
from newrelic.api.database_trace import database_trace
Expand Down Expand Up @@ -131,7 +131,7 @@ def handle_exception(loop, context):
# The agent should have removed all traces from the cache since
# run_until_complete has terminated (all callbacks scheduled inside the
# task have run)
assert not trace_cache()._cache
assert not trace_cache()

# Assert that no exceptions have occurred
assert not exceptions, exceptions
Expand Down Expand Up @@ -286,7 +286,7 @@ def _test():

# The agent should have removed all traces from the cache since
# run_until_complete has terminated
assert not trace_cache()._cache
assert not trace_cache()

# Assert that no exceptions have occurred
assert not exceptions, exceptions
Expand Down
2 changes: 1 addition & 1 deletion tests/agent_features/test_event_loop_wait_time.py
Expand Up @@ -140,7 +140,7 @@ def _test():
def test_record_event_loop_wait_outside_task():
# Insert a random trace into the trace cache
trace = FunctionTrace(name="testing")
trace_cache()._cache[0] = trace
trace_cache()[0] = trace

@background_task(name="test_record_event_loop_wait_outside_task")
def _test():
Expand Down

0 comments on commit f977ba6

Please sign in to comment.