diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index f42af328e..7a7c6087d 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,4 +1,14 @@ -ARG IMAGE=ghcr.io/newrelic-experimental/pyenv-devcontainer:latest - # To target other architectures, change the --platform directive in the Dockerfile. -FROM --platform=linux/amd64 ${IMAGE} +ARG IMAGE_TAG=latest +FROM ghcr.io/newrelic/newrelic-python-agent-ci:${IMAGE_TAG} + +# Setup non-root user +USER root +ARG UID=1000 +ARG GID=$UID +ENV HOME /home/vscode +RUN mkdir -p ${HOME} && \ + groupadd --gid ${GID} vscode && \ + useradd --uid ${UID} --gid ${GID} --home ${HOME} vscode && \ + chown -R ${UID}:${GID} /home/vscode +USER ${UID}:${GID} diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 92a8cdee4..fbefff476 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -5,7 +5,7 @@ // To target other architectures, change the --platform directive in the Dockerfile. "dockerfile": "Dockerfile", "args": { - "IMAGE": "ghcr.io/newrelic-experimental/pyenv-devcontainer:latest" + "IMAGE_TAG": "latest" } }, "remoteUser": "vscode", diff --git a/.github/containers/Dockerfile b/.github/containers/Dockerfile index 2fbefb14a..d761b6f4a 100644 --- a/.github/containers/Dockerfile +++ b/.github/containers/Dockerfile @@ -23,6 +23,8 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ build-essential \ curl \ expat \ + fish \ + fontconfig \ freetds-common \ freetds-dev \ gcc \ @@ -46,13 +48,16 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ python2-dev \ python3-dev \ python3-pip \ + sudo \ tzdata \ unixodbc-dev \ unzip \ + vim \ wget \ zip \ zlib1g \ - zlib1g-dev && \ + zlib1g-dev \ + zsh && \ rm -rf /var/lib/apt/lists/* # Build librdkafka from source @@ -93,6 +98,10 @@ RUN echo 'eval "$(pyenv init -)"' >>$HOME/.bashrc && \ # Install Python ARG PYTHON_VERSIONS="3.10 3.9 3.8 3.7 3.11 2.7 pypy2.7-7.3.12 pypy3.8-7.3.11" COPY --chown=1000:1000 --chmod=+x ./install-python.sh /tmp/install-python.sh -COPY ./requirements.txt /requirements.txt RUN /tmp/install-python.sh && \ rm /tmp/install-python.sh + +# Install dependencies for main python installation +COPY ./requirements.txt /tmp/requirements.txt +RUN pyenv exec pip install --upgrade -r /tmp/requirements.txt && \ + rm /tmp/requirements.txt \ No newline at end of file diff --git a/.github/containers/Makefile b/.github/containers/Makefile index 35081f738..4c057813d 100644 --- a/.github/containers/Makefile +++ b/.github/containers/Makefile @@ -19,16 +19,16 @@ REPO_ROOT:=$(realpath $(MAKEFILE_DIR)../../) .PHONY: default default: test +# Perform a shortened build for testing .PHONY: build build: - @# Perform a shortened build for testing @docker build $(MAKEFILE_DIR) \ -t ghcr.io/newrelic/newrelic-python-agent-ci:local \ --build-arg='PYTHON_VERSIONS=3.10 2.7' +# Ensure python versions are usable .PHONY: test test: build - @# Ensure python versions are usable @docker run --rm ghcr.io/newrelic/python-agent-ci:local /bin/bash -c '\ python3.10 --version && \ python2.7 --version && \ diff --git a/.github/containers/install-python.sh b/.github/containers/install-python.sh index 2031e2d92..f9da0a003 100755 --- a/.github/containers/install-python.sh +++ b/.github/containers/install-python.sh @@ -45,9 +45,6 @@ main() { # Set all installed versions as globally accessible pyenv global ${PYENV_VERSIONS[@]} - - # Install dependencies for main python installation - pyenv exec pip install --upgrade -r /requirements.txt } main diff --git a/.github/containers/requirements.txt b/.github/containers/requirements.txt index 27fa6624b..68bdfe4fe 100644 --- a/.github/containers/requirements.txt +++ b/.github/containers/requirements.txt @@ -1,5 +1,9 @@ +bandit +black +flake8 +isort pip setuptools -wheel +tox virtualenv<20.22.0 -tox \ No newline at end of file +wheel \ No newline at end of file diff --git a/.github/workflows/build-ci-image.yml b/.github/workflows/build-ci-image.yml index 5bd0e6f69..8bd904661 100644 --- a/.github/workflows/build-ci-image.yml +++ b/.github/workflows/build-ci-image.yml @@ -63,6 +63,6 @@ jobs: with: push: ${{ github.event_name != 'pull_request' }} context: .github/containers - platforms: linux/amd64 + platforms: ${{ (github.ref == 'refs/head/main') && 'linux/amd64,linux/arm64' || 'linux/amd64' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b2c221bcf..e3b264a9f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1024,3 +1024,68 @@ jobs: name: coverage-${{ github.job }}-${{ strategy.job-index }} path: ./**/.coverage.* retention-days: 1 + + firestore: + env: + TOTAL_GROUPS: 1 + + strategy: + fail-fast: false + matrix: + group-number: [1] + + runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway + timeout-minutes: 30 + + services: + firestore: + # Image set here MUST be repeated down below in options. See comment below. + image: gcr.io/google.com/cloudsdktool/google-cloud-cli:437.0.1-emulators + ports: + - 8080:8080 + # Set health checks to wait 5 seconds in lieu of an actual healthcheck + options: >- + --health-cmd "echo success" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + --health-start-period 5s + gcr.io/google.com/cloudsdktool/google-cloud-cli:437.0.1-emulators /bin/bash -c "gcloud emulators firestore start --host-port=0.0.0.0:8080" || + # This is a very hacky solution. GitHub Actions doesn't provide APIs for setting commands on services, but allows adding arbitrary options. + # --entrypoint won't work as it only accepts an executable and not the [] syntax. + # Instead, we specify the image again the command afterwards like a call to docker create. The result is a few environment variables + # and the original command being appended to our hijacked docker create command. We can avoid any issues by adding || to prevent that + # from every being executed as bash commands. + + steps: + - uses: actions/checkout@v3 + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin + + - name: Get Environments + id: get-envs + run: | + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT + env: + GROUP_NUMBER: ${{ matrix.group-number }} + + - name: Test + run: | + tox -vv -e ${{ steps.get-envs.outputs.envs }} -p auto + env: + TOX_PARALLEL_NO_SPINNER: 1 + PY_COLORS: 0 + + - name: Upload Coverage Artifacts + uses: actions/upload-artifact@v3 + with: + name: coverage-${{ github.job }}-${{ strategy.job-index }} + path: ./**/.coverage.* + retention-days: 1 diff --git a/newrelic/api/database_trace.py b/newrelic/api/database_trace.py index 2bc497688..1069be506 100644 --- a/newrelic/api/database_trace.py +++ b/newrelic/api/database_trace.py @@ -16,7 +16,7 @@ import logging from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.database_node import DatabaseNode from newrelic.core.stack_trace import current_stack @@ -244,9 +244,9 @@ def create_node(self): ) -def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None): +def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None, async_wrapper=None): def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -273,9 +273,9 @@ def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_database_trace_wrapper_) -def database_trace(sql, dbapi2_module=None): - return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module) +def database_trace(sql, dbapi2_module=None, async_wrapper=None): + return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module, async_wrapper=async_wrapper) -def wrap_database_trace(module, object_path, sql, dbapi2_module=None): - wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module)) +def wrap_database_trace(module, object_path, sql, dbapi2_module=None, async_wrapper=None): + wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module, async_wrapper)) diff --git a/newrelic/api/datastore_trace.py b/newrelic/api/datastore_trace.py index fb40abcab..0401c79ea 100644 --- a/newrelic/api/datastore_trace.py +++ b/newrelic/api/datastore_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.datastore_node import DatastoreNode @@ -82,6 +82,9 @@ def __enter__(self): self.product = transaction._intern_string(self.product) self.target = transaction._intern_string(self.target) self.operation = transaction._intern_string(self.operation) + self.host = transaction._intern_string(self.host) + self.port_path_or_id = transaction._intern_string(self.port_path_or_id) + self.database_name = transaction._intern_string(self.database_name) datastore_tracer_settings = transaction.settings.datastore_tracer self.instance_reporting_enabled = datastore_tracer_settings.instance_reporting.enabled @@ -92,7 +95,14 @@ def __repr__(self): return "<%s object at 0x%x %s>" % ( self.__class__.__name__, id(self), - dict(product=self.product, target=self.target, operation=self.operation), + dict( + product=self.product, + target=self.target, + operation=self.operation, + host=self.host, + port_path_or_id=self.port_path_or_id, + database_name=self.database_name, + ), ) def finalize_data(self, transaction, exc=None, value=None, tb=None): @@ -125,7 +135,7 @@ def create_node(self): ) -def DatastoreTraceWrapper(wrapped, product, target, operation): +def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Wraps a method to time datastore queries. :param wrapped: The function to apply the trace to. @@ -140,6 +150,16 @@ def DatastoreTraceWrapper(wrapped, product, target, operation): or the name of any API function/method in the client library. :type operation: str or callable + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None :rtype: :class:`newrelic.common.object_wrapper.FunctionWrapper` This is typically used to wrap datastore queries such as calls to Redis or @@ -155,7 +175,7 @@ def DatastoreTraceWrapper(wrapped, product, target, operation): """ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -187,7 +207,33 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): else: _operation = operation - trace = DatastoreTrace(_product, _target, _operation, parent=parent, source=wrapped) + if callable(host): + if instance is not None: + _host = host(instance, *args, **kwargs) + else: + _host = host(*args, **kwargs) + else: + _host = host + + if callable(port_path_or_id): + if instance is not None: + _port_path_or_id = port_path_or_id(instance, *args, **kwargs) + else: + _port_path_or_id = port_path_or_id(*args, **kwargs) + else: + _port_path_or_id = port_path_or_id + + if callable(database_name): + if instance is not None: + _database_name = database_name(instance, *args, **kwargs) + else: + _database_name = database_name(*args, **kwargs) + else: + _database_name = database_name + + trace = DatastoreTrace( + _product, _target, _operation, _host, _port_path_or_id, _database_name, parent=parent, source=wrapped + ) if wrapper: # pylint: disable=W0125,W0126 return wrapper(wrapped, trace)(*args, **kwargs) @@ -198,7 +244,7 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_datastore_trace_wrapper_) -def datastore_trace(product, target, operation): +def datastore_trace(product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Decorator allows datastore query to be timed. :param product: The name of the vendor. @@ -211,6 +257,16 @@ def datastore_trace(product, target, operation): or the name of any API function/method in the client library. :type operation: str + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to decorate datastore queries such as calls to Redis or ElasticSearch. @@ -224,10 +280,21 @@ def datastore_trace(product, target, operation): ... time.sleep(*args, **kwargs) """ - return functools.partial(DatastoreTraceWrapper, product=product, target=target, operation=operation) - - -def wrap_datastore_trace(module, object_path, product, target, operation): + return functools.partial( + DatastoreTraceWrapper, + product=product, + target=target, + operation=operation, + host=host, + port_path_or_id=port_path_or_id, + database_name=database_name, + async_wrapper=async_wrapper, + ) + + +def wrap_datastore_trace( + module, object_path, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None +): """Method applies custom timing to datastore query. :param module: Module containing the method to be instrumented. @@ -244,6 +311,16 @@ def wrap_datastore_trace(module, object_path, product, target, operation): or the name of any API function/method in the client library. :type operation: str + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to time database query method calls such as Redis GET. @@ -256,4 +333,6 @@ def wrap_datastore_trace(module, object_path, product, target, operation): ... 'sleep') """ - wrap_object(module, object_path, DatastoreTraceWrapper, (product, target, operation)) + wrap_object( + module, object_path, DatastoreTraceWrapper, (product, target, operation, host, port_path_or_id, database_name, async_wrapper) + ) diff --git a/newrelic/api/external_trace.py b/newrelic/api/external_trace.py index c43c560c6..2e147df45 100644 --- a/newrelic/api/external_trace.py +++ b/newrelic/api/external_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.external_node import ExternalNode @@ -66,9 +66,9 @@ def create_node(self): ) -def ExternalTraceWrapper(wrapped, library, url, method=None): +def ExternalTraceWrapper(wrapped, library, url, method=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -103,7 +103,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -125,9 +125,9 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def external_trace(library, url, method=None): - return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method) +def external_trace(library, url, method=None, async_wrapper=None): + return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method, async_wrapper=async_wrapper) -def wrap_external_trace(module, object_path, library, url, method=None): - wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method)) +def wrap_external_trace(module, object_path, library, url, method=None, async_wrapper=None): + wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method, async_wrapper)) diff --git a/newrelic/api/function_trace.py b/newrelic/api/function_trace.py index 474c1b226..85d7617b6 100644 --- a/newrelic/api/function_trace.py +++ b/newrelic/api/function_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.function_node import FunctionNode @@ -89,9 +89,9 @@ def create_node(self): ) -def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -147,7 +147,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -171,13 +171,13 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): return functools.partial( - FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup + FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup, async_wrapper=async_wrapper ) def wrap_function_trace( - module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None + module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None ): - return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup)) + return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup, async_wrapper)) diff --git a/newrelic/api/graphql_trace.py b/newrelic/api/graphql_trace.py index 6b0d344a2..e8803fa68 100644 --- a/newrelic/api/graphql_trace.py +++ b/newrelic/api/graphql_trace.py @@ -16,7 +16,7 @@ from newrelic.api.time_trace import TimeTrace, current_trace from newrelic.api.transaction import current_transaction -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.graphql_node import GraphQLOperationNode, GraphQLResolverNode @@ -109,9 +109,9 @@ def set_transaction_name(self, priority=None): transaction.set_transaction_name(name, "GraphQL", priority=priority) -def GraphQLOperationTraceWrapper(wrapped): +def GraphQLOperationTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -130,12 +130,12 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_operation_trace(): - return functools.partial(GraphQLOperationTraceWrapper) +def graphql_operation_trace(async_wrapper=None): + return functools.partial(GraphQLOperationTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_operation_trace(module, object_path): - wrap_object(module, object_path, GraphQLOperationTraceWrapper) +def wrap_graphql_operation_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLOperationTraceWrapper, (async_wrapper,)) class GraphQLResolverTrace(TimeTrace): @@ -199,9 +199,9 @@ def create_node(self): ) -def GraphQLResolverTraceWrapper(wrapped): +def GraphQLResolverTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -220,9 +220,9 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_resolver_trace(): - return functools.partial(GraphQLResolverTraceWrapper) +def graphql_resolver_trace(async_wrapper=None): + return functools.partial(GraphQLResolverTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_resolver_trace(module, object_path): - wrap_object(module, object_path, GraphQLResolverTraceWrapper) +def wrap_graphql_resolver_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLResolverTraceWrapper, (async_wrapper,)) diff --git a/newrelic/api/memcache_trace.py b/newrelic/api/memcache_trace.py index 6657a9ce2..87f12f9fc 100644 --- a/newrelic/api/memcache_trace.py +++ b/newrelic/api/memcache_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.memcache_node import MemcacheNode @@ -51,9 +51,9 @@ def create_node(self): ) -def MemcacheTraceWrapper(wrapped, command): +def MemcacheTraceWrapper(wrapped, command, async_wrapper=None): def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -80,9 +80,9 @@ def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_wrapper_memcache_trace_) -def memcache_trace(command): - return functools.partial(MemcacheTraceWrapper, command=command) +def memcache_trace(command, async_wrapper=None): + return functools.partial(MemcacheTraceWrapper, command=command, async_wrapper=async_wrapper) -def wrap_memcache_trace(module, object_path, command): - wrap_object(module, object_path, MemcacheTraceWrapper, (command,)) +def wrap_memcache_trace(module, object_path, command, async_wrapper=None): + wrap_object(module, object_path, MemcacheTraceWrapper, (command, async_wrapper)) diff --git a/newrelic/api/message_trace.py b/newrelic/api/message_trace.py index be819d704..f564c41cb 100644 --- a/newrelic/api/message_trace.py +++ b/newrelic/api/message_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.message_node import MessageNode @@ -91,9 +91,9 @@ def create_node(self): ) -def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True): +def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -144,7 +144,7 @@ def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_message_trace_wrapper_) -def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True): +def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): return functools.partial( MessageTraceWrapper, library=library, @@ -153,10 +153,11 @@ def message_trace(library, operation, destination_type, destination_name, params destination_name=destination_name, params=params, terminal=terminal, + async_wrapper=async_wrapper, ) -def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True): +def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): wrap_object( - module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal) + module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal, async_wrapper) ) diff --git a/newrelic/common/async_wrapper.py b/newrelic/common/async_wrapper.py index c5f95308d..2d3db2b4b 100644 --- a/newrelic/common/async_wrapper.py +++ b/newrelic/common/async_wrapper.py @@ -18,7 +18,9 @@ is_coroutine_callable, is_asyncio_coroutine, is_generator_function, + is_async_generator_function, ) +from newrelic.packages import six def evaluate_wrapper(wrapper_string, wrapped, trace): @@ -29,7 +31,6 @@ def evaluate_wrapper(wrapper_string, wrapped, trace): def coroutine_wrapper(wrapped, trace): - WRAPPER = textwrap.dedent(""" @functools.wraps(wrapped) async def wrapper(*args, **kwargs): @@ -61,29 +62,76 @@ def wrapper(*args, **kwargs): return wrapped -def generator_wrapper(wrapped, trace): - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - g = wrapped(*args, **kwargs) - value = None - with trace: - while True: +if six.PY3: + def generator_wrapper(wrapped, trace): + WRAPPER = textwrap.dedent(""" + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + with trace: + result = yield from wrapped(*args, **kwargs) + return result + """) + + try: + return evaluate_wrapper(WRAPPER, wrapped, trace) + except: + return wrapped +else: + def generator_wrapper(wrapped, trace): + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + g = wrapped(*args, **kwargs) + with trace: try: - yielded = g.send(value) + yielded = g.send(None) + while True: + try: + sent = yield yielded + except GeneratorExit as e: + g.close() + raise + except BaseException as e: + yielded = g.throw(e) + else: + yielded = g.send(sent) except StopIteration: - break + return + return wrapper - try: - value = yield yielded - except BaseException as e: - value = yield g.throw(type(e), e) - return wrapper +def async_generator_wrapper(wrapped, trace): + WRAPPER = textwrap.dedent(""" + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + g = wrapped(*args, **kwargs) + with trace: + try: + yielded = await g.asend(None) + while True: + try: + sent = yield yielded + except GeneratorExit as e: + await g.aclose() + raise + except BaseException as e: + yielded = await g.athrow(e) + else: + yielded = await g.asend(sent) + except StopAsyncIteration: + return + """) + + try: + return evaluate_wrapper(WRAPPER, wrapped, trace) + except: + return wrapped def async_wrapper(wrapped): if is_coroutine_callable(wrapped): return coroutine_wrapper + elif is_async_generator_function(wrapped): + return async_generator_wrapper elif is_generator_function(wrapped): if is_asyncio_coroutine(wrapped): return awaitable_generator_wrapper diff --git a/newrelic/common/coroutine.py b/newrelic/common/coroutine.py index cf4c91f85..33a4922f5 100644 --- a/newrelic/common/coroutine.py +++ b/newrelic/common/coroutine.py @@ -43,3 +43,11 @@ def _iscoroutinefunction_tornado(fn): def is_coroutine_callable(wrapped): return is_coroutine_function(wrapped) or is_coroutine_function(getattr(wrapped, "__call__", None)) + + +if hasattr(inspect, 'isasyncgenfunction'): + def is_async_generator_function(wrapped): + return inspect.isasyncgenfunction(wrapped) +else: + def is_async_generator_function(wrapped): + return False diff --git a/newrelic/common/package_version_utils.py b/newrelic/common/package_version_utils.py index f3d334e2a..3152342b4 100644 --- a/newrelic/common/package_version_utils.py +++ b/newrelic/common/package_version_utils.py @@ -70,6 +70,23 @@ def int_or_str(value): def _get_package_version(name): module = sys.modules.get(name, None) version = None + + # importlib was introduced into the standard library starting in Python3.8. + if "importlib" in sys.modules and hasattr(sys.modules["importlib"], "metadata"): + try: + # In Python3.10+ packages_distribution can be checked for as well + if hasattr(sys.modules["importlib"].metadata, "packages_distributions"): # pylint: disable=E1101 + distributions = sys.modules["importlib"].metadata.packages_distributions() # pylint: disable=E1101 + distribution_name = distributions.get(name, name) + else: + distribution_name = name + + version = sys.modules["importlib"].metadata.version(distribution_name) # pylint: disable=E1101 + if version not in NULL_VERSIONS: + return version + except Exception: + pass + for attr in VERSION_ATTRS: try: version = getattr(module, attr, None) @@ -84,15 +101,6 @@ def _get_package_version(name): except Exception: pass - # importlib was introduced into the standard library starting in Python3.8. - if "importlib" in sys.modules and hasattr(sys.modules["importlib"], "metadata"): - try: - version = sys.modules["importlib"].metadata.version(name) # pylint: disable=E1101 - if version not in NULL_VERSIONS: - return version - except Exception: - pass - if "pkg_resources" in sys.modules: try: version = sys.modules["pkg_resources"].get_distribution(name).version diff --git a/newrelic/config.py b/newrelic/config.py index 8a041ad34..efeeaaec2 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -2269,6 +2269,87 @@ def _process_module_builtin_defaults(): "instrument_graphql_validate", ) + _process_module_definition( + "google.cloud.firestore_v1.base_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_base_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.document", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_document", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_document", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_document", + ) + _process_module_definition( + "google.cloud.firestore_v1.collection", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_collection", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_collection", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_collection", + ) + _process_module_definition( + "google.cloud.firestore_v1.query", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_query", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_query", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_query", + ) + _process_module_definition( + "google.cloud.firestore_v1.aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.bulk_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_bulk_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_transaction", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_transaction", + ) + _process_module_definition( "ariadne.asgi", "newrelic.hooks.framework_ariadne", diff --git a/newrelic/core/rules_engine.py b/newrelic/core/rules_engine.py index fccc5e5e1..62ecce3fe 100644 --- a/newrelic/core/rules_engine.py +++ b/newrelic/core/rules_engine.py @@ -22,6 +22,27 @@ class NormalizationRule(_NormalizationRule): + def __new__( + cls, + match_expression="", + replacement="", + ignore=False, + eval_order=0, + terminate_chain=False, + each_segment=False, + replace_all=False, + ): + return _NormalizationRule.__new__( + cls, + match_expression=match_expression, + replacement=replacement, + ignore=ignore, + eval_order=eval_order, + terminate_chain=terminate_chain, + each_segment=each_segment, + replace_all=replace_all, + ) + def __init__(self, *args, **kwargs): self.match_expression_re = re.compile(self.match_expression, re.IGNORECASE) diff --git a/newrelic/core/stats_engine.py b/newrelic/core/stats_engine.py index 203e3e796..88ec31c6e 100644 --- a/newrelic/core/stats_engine.py +++ b/newrelic/core/stats_engine.py @@ -1129,7 +1129,11 @@ def metric_data(self, normalizer=None): if normalizer is not None: for key, value in six.iteritems(self.__stats_table): - key = (normalizer(key[0])[0], key[1]) + normalized_name, ignored = normalizer(key[0]) + if ignored: + continue + + key = (normalized_name, key[1]) stats = normalized_stats.get(key) if stats is None: normalized_stats[key] = copy.copy(value) diff --git a/newrelic/hooks/component_graphqlserver.py b/newrelic/hooks/component_graphqlserver.py index 29004c11f..ebc62a34d 100644 --- a/newrelic/hooks/component_graphqlserver.py +++ b/newrelic/hooks/component_graphqlserver.py @@ -1,19 +1,18 @@ -from newrelic.api.asgi_application import wrap_asgi_application from newrelic.api.error_trace import ErrorTrace from newrelic.api.graphql_trace import GraphQLOperationTrace from newrelic.api.transaction import current_transaction -from newrelic.api.transaction_name import TransactionNameWrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, + GRAPHQL_VERSION, + ignore_graphql_duplicate_exception, ) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception -def framework_details(): - import graphql_server - return ("GraphQLServer", getattr(graphql_server, "__version__", None)) +GRAPHQL_SERVER_VERSION = get_package_version("graphql-server") +graphql_server_major_version = int(GRAPHQL_SERVER_VERSION.split(".")[0]) + def bind_query(schema, params, *args, **kwargs): return getattr(params, "query", None) @@ -30,9 +29,8 @@ def wrap_get_response(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="GraphQLServer", version=GRAPHQL_SERVER_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -45,5 +43,8 @@ def wrap_get_response(wrapped, instance, args, kwargs): with ErrorTrace(ignore=ignore_graphql_duplicate_exception): return wrapped(*args, **kwargs) + def instrument_graphqlserver(module): - wrap_function_wrapper(module, "get_response", wrap_get_response) + if graphql_server_major_version <= 2: + return + wrap_function_wrapper(module, "get_response", wrap_get_response) diff --git a/newrelic/hooks/datastore_firestore.py b/newrelic/hooks/datastore_firestore.py new file mode 100644 index 000000000..6d3196a7c --- /dev/null +++ b/newrelic/hooks/datastore_firestore.py @@ -0,0 +1,473 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.api.datastore_trace import wrap_datastore_trace +from newrelic.api.function_trace import wrap_function_trace +from newrelic.common.async_wrapper import generator_wrapper, async_generator_wrapper + + +def _conn_str_to_host(getter): + """Safely transform a getter that can retrieve a connection string into the resulting host.""" + + def closure(obj, *args, **kwargs): + try: + return getter(obj, *args, **kwargs).split(":")[0] + except Exception: + return None + + return closure + + +def _conn_str_to_port(getter): + """Safely transform a getter that can retrieve a connection string into the resulting port.""" + + def closure(obj, *args, **kwargs): + try: + return getter(obj, *args, **kwargs).split(":")[1] + except Exception: + return None + + return closure + + +# Default Target ID and Instance Info +_get_object_id = lambda obj, *args, **kwargs: getattr(obj, "id", None) +_get_client_database_string = lambda obj, *args, **kwargs: getattr( + getattr(obj, "_client", None), "_database_string", None +) +_get_client_target = lambda obj, *args, **kwargs: obj._client._target +_get_client_target_host = _conn_str_to_host(_get_client_target) +_get_client_target_port = _conn_str_to_port(_get_client_target) + +# Client Instance Info +_get_database_string = lambda obj, *args, **kwargs: getattr(obj, "_database_string", None) +_get_target = lambda obj, *args, **kwargs: obj._target +_get_target_host = _conn_str_to_host(_get_target) +_get_target_port = _conn_str_to_port(_get_target) + +# Query Target ID +_get_parent_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_parent", None), "id", None) + +# AggregationQuery Target ID +_get_collection_ref_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_collection_ref", None), "id", None) + + +def instrument_google_cloud_firestore_v1_base_client(module): + rollup = ("Datastore/all", "Datastore/Firestore/all") + wrap_function_trace( + module, "BaseClient.__init__", name="%s:BaseClient.__init__" % module.__name__, terminal=True, rollup=rollup + ) + + +def instrument_google_cloud_firestore_v1_client(module): + if hasattr(module, "Client"): + class_ = module.Client + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Client.%s" % method, + operation=method, + product="Firestore", + target=None, + host=_get_target_host, + port_path_or_id=_get_target_port, + database_name=_get_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_client(module): + if hasattr(module, "AsyncClient"): + class_ = module.AsyncClient + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncClient.%s" % method, + operation=method, + product="Firestore", + target=None, + host=_get_target_host, + port_path_or_id=_get_target_port, + database_name=_get_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_collection(module): + if hasattr(module, "CollectionReference"): + class_ = module.CollectionReference + for method in ("add", "get"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream", "list_documents"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_collection(module): + if hasattr(module, "AsyncCollectionReference"): + class_ = module.AsyncCollectionReference + for method in ("add", "get"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + operation=method, + ) + + for method in ("stream", "list_documents"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_document(module): + if hasattr(module, "DocumentReference"): + class_ = module.DocumentReference + for method in ("create", "delete", "get", "set", "update"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "DocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("collections",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "DocumentReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_document(module): + if hasattr(module, "AsyncDocumentReference"): + class_ = module.AsyncDocumentReference + for method in ("create", "delete", "get", "set", "update"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncDocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("collections",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncDocumentReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_query(module): + if hasattr(module, "Query"): + class_ = module.Query + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Query.%s" % method, + product="Firestore", + target=_get_parent_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Query.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + if hasattr(module, "CollectionGroup"): + class_ = module.CollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionGroup.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_query(module): + if hasattr(module, "AsyncQuery"): + class_ = module.AsyncQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncQuery.%s" % method, + product="Firestore", + target=_get_parent_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + if hasattr(module, "AsyncCollectionGroup"): + class_ = module.AsyncCollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionGroup.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_aggregation(module): + if hasattr(module, "AggregationQuery"): + class_ = module.AggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_collection_ref_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_aggregation(module): + if hasattr(module, "AsyncAggregationQuery"): + class_ = module.AsyncAggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncAggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncAggregationQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_collection_ref_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_batch(module): + if hasattr(module, "WriteBatch"): + class_ = module.WriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "WriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_async_batch(module): + if hasattr(module, "AsyncWriteBatch"): + class_ = module.AsyncWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncWriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_bulk_batch(module): + if hasattr(module, "BulkWriteBatch"): + class_ = module.BulkWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "BulkWriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_transaction(module): + if hasattr(module, "Transaction"): + class_ = module.Transaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, + "Transaction.%s" % method, + product="Firestore", + target=None, + operation=operation, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_async_transaction(module): + if hasattr(module, "AsyncTransaction"): + class_ = module.AsyncTransaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, + "AsyncTransaction.%s" % method, + product="Firestore", + target=None, + operation=operation, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) diff --git a/newrelic/hooks/datastore_redis.py b/newrelic/hooks/datastore_redis.py index 6854d84f3..0f1c522b7 100644 --- a/newrelic/hooks/datastore_redis.py +++ b/newrelic/hooks/datastore_redis.py @@ -161,6 +161,7 @@ "cluster_reset", "cluster_save_config", "cluster_set_config_epoch", + "client_setinfo", "cluster_setslot", "cluster_slaves", "cluster_slots", @@ -219,6 +220,7 @@ "function_load", "function_restore", "function_stats", + "gears_refresh_cluster", "geoadd", "geodist", "geohash", @@ -320,6 +322,8 @@ "pubsub_channels", "pubsub_numpat", "pubsub_numsub", + "pubsub_shardchannels", + "pubsub_shardnumsub", "punsubscribe", "quantile", "query", @@ -374,6 +378,7 @@ "smismember", "smove", "spellcheck", + "spublish", "srem", "sscan_iter", "sscan", @@ -393,6 +398,11 @@ "syndump", "synupdate", "tagvals", + "tfcall", + "tfcall_async", + "tfunction_delete", + "tfunction_list", + "tfunction_load", "time", "toggle", "touch", diff --git a/newrelic/hooks/framework_ariadne.py b/newrelic/hooks/framework_ariadne.py index 7d55a89b8..4927abe0b 100644 --- a/newrelic/hooks/framework_ariadne.py +++ b/newrelic/hooks/framework_ariadne.py @@ -21,25 +21,12 @@ from newrelic.api.wsgi_application import wrap_wsgi_application from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement -from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, -) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception +from newrelic.hooks.framework_graphql import GRAPHQL_VERSION, ignore_graphql_duplicate_exception - -def framework_details(): - try: - import ariadne - version = ariadne.__version__ - except Exception: - try: - import pkg_resources - version = pkg_resources.get_distribution("ariadne").version - except Exception: - version = None - - return ("Ariadne", version) +ARIADNE_VERSION = get_package_version("ariadne") +ariadne_version_tuple = tuple(map(int, ARIADNE_VERSION.split("."))) def bind_graphql(schema, data, *args, **kwargs): @@ -57,9 +44,8 @@ def wrap_graphql_sync(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) # No version info available on ariadne - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Ariadne", version=ARIADNE_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) query = data["query"] if hasattr(query, "body"): @@ -91,9 +77,8 @@ async def wrap_graphql(wrapped, instance, args, kwargs): result = await result return result - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) # No version info available on ariadne - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Ariadne", version=ARIADNE_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) query = data["query"] if hasattr(query, "body"): @@ -112,6 +97,9 @@ async def wrap_graphql(wrapped, instance, args, kwargs): def instrument_ariadne_execute(module): + # v0.9.0 is the version where ariadne started using graphql-core v3 + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "graphql"): wrap_function_wrapper(module, "graphql", wrap_graphql) @@ -120,10 +108,14 @@ def instrument_ariadne_execute(module): def instrument_ariadne_asgi(module): + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "GraphQL"): - wrap_asgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_asgi_application(module, "GraphQL.__call__", framework=("Ariadne", ARIADNE_VERSION)) def instrument_ariadne_wsgi(module): + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "GraphQL"): - wrap_wsgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_wsgi_application(module, "GraphQL.__call__", framework=("Ariadne", ARIADNE_VERSION)) diff --git a/newrelic/hooks/framework_graphql.py b/newrelic/hooks/framework_graphql.py index 3e1d4333c..df86e6984 100644 --- a/newrelic/hooks/framework_graphql.py +++ b/newrelic/hooks/framework_graphql.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import logging import sys import time from collections import deque +from inspect import isawaitable from newrelic.api.error_trace import ErrorTrace from newrelic.api.function_trace import FunctionTrace @@ -25,42 +25,14 @@ from newrelic.api.transaction import current_transaction, ignore_transaction from newrelic.common.object_names import callable_name, parse_exc_info from newrelic.common.object_wrapper import function_wrapper, wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement -from newrelic.packages import six - -try: - from inspect import isawaitable -except ImportError: - - def isawaitable(f): - return False - - -try: - # from promise import is_thenable as is_promise - from promise import Promise - - def is_promise(obj): - return isinstance(obj, Promise) - - def as_promise(f): - return Promise.resolve(None).then(f) - -except ImportError: - # If promises is not installed, prevent crashes by bypassing logic - def is_promise(obj): - return False - - def as_promise(f): - return f - -if six.PY3: - from newrelic.hooks.framework_graphql_py3 import ( - nr_coro_execute_name_wrapper, - nr_coro_resolver_error_wrapper, - nr_coro_resolver_wrapper, - nr_coro_graphql_impl_wrapper, - ) +from newrelic.hooks.framework_graphql_py3 import ( + nr_coro_execute_name_wrapper, + nr_coro_graphql_impl_wrapper, + nr_coro_resolver_error_wrapper, + nr_coro_resolver_wrapper, +) _logger = logging.getLogger(__name__) @@ -70,23 +42,8 @@ def as_promise(f): VERSION = None -def framework_version(): - """Framework version string.""" - global VERSION - if VERSION is None: - from graphql import __version__ as version - - VERSION = version - - return VERSION - - -def graphql_version(): - """Minor version tuple.""" - version = framework_version() - - # Take first two values in version to avoid ValueErrors with pre-releases (ex: 3.2.0a0) - return tuple(int(v) for v in version.split(".")[:2]) +GRAPHQL_VERSION = get_package_version("graphql-core") +major_version = int(GRAPHQL_VERSION.split(".")[0]) def ignore_graphql_duplicate_exception(exc, val, tb): @@ -115,20 +72,6 @@ def ignore_graphql_duplicate_exception(exc, val, tb): return None # Follow original exception matching rules -def catch_promise_error(e): - if hasattr(e, "__traceback__"): - notice_error(error=(e.__class__, e, e.__traceback__), ignore=ignore_graphql_duplicate_exception) - else: - # Python 2 does not retain a reference to the traceback and is irretrievable from a promise. - # As a workaround, raise the error and report it despite having an incorrect traceback. - try: - raise e - except Exception: - notice_error(ignore=ignore_graphql_duplicate_exception) - - return None - - def wrap_executor_context_init(wrapped, instance, args, kwargs): result = wrapped(*args, **kwargs) @@ -150,10 +93,6 @@ def bind_operation_v3(operation, root_value): return operation -def bind_operation_v2(exe_context, operation, root_value): - return operation - - def wrap_execute_operation(wrapped, instance, args, kwargs): transaction = current_transaction() trace = current_trace() @@ -170,15 +109,9 @@ def wrap_execute_operation(wrapped, instance, args, kwargs): try: operation = bind_operation_v3(*args, **kwargs) except TypeError: - try: - operation = bind_operation_v2(*args, **kwargs) - except TypeError: - return wrapped(*args, **kwargs) + return wrapped(*args, **kwargs) - if graphql_version() < (3, 0): - execution_context = args[0] - else: - execution_context = instance + execution_context = instance trace.operation_name = get_node_value(operation, "name") or "" @@ -203,14 +136,11 @@ def set_name(value=None): # Operation trace sets transaction name trace.set_transaction_name(priority=14) return value - - if is_promise(result) and result.is_pending and graphql_version() < (3, 0): - return result.then(set_name) - elif isawaitable(result) and not is_promise(result): + + if isawaitable(result): return nr_coro_execute_name_wrapper(wrapped, result, set_name) else: - set_name() - return result + return set_name(result) def get_node_value(field, attr, subattr="value"): @@ -221,39 +151,25 @@ def get_node_value(field, attr, subattr="value"): def is_fragment_spread_node(field): - # Resolve version specific imports - try: - from graphql.language.ast import FragmentSpread - except ImportError: - from graphql import FragmentSpreadNode as FragmentSpread + from graphql.language.ast import FragmentSpreadNode - return isinstance(field, FragmentSpread) + return isinstance(field, FragmentSpreadNode) def is_fragment(field): - # Resolve version specific imports - try: - from graphql.language.ast import FragmentSpread, InlineFragment - except ImportError: - from graphql import FragmentSpreadNode as FragmentSpread - from graphql import InlineFragmentNode as InlineFragment - - _fragment_types = (InlineFragment, FragmentSpread) + from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode + _fragment_types = (InlineFragmentNode, FragmentSpreadNode) return isinstance(field, _fragment_types) def is_named_fragment(field): - # Resolve version specific imports - try: - from graphql.language.ast import NamedType - except ImportError: - from graphql import NamedTypeNode as NamedType + from graphql.language.ast import NamedTypeNode return ( is_fragment(field) and getattr(field, "type_condition", None) is not None - and isinstance(field.type_condition, NamedType) + and isinstance(field.type_condition, NamedTypeNode) ) @@ -333,8 +249,7 @@ def wrap_middleware(wrapped, instance, args, kwargs): transaction.set_transaction_name(name, "GraphQL", priority=12) with FunctionTrace(name, source=wrapped): with ErrorTrace(ignore=ignore_graphql_duplicate_exception): - result = wrapped(*args, **kwargs) - return result + return wrapped(*args, **kwargs) def bind_get_field_resolver(field_resolver): @@ -391,19 +306,8 @@ def wrap_resolver(wrapped, instance, args, kwargs): with ErrorTrace(ignore=ignore_graphql_duplicate_exception): sync_start_time = time.time() result = wrapped(*args, **kwargs) - - if is_promise(result) and result.is_pending and graphql_version() < (3, 0): - @functools.wraps(wrapped) - def nr_promise_resolver_error_wrapper(v): - with trace: - with ErrorTrace(ignore=ignore_graphql_duplicate_exception): - try: - return result.get() - except Exception: - transaction.set_transaction_name(name, "GraphQL", priority=15) - raise - return as_promise(nr_promise_resolver_error_wrapper) - elif isawaitable(result) and not is_promise(result): + + if isawaitable(result): # Grab any async resolvers and wrap with traces return nr_coro_resolver_error_wrapper( wrapped, name, trace, ignore_graphql_duplicate_exception, result, transaction @@ -411,9 +315,6 @@ def nr_promise_resolver_error_wrapper(v): else: with trace: trace.start_time = sync_start_time - if is_promise(result) and result.is_rejected: - result.catch(catch_promise_error).get() - transaction.set_transaction_name(name, "GraphQL", priority=15) return result @@ -456,19 +357,12 @@ def bind_resolve_field_v3(parent_type, source, field_nodes, path): return parent_type, field_nodes, path -def bind_resolve_field_v2(exe_context, parent_type, source, field_asts, parent_info, field_path): - return parent_type, field_asts, field_path - - def wrap_resolve_field(wrapped, instance, args, kwargs): transaction = current_transaction() if transaction is None: return wrapped(*args, **kwargs) - if graphql_version() < (3, 0): - bind_resolve_field = bind_resolve_field_v2 - else: - bind_resolve_field = bind_resolve_field_v3 + bind_resolve_field = bind_resolve_field_v3 try: parent_type, field_asts, field_path = bind_resolve_field(*args, **kwargs) @@ -497,14 +391,7 @@ def wrap_resolve_field(wrapped, instance, args, kwargs): notice_error(ignore=ignore_graphql_duplicate_exception) raise - if is_promise(result) and result.is_pending and graphql_version() < (3, 0): - @functools.wraps(wrapped) - def nr_promise_resolver_wrapper(v): - with trace: - with ErrorTrace(ignore=ignore_graphql_duplicate_exception): - return result.get() - return as_promise(nr_promise_resolver_wrapper) - elif isawaitable(result) and not is_promise(result): + if isawaitable(result): # Asynchronous resolvers (returned coroutines from non-coroutine functions) # Return a coroutine that handles wrapping in a resolver trace return nr_coro_resolver_wrapper(wrapped, trace, ignore_graphql_duplicate_exception, result) @@ -539,11 +426,8 @@ def wrap_graphql_impl(wrapped, instance, args, kwargs): if not transaction: return wrapped(*args, **kwargs) - transaction.add_framework_info(name="GraphQL", version=framework_version()) - if graphql_version() < (3, 0): - bind_query = bind_execute_graphql_query - else: - bind_query = bind_graphql_impl_query + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) + bind_query = bind_graphql_impl_query try: schema, query = bind_query(*args, **kwargs) @@ -564,7 +448,7 @@ def wrap_graphql_impl(wrapped, instance, args, kwargs): framework = schema._nr_framework trace.product = framework[0] transaction.add_framework_info(name=framework[0], version=framework[1]) - + # Trace must be manually started and stopped to ensure it exists prior to and during the entire duration of the query. # Otherwise subsequent instrumentation will not be able to find an operation trace and will have issues. trace.__enter__() @@ -576,19 +460,7 @@ def wrap_graphql_impl(wrapped, instance, args, kwargs): trace.__exit__(*sys.exc_info()) raise else: - if is_promise(result) and result.is_pending: - # Execution promise, append callbacks to exit trace. - def on_resolve(v): - trace.__exit__(None, None, None) - return v - - def on_reject(e): - catch_promise_error(e) - trace.__exit__(e.__class__, e, e.__traceback__) - return e - - return result.then(on_resolve, on_reject) - elif isawaitable(result) and not is_promise(result): + if isawaitable(result): # Asynchronous implementations # Return a coroutine that handles closing the operation trace return nr_coro_graphql_impl_wrapper(wrapped, trace, ignore_graphql_duplicate_exception, result) @@ -620,11 +492,15 @@ def instrument_graphql_execute(module): def instrument_graphql_execution_utils(module): + if major_version == 2: + return if hasattr(module, "ExecutionContext"): wrap_function_wrapper(module, "ExecutionContext.__init__", wrap_executor_context_init) def instrument_graphql_execution_middleware(module): + if major_version == 2: + return if hasattr(module, "get_middleware_resolvers"): wrap_function_wrapper(module, "get_middleware_resolvers", wrap_get_middleware_resolvers) if hasattr(module, "MiddlewareManager"): @@ -632,20 +508,26 @@ def instrument_graphql_execution_middleware(module): def instrument_graphql_error_located_error(module): + if major_version == 2: + return if hasattr(module, "located_error"): wrap_function_wrapper(module, "located_error", wrap_error_handler) def instrument_graphql_validate(module): + if major_version == 2: + return wrap_function_wrapper(module, "validate", wrap_validate) def instrument_graphql(module): + if major_version == 2: + return if hasattr(module, "graphql_impl"): wrap_function_wrapper(module, "graphql_impl", wrap_graphql_impl) - if hasattr(module, "execute_graphql"): - wrap_function_wrapper(module, "execute_graphql", wrap_graphql_impl) def instrument_graphql_parser(module): + if major_version == 2: + return wrap_function_wrapper(module, "parse", wrap_parse) diff --git a/newrelic/hooks/framework_strawberry.py b/newrelic/hooks/framework_strawberry.py index cfbe450d6..e6d06bb04 100644 --- a/newrelic/hooks/framework_strawberry.py +++ b/newrelic/hooks/framework_strawberry.py @@ -16,29 +16,14 @@ from newrelic.api.error_trace import ErrorTrace from newrelic.api.graphql_trace import GraphQLOperationTrace from newrelic.api.transaction import current_transaction -from newrelic.api.transaction_name import TransactionNameWrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement -from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, -) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception +from newrelic.hooks.framework_graphql import GRAPHQL_VERSION, ignore_graphql_duplicate_exception - -def framework_details(): - import strawberry - - try: - version = strawberry.__version__ - except Exception: - try: - import pkg_resources - version = pkg_resources.get_distribution("strawberry-graphql").version - except Exception: - version = None - - return ("Strawberry", version) +STRAWBERRY_GRAPHQL_VERSION = get_package_version("strawberry-graphql") +strawberry_version_tuple = tuple(map(int, STRAWBERRY_GRAPHQL_VERSION.split("."))) def bind_execute(query, *args, **kwargs): @@ -56,9 +41,8 @@ def wrap_execute_sync(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Strawberry", version=STRAWBERRY_GRAPHQL_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -83,9 +67,8 @@ async def wrap_execute(wrapped, instance, args, kwargs): except TypeError: return await wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Strawberry", version=STRAWBERRY_GRAPHQL_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -107,7 +90,7 @@ def wrap_from_resolver(wrapped, instance, args, kwargs): result = wrapped(*args, **kwargs) try: - field = bind_from_resolver(*args, **kwargs) + field = bind_from_resolver(*args, **kwargs) except TypeError: pass else: @@ -119,6 +102,8 @@ def wrap_from_resolver(wrapped, instance, args, kwargs): def instrument_strawberry_schema(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "Schema"): if hasattr(module.Schema, "execute"): wrap_function_wrapper(module, "Schema.execute", wrap_execute) @@ -127,11 +112,15 @@ def instrument_strawberry_schema(module): def instrument_strawberry_asgi(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "GraphQL"): - wrap_asgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_asgi_application(module, "GraphQL.__call__", framework=("Strawberry", STRAWBERRY_GRAPHQL_VERSION)) def instrument_strawberry_schema_converter(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "GraphQLCoreConverter"): if hasattr(module.GraphQLCoreConverter, "from_resolver"): wrap_function_wrapper(module, "GraphQLCoreConverter.from_resolver", wrap_from_resolver) diff --git a/tests/agent_features/_test_async_coroutine_trace.py b/tests/agent_features/_test_async_coroutine_trace.py index 51b81f5f6..1250b8c25 100644 --- a/tests/agent_features/_test_async_coroutine_trace.py +++ b/tests/agent_features/_test_async_coroutine_trace.py @@ -28,6 +28,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -41,6 +42,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_awaitable_timing(event_loop, trace, metric): @@ -79,6 +82,8 @@ def _test(): (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) @pytest.mark.parametrize("yield_from", [True, False]) diff --git a/tests/agent_features/_test_async_generator_trace.py b/tests/agent_features/_test_async_generator_trace.py new file mode 100644 index 000000000..30b970c37 --- /dev/null +++ b/tests/agent_features/_test_async_generator_trace.py @@ -0,0 +1,548 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import sys +import time + +import pytest +from testing_support.fixtures import capture_transaction_metrics, validate_tt_parenting +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +asyncio = pytest.importorskip("asyncio") + + +@pytest.mark.parametrize( + "trace,metric", + [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), + ], +) +def test_async_generator_timing(event_loop, trace, metric): + @trace() + async def simple_gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_async_generator_timing", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_async_generator_timing") + def _test_async_generator_timing(): + async def _test(): + async for _ in simple_gen(): + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_timing() + + # Check that coroutines time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +class MyException(Exception): + pass + + +@validate_transaction_metrics( + "test_async_generator_error", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@validate_transaction_errors(errors=["_test_async_generator_trace:MyException"]) +def test_async_generator_error(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + @background_task(name="test_async_generator_error") + async def _test(): + gen = agen() + await gen.asend(None) + await gen.athrow(MyException) + + with pytest.raises(MyException): + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_async_generator_caught_exception", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@validate_transaction_errors(errors=[]) +def test_async_generator_caught_exception(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + time.sleep(0.1) + try: + yield + except ValueError: + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_caught_exception") + def _test_async_generator_caught_exception(): + async def _test(): + gen = agen() + # kickstart the generator (the try/except logic is inside the + # generator) + await gen.asend(None) + await gen.athrow(ValueError) + + # consume the generator + async for _ in gen: + pass + + # The ValueError should not be reraised + event_loop.run_until_complete(_test()) + _test_async_generator_caught_exception() + + assert full_metrics[("Function/agen", "")].total_call_time >= 0.2 + + +@validate_transaction_metrics( + "test_async_generator_handles_terminal_nodes", + background_task=True, + scoped_metrics=[("Function/parent", 1), ("Function/agen", None)], + rollup_metrics=[("Function/parent", 1), ("Function/agen", None)], +) +def test_async_generator_handles_terminal_nodes(event_loop): + # sometimes coroutines can be called underneath terminal nodes + # In this case, the trace shouldn't actually be created and we also + # shouldn't get any errors + + @function_trace(name="agen") + async def agen(): + yield + time.sleep(0.1) + + @function_trace(name="parent", terminal=True) + async def parent(): + # parent calls child + async for _ in agen(): + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_handles_terminal_nodes") + def _test_async_generator_handles_terminal_nodes(): + async def _test(): + await parent() + + event_loop.run_until_complete(_test()) + _test_async_generator_handles_terminal_nodes() + + metric_key = ("Function/parent", "") + assert full_metrics[metric_key].total_exclusive_call_time >= 0.1 + + +@validate_transaction_metrics( + "test_async_generator_close_ends_trace", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_async_generator_close_ends_trace(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + @background_task(name="test_async_generator_close_ends_trace") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + # trace should be ended/recorded by close + await gen.aclose() + + # We may call gen.close as many times as we want + await gen.aclose() + + event_loop.run_until_complete(_test()) + +@validate_tt_parenting( + ( + "TransactionNode", + [ + ( + "FunctionNode", + [ + ("FunctionNode", []), + ], + ), + ], + ) +) +@validate_transaction_metrics( + "test_async_generator_parents", + background_task=True, + scoped_metrics=[("Function/child", 1), ("Function/parent", 1)], + rollup_metrics=[("Function/child", 1), ("Function/parent", 1)], +) +def test_async_generator_parents(event_loop): + @function_trace(name="child") + async def child(): + yield + time.sleep(0.1) + yield + + @function_trace(name="parent") + async def parent(): + time.sleep(0.1) + yield + async for _ in child(): + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_parents") + def _test_async_generator_parents(): + async def _test(): + async for _ in parent(): + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_parents() + + # Check that the child time is subtracted from the parent time (parenting + # relationship is correctly established) + key = ("Function/parent", "") + assert full_metrics[key].total_exclusive_call_time < 0.2 + + +@validate_transaction_metrics( + "test_asend_receives_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_asend_receives_a_value(event_loop): + _received = [] + @function_trace(name="agen") + async def agen(): + value = yield + _received.append(value) + yield value + + @background_task(name="test_asend_receives_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + assert await gen.asend("foobar") == "foobar" + assert _received and _received[0] == "foobar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_athrow_yields_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_athrow_yields_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + try: + yield + except MyException: + yield "foobar" + + @background_task(name="test_athrow_yields_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + assert await gen.athrow(MyException) == "foobar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_multiple_throws_yield_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_multiple_throws_yield_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + value = None + for _ in range(4): + try: + yield value + value = "bar" + except MyException: + value = "foo" + + + @background_task(name="test_multiple_throws_yield_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + assert await gen.asend(None) is None + assert await gen.athrow(MyException) == "foo" + assert await gen.athrow(MyException) == "foo" + assert await gen.asend(None) == "bar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_athrow_does_not_yield_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_athrow_does_not_yield_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + try: + yield + except MyException: + return + + @background_task(name="test_athrow_does_not_yield_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + # async generator will raise StopAsyncIteration + with pytest.raises(StopAsyncIteration): + await gen.athrow(MyException) + + + event_loop.run_until_complete(_test()) + + +@pytest.mark.parametrize( + "trace", + [ + function_trace(name="simple_gen"), + external_trace(library="lib", url="http://foo.com"), + database_trace("select * from foo"), + datastore_trace("lib", "foo", "bar"), + message_trace("lib", "op", "typ", "name"), + memcache_trace("cmd"), + ], +) +def test_async_generator_functions_outside_of_transaction(event_loop, trace): + @trace + async def agen(): + for _ in range(2): + yield "foo" + + async def _test(): + assert [_ async for _ in agen()] == ["foo", "foo"] + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_catching_generator_exit_causes_runtime_error", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_catching_generator_exit_causes_runtime_error(event_loop): + @function_trace(name="agen") + async def agen(): + try: + yield + except GeneratorExit: + yield + + @background_task(name="test_catching_generator_exit_causes_runtime_error") + async def _test(): + gen = agen() + + # kickstart the coroutine (we're inside the try now) + await gen.asend(None) + + # Generators cannot catch generator exit exceptions (which are injected by + # close). This will result in a runtime error. + with pytest.raises(RuntimeError): + await gen.aclose() + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_async_generator_time_excludes_creation_time", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_async_generator_time_excludes_creation_time(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_time_excludes_creation_time") + def _test_async_generator_time_excludes_creation_time(): + async def _test(): + gen = agen() + time.sleep(0.1) + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_time_excludes_creation_time() + + # check that the trace does not include the time between creation and + # consumption + assert full_metrics[("Function/agen", "")].total_call_time < 0.1 + + +@validate_transaction_metrics( + "test_complete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@background_task(name="test_complete_async_generator") +def test_complete_async_generator(event_loop): + @function_trace(name="agen") + async def agen(): + for i in range(5): + yield i + + async def _test(): + gen = agen() + assert [x async for x in gen] == [x for x in range(5)] + + event_loop.run_until_complete(_test()) + + +@pytest.mark.parametrize("nr_transaction", [True, False]) +def test_incomplete_async_generator(event_loop, nr_transaction): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + def _test_incomplete_async_generator(): + async def _test(): + c = agen() + + async for _ in c: + break + + if nr_transaction: + _test = background_task(name="test_incomplete_async_generator")(_test) + + event_loop.run_until_complete(_test()) + + if nr_transaction: + _test_incomplete_async_generator = validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + )(_test_incomplete_async_generator) + + _test_incomplete_async_generator() + + +def test_incomplete_async_generator_transaction_exited(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + @validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + ) + def _test_incomplete_async_generator(): + c = agen() + @background_task(name="test_incomplete_async_generator") + async def _test(): + async for _ in c: + break + + event_loop.run_until_complete(_test()) + + # Remove generator after transaction completes + del c + + _test_incomplete_async_generator() diff --git a/tests/agent_features/test_async_generator_trace.py b/tests/agent_features/test_async_generator_trace.py new file mode 100644 index 000000000..208cf1588 --- /dev/null +++ b/tests/agent_features/test_async_generator_trace.py @@ -0,0 +1,19 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +# Async Generators were introduced in Python 3.6, but some APIs weren't completely stable until Python 3.7. +if sys.version_info >= (3, 7): + from _test_async_generator_trace import * # NOQA diff --git a/tests/agent_features/test_async_wrapper_detection.py b/tests/agent_features/test_async_wrapper_detection.py new file mode 100644 index 000000000..bb1fd3f1e --- /dev/null +++ b/tests/agent_features/test_async_wrapper_detection.py @@ -0,0 +1,102 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import functools +import time + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +from newrelic.common.async_wrapper import generator_wrapper + +from testing_support.fixtures import capture_transaction_metrics +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +trace_metric_cases = [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), +] + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_automatic_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace() + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + for _ in gen(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_manual_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace(async_wrapper=generator_wrapper) + def wrapper_func(): + """Function that returns a generator object, obscuring the automatic introspection of async_wrapper()""" + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + return gen() + + for _ in wrapper_func(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 diff --git a/tests/agent_features/test_coroutine_trace.py b/tests/agent_features/test_coroutine_trace.py index 36e365bc4..2043f1326 100644 --- a/tests/agent_features/test_coroutine_trace.py +++ b/tests/agent_features/test_coroutine_trace.py @@ -31,6 +31,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -47,6 +48,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_coroutine_timing(trace, metric): @@ -337,6 +340,37 @@ def coro(): pass +@validate_transaction_metrics( + "test_multiple_throws_yield_a_value", + background_task=True, + scoped_metrics=[("Function/coro", 1)], + rollup_metrics=[("Function/coro", 1)], +) +@background_task(name="test_multiple_throws_yield_a_value") +def test_multiple_throws_yield_a_value(): + @function_trace(name="coro") + def coro(): + value = None + for _ in range(4): + try: + yield value + value = "bar" + except MyException: + value = "foo" + + c = coro() + + # kickstart the coroutine + assert next(c) is None + assert c.throw(MyException) == "foo" + assert c.throw(MyException) == "foo" + assert next(c) == "bar" + + # finish consumption of the coroutine if necessary + for _ in c: + pass + + @pytest.mark.parametrize( "trace", [ diff --git a/tests/agent_features/test_datastore_trace.py b/tests/agent_features/test_datastore_trace.py new file mode 100644 index 000000000..08067e040 --- /dev/null +++ b/tests/agent_features/test_datastore_trace.py @@ -0,0 +1,89 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from testing_support.validators.validate_datastore_trace_inputs import ( + validate_datastore_trace_inputs, +) + +from newrelic.api.background_task import background_task +from newrelic.api.datastore_trace import DatastoreTrace, DatastoreTraceWrapper + + +@validate_datastore_trace_inputs( + operation="test_operation", + target="test_target", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", +) +@background_task() +def test_dt_trace_all_args(): + with DatastoreTrace( + product="Agent Features", + target="test_target", + operation="test_operation", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", + ): + pass + + +@validate_datastore_trace_inputs(operation=None, target=None, host=None, port_path_or_id=None, database_name=None) +@background_task() +def test_dt_trace_empty(): + with DatastoreTrace(product=None, target=None, operation=None): + pass + + +@background_task() +def test_dt_trace_callable_args(): + def product_callable(): + return "Agent Features" + + def target_callable(): + return "test_target" + + def operation_callable(): + return "test_operation" + + def host_callable(): + return "test_host" + + def port_path_id_callable(): + return "test_port" + + def db_name_callable(): + return "test_db_name" + + @validate_datastore_trace_inputs( + operation="test_operation", + target="test_target", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", + ) + def _test(): + pass + + wrapped_fn = DatastoreTraceWrapper( + _test, + product=product_callable, + target=target_callable, + operation=operation_callable, + host=host_callable, + port_path_or_id=port_path_id_callable, + database_name=db_name_callable, + ) + wrapped_fn() diff --git a/tests/agent_unittests/test_package_version_utils.py b/tests/agent_unittests/test_package_version_utils.py index 435d74947..30c22cff1 100644 --- a/tests/agent_unittests/test_package_version_utils.py +++ b/tests/agent_unittests/test_package_version_utils.py @@ -24,11 +24,19 @@ get_package_version_tuple, ) +# Notes: +# importlib.metadata was a provisional addition to the std library in PY38 and PY39 +# while pkg_resources was deprecated. +# importlib.metadata is no longer provisional in PY310+. It added some attributes +# such as distribution_packages and removed pkg_resources. + IS_PY38_PLUS = sys.version_info[:2] >= (3, 8) +IS_PY310_PLUS = sys.version_info[:2] >= (3,10) SKIP_IF_NOT_IMPORTLIB_METADATA = pytest.mark.skipif(not IS_PY38_PLUS, reason="importlib.metadata is not supported.") SKIP_IF_IMPORTLIB_METADATA = pytest.mark.skipif( IS_PY38_PLUS, reason="importlib.metadata is preferred over pkg_resources." ) +SKIP_IF_NOT_PY310_PLUS = pytest.mark.skipif(not IS_PY310_PLUS, reason="These features were added in 3.10+") @pytest.fixture(scope="function", autouse=True) @@ -38,8 +46,10 @@ def patched_pytest_module(monkeypatch): monkeypatch.delattr(pytest, attr) yield pytest + - +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA @pytest.mark.parametrize( "attr,value,expected_value", ( @@ -58,6 +68,8 @@ def test_get_package_version(attr, value, expected_value): delattr(pytest, attr) +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA def test_skips_version_callables(): # There is no file/module here, so we monkeypatch # pytest instead for our purposes @@ -72,6 +84,8 @@ def test_skips_version_callables(): delattr(pytest, "version_tuple") +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA @pytest.mark.parametrize( "attr,value,expected_value", ( @@ -97,6 +111,13 @@ def test_importlib_metadata(): assert version not in NULL_VERSIONS, version +@SKIP_IF_NOT_PY310_PLUS +@validate_function_called("importlib.metadata", "packages_distributions") +def test_mapping_import_to_distribution_packages(): + version = get_package_version("pytest") + assert version not in NULL_VERSIONS, version + + @SKIP_IF_IMPORTLIB_METADATA @validate_function_called("pkg_resources", "get_distribution") def test_pkg_resources_metadata(): diff --git a/tests/component_djangorestframework/test_application.py b/tests/component_djangorestframework/test_application.py index c036f068d..2874d934f 100644 --- a/tests/component_djangorestframework/test_application.py +++ b/tests/component_djangorestframework/test_application.py @@ -12,190 +12,169 @@ # See the License for the specific language governing permissions and # limitations under the License. +import django import pytest import webtest +from testing_support.fixtures import function_not_called, override_generic_settings +from testing_support.validators.validate_code_level_metrics import ( + validate_code_level_metrics, +) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) -from newrelic.packages import six from newrelic.core.config import global_settings +from newrelic.packages import six -from testing_support.fixtures import ( - override_generic_settings, - function_not_called) -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_code_level_metrics import validate_code_level_metrics -import django - -DJANGO_VERSION = tuple(map(int, django.get_version().split('.')[:2])) +DJANGO_VERSION = tuple(map(int, django.get_version().split(".")[:2])) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def target_application(): from wsgi import application + test_application = webtest.TestApp(application) return test_application if DJANGO_VERSION >= (1, 10): - url_module_path = 'django.urls.resolvers' + url_module_path = "django.urls.resolvers" # Django 1.10 new style middleware removed individual process_* methods. # All middleware in Django 1.10+ is called through the __call__ methods on # middlwares. - process_request_method = '' - process_view_method = '' - process_response_method = '' + process_request_method = "" + process_view_method = "" + process_response_method = "" else: - url_module_path = 'django.core.urlresolvers' - process_request_method = '.process_request' - process_view_method = '.process_view' - process_response_method = '.process_response' + url_module_path = "django.core.urlresolvers" + process_request_method = ".process_request" + process_view_method = ".process_view" + process_response_method = ".process_response" if DJANGO_VERSION >= (2, 0): - url_resolver_cls = 'URLResolver' + url_resolver_cls = "URLResolver" else: - url_resolver_cls = 'RegexURLResolver' + url_resolver_cls = "RegexURLResolver" _scoped_metrics = [ - ('Function/django.core.handlers.wsgi:WSGIHandler.__call__', 1), - ('Python/WSGI/Application', 1), - ('Python/WSGI/Response', 1), - ('Python/WSGI/Finalize', 1), - (('Function/django.middleware.common:' - 'CommonMiddleware' + process_request_method), 1), - (('Function/django.contrib.sessions.middleware:' - 'SessionMiddleware' + process_request_method), 1), - (('Function/django.contrib.auth.middleware:' - 'AuthenticationMiddleware' + process_request_method), 1), - (('Function/django.contrib.messages.middleware:' - 'MessageMiddleware' + process_request_method), 1), - (('Function/%s:' % url_module_path + - '%s.resolve' % url_resolver_cls), 1), - (('Function/django.middleware.csrf:' - 'CsrfViewMiddleware' + process_view_method), 1), - (('Function/django.contrib.messages.middleware:' - 'MessageMiddleware' + process_response_method), 1), - (('Function/django.middleware.csrf:' - 'CsrfViewMiddleware' + process_response_method), 1), - (('Function/django.contrib.sessions.middleware:' - 'SessionMiddleware' + process_response_method), 1), - (('Function/django.middleware.common:' - 'CommonMiddleware' + process_response_method), 1), + ("Function/django.core.handlers.wsgi:WSGIHandler.__call__", 1), + ("Python/WSGI/Application", 1), + ("Python/WSGI/Response", 1), + ("Python/WSGI/Finalize", 1), + (("Function/django.middleware.common:CommonMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.sessions.middleware:SessionMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.auth.middleware:AuthenticationMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.messages.middleware:MessageMiddleware%s" % process_request_method), 1), + (("Function/%s:%s.resolve" % (url_module_path, url_resolver_cls)), 1), + (("Function/django.middleware.csrf:CsrfViewMiddleware%s" % process_view_method), 1), + (("Function/django.contrib.messages.middleware:MessageMiddleware%s" % process_response_method), 1), + (("Function/django.middleware.csrf:CsrfViewMiddleware%s" % process_response_method), 1), + (("Function/django.contrib.sessions.middleware:SessionMiddleware%s" % process_response_method), 1), + (("Function/django.middleware.common:CommonMiddleware%s" % process_response_method), 1), ] _test_application_index_scoped_metrics = list(_scoped_metrics) -_test_application_index_scoped_metrics.append(('Function/views:index', 1)) +_test_application_index_scoped_metrics.append(("Function/views:index", 1)) if DJANGO_VERSION >= (1, 5): - _test_application_index_scoped_metrics.extend([ - ('Function/django.http.response:HttpResponse.close', 1)]) + _test_application_index_scoped_metrics.extend([("Function/django.http.response:HttpResponse.close", 1)]) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics('views:index', - scoped_metrics=_test_application_index_scoped_metrics) +@validate_transaction_metrics("views:index", scoped_metrics=_test_application_index_scoped_metrics) @validate_code_level_metrics("views", "index") def test_application_index(target_application): - response = target_application.get('') - response.mustcontain('INDEX RESPONSE') + response = target_application.get("") + response.mustcontain("INDEX RESPONSE") _test_application_view_scoped_metrics = list(_scoped_metrics) -_test_application_view_scoped_metrics.append(('Function/urls:View.get', 1)) +_test_application_view_scoped_metrics.append(("Function/urls:View.get", 1)) if DJANGO_VERSION >= (1, 5): - _test_application_view_scoped_metrics.extend([ - ('Function/rest_framework.response:Response.close', 1)]) + _test_application_view_scoped_metrics.extend([("Function/rest_framework.response:Response.close", 1)]) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics('urls:View.get', - scoped_metrics=_test_application_view_scoped_metrics) +@validate_transaction_metrics("urls:View.get", scoped_metrics=_test_application_view_scoped_metrics) @validate_code_level_metrics("urls.View", "get") def test_application_view(target_application): - response = target_application.get('/view/') + response = target_application.get("/view/") assert response.status_int == 200 - response.mustcontain('restframework view response') + response.mustcontain("restframework view response") _test_application_view_error_scoped_metrics = list(_scoped_metrics) -_test_application_view_error_scoped_metrics.append( - ('Function/urls:ViewError.get', 1)) +_test_application_view_error_scoped_metrics.append(("Function/urls:ViewError.get", 1)) -@validate_transaction_errors(errors=['urls:Error']) -@validate_transaction_metrics('urls:ViewError.get', - scoped_metrics=_test_application_view_error_scoped_metrics) +@validate_transaction_errors(errors=["urls:Error"]) +@validate_transaction_metrics("urls:ViewError.get", scoped_metrics=_test_application_view_error_scoped_metrics) @validate_code_level_metrics("urls.ViewError", "get") def test_application_view_error(target_application): - target_application.get('/view_error/', status=500) + target_application.get("/view_error/", status=500) _test_application_view_handle_error_scoped_metrics = list(_scoped_metrics) -_test_application_view_handle_error_scoped_metrics.append( - ('Function/urls:ViewHandleError.get', 1)) +_test_application_view_handle_error_scoped_metrics.append(("Function/urls:ViewHandleError.get", 1)) -@pytest.mark.parametrize('status,should_record', [(418, True), (200, False)]) -@pytest.mark.parametrize('use_global_exc_handler', [True, False]) +@pytest.mark.parametrize("status,should_record", [(418, True), (200, False)]) +@pytest.mark.parametrize("use_global_exc_handler", [True, False]) @validate_code_level_metrics("urls.ViewHandleError", "get") -def test_application_view_handle_error(status, should_record, - use_global_exc_handler, target_application): - errors = ['urls:Error'] if should_record else [] +def test_application_view_handle_error(status, should_record, use_global_exc_handler, target_application): + errors = ["urls:Error"] if should_record else [] @validate_transaction_errors(errors=errors) - @validate_transaction_metrics('urls:ViewHandleError.get', - scoped_metrics=_test_application_view_handle_error_scoped_metrics) + @validate_transaction_metrics( + "urls:ViewHandleError.get", scoped_metrics=_test_application_view_handle_error_scoped_metrics + ) def _test(): - response = target_application.get( - '/view_handle_error/%s/%s/' % (status, use_global_exc_handler), - status=status) + response = target_application.get("/view_handle_error/%s/%s/" % (status, use_global_exc_handler), status=status) if use_global_exc_handler: - response.mustcontain('exception was handled global') + response.mustcontain("exception was handled global") else: - response.mustcontain('exception was handled not global') + response.mustcontain("exception was handled not global") _test() -_test_api_view_view_name_get = 'urls:wrapped_view.get' +_test_api_view_view_name_get = "urls:wrapped_view.get" _test_api_view_scoped_metrics_get = list(_scoped_metrics) -_test_api_view_scoped_metrics_get.append( - ('Function/%s' % _test_api_view_view_name_get, 1)) +_test_api_view_scoped_metrics_get.append(("Function/%s" % _test_api_view_view_name_get, 1)) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics(_test_api_view_view_name_get, - scoped_metrics=_test_api_view_scoped_metrics_get) +@validate_transaction_metrics(_test_api_view_view_name_get, scoped_metrics=_test_api_view_scoped_metrics_get) @validate_code_level_metrics("urls.WrappedAPIView", "wrapped_view", py2_namespace="urls") def test_api_view_get(target_application): - response = target_application.get('/api_view/') - response.mustcontain('wrapped_view response') + response = target_application.get("/api_view/") + response.mustcontain("wrapped_view response") -_test_api_view_view_name_post = 'urls:wrapped_view.http_method_not_allowed' +_test_api_view_view_name_post = "urls:wrapped_view.http_method_not_allowed" _test_api_view_scoped_metrics_post = list(_scoped_metrics) -_test_api_view_scoped_metrics_post.append( - ('Function/%s' % _test_api_view_view_name_post, 1)) +_test_api_view_scoped_metrics_post.append(("Function/%s" % _test_api_view_view_name_post, 1)) -@validate_transaction_errors( - errors=['rest_framework.exceptions:MethodNotAllowed']) -@validate_transaction_metrics(_test_api_view_view_name_post, - scoped_metrics=_test_api_view_scoped_metrics_post) +@validate_transaction_errors(errors=["rest_framework.exceptions:MethodNotAllowed"]) +@validate_transaction_metrics(_test_api_view_view_name_post, scoped_metrics=_test_api_view_scoped_metrics_post) def test_api_view_method_not_allowed(target_application): - target_application.post('/api_view/', status=405) + target_application.post("/api_view/", status=405) def test_application_view_agent_disabled(target_application): settings = global_settings() - @override_generic_settings(settings, {'enabled': False}) - @function_not_called('newrelic.core.stats_engine', - 'StatsEngine.record_transaction') + @override_generic_settings(settings, {"enabled": False}) + @function_not_called("newrelic.core.stats_engine", "StatsEngine.record_transaction") def _test(): - response = target_application.get('/view/') + response = target_application.get("/view/") assert response.status_int == 200 - response.mustcontain('restframework view response') + response.mustcontain("restframework view response") _test() diff --git a/tests/component_flask_rest/test_application.py b/tests/component_flask_rest/test_application.py index d463a0205..0decc8ba7 100644 --- a/tests/component_flask_rest/test_application.py +++ b/tests/component_flask_rest/test_application.py @@ -62,7 +62,7 @@ def application(request): ] -@validate_code_level_metrics(TEST_APPLICATION_PREFIX + ".IndexResource", "get") +@validate_code_level_metrics("_test_application.create_app..IndexResource", "get", py2_namespace="_test_application.IndexResource") @validate_transaction_errors(errors=[]) @validate_transaction_metrics("_test_application:index", scoped_metrics=_test_application_index_scoped_metrics) def test_application_index(application): @@ -88,7 +88,7 @@ def test_application_index(application): ], ) def test_application_raises(exception, status_code, ignore_status_code, propagate_exceptions, application): - @validate_code_level_metrics(TEST_APPLICATION_PREFIX + ".ExceptionResource", "get") + @validate_code_level_metrics("_test_application.create_app..ExceptionResource", "get", py2_namespace="_test_application.ExceptionResource") @validate_transaction_metrics("_test_application:exception", scoped_metrics=_test_application_raises_scoped_metrics) def _test(): try: @@ -118,4 +118,4 @@ def test_application_outside_transaction(application): def _test(): application.get("/exception/werkzeug.exceptions:HTTPException/404", status=404) - _test() + _test() \ No newline at end of file diff --git a/tests/component_graphqlserver/_target_schema_async.py b/tests/component_graphqlserver/_target_schema_async.py index c48be2126..aff587bc8 100644 --- a/tests/component_graphqlserver/_target_schema_async.py +++ b/tests/component_graphqlserver/_target_schema_async.py @@ -103,62 +103,33 @@ async def resolve_error(root, info): raise RuntimeError("Runtime Error!") -try: - hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) - library_field = GraphQLField( - Library, - resolver=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolver=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolver=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolver=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolver=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) -except TypeError: - hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) - library_field = GraphQLField( - Library, - resolve=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolve=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolve=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolve=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolve=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) +hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) +library_field = GraphQLField( + Library, + resolver=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolver=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolver=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolver=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolver=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) query = GraphQLObjectType( name="Query", diff --git a/tests/component_graphqlserver/test_graphql.py b/tests/component_graphqlserver/test_graphql.py index 22cfda306..098f50970 100644 --- a/tests/component_graphqlserver/test_graphql.py +++ b/tests/component_graphqlserver/test_graphql.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. + import importlib + import pytest from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.validators.validate_span_events import validate_span_events from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.common.object_names import callable_name @@ -32,9 +38,11 @@ def is_graphql_2(): major_version = int(version.split(".")[0]) return major_version == 2 + @pytest.fixture(scope="session", params=("Sanic", "Flask")) def target_application(request): from . import _test_graphql + framework = request.param version = importlib.import_module(framework.lower()).__version__ diff --git a/tests/cross_agent/test_agent_attributes.py b/tests/cross_agent/test_agent_attributes.py index c254be772..527b31a75 100644 --- a/tests/cross_agent/test_agent_attributes.py +++ b/tests/cross_agent/test_agent_attributes.py @@ -40,7 +40,8 @@ def _default_settings(): 'browser_monitoring.attributes.exclude': [], } -FIXTURE = os.path.join(os.curdir, 'fixtures', 'attribute_configuration.json') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'attribute_configuration.json') def _load_tests(): with open(FIXTURE, 'r') as fh: diff --git a/tests/cross_agent/test_datstore_instance.py b/tests/cross_agent/test_datstore_instance.py index aa095400f..e2a7c0b15 100644 --- a/tests/cross_agent/test_datstore_instance.py +++ b/tests/cross_agent/test_datstore_instance.py @@ -23,7 +23,8 @@ from newrelic.core.database_node import DatabaseNode from newrelic.core.stats_engine import StatsEngine -FIXTURE = os.path.join(os.curdir, +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'datastores', 'datastore_instances.json') _parameters_list = ['name', 'system_hostname', 'db_hostname', diff --git a/tests/cross_agent/test_docker.py b/tests/cross_agent/test_docker.py index 9bc1a7363..fd919932b 100644 --- a/tests/cross_agent/test_docker.py +++ b/tests/cross_agent/test_docker.py @@ -19,7 +19,8 @@ import newrelic.common.utilization as u -DOCKER_FIXTURE = os.path.join(os.curdir, 'fixtures', 'docker_container_id') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +DOCKER_FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'docker_container_id') def _load_docker_test_attributes(): diff --git a/tests/cross_agent/test_labels_and_rollups.py b/tests/cross_agent/test_labels_and_rollups.py index d333ec35b..15ebb1e36 100644 --- a/tests/cross_agent/test_labels_and_rollups.py +++ b/tests/cross_agent/test_labels_and_rollups.py @@ -21,7 +21,8 @@ from testing_support.fixtures import override_application_settings -FIXTURE = os.path.join(os.curdir, 'fixtures', 'labels.json') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'labels.json') def _load_tests(): with open(FIXTURE, 'r') as fh: diff --git a/tests/cross_agent/test_rules.py b/tests/cross_agent/test_rules.py index e37db787c..ce2983c90 100644 --- a/tests/cross_agent/test_rules.py +++ b/tests/cross_agent/test_rules.py @@ -16,23 +16,23 @@ import os import pytest -from newrelic.core.rules_engine import RulesEngine, NormalizationRule +from newrelic.api.application import application_instance +from newrelic.api.background_task import background_task +from newrelic.api.transaction import record_custom_metric +from newrelic.core.rules_engine import RulesEngine + +from testing_support.validators.validate_metric_payload import validate_metric_payload CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) FIXTURE = os.path.normpath(os.path.join( CURRENT_DIR, 'fixtures', 'rules.json')) + def _load_tests(): with open(FIXTURE, 'r') as fh: js = fh.read() return json.loads(js) -def _prepare_rules(test_rules): - # ensure all keys are present, if not present set to an empty string - for rule in test_rules: - for key in NormalizationRule._fields: - rule[key] = rule.get(key, '') - return test_rules def _make_case_insensitive(rules): # lowercase each rule @@ -42,14 +42,14 @@ def _make_case_insensitive(rules): rule['replacement'] = rule['replacement'].lower() return rules + @pytest.mark.parametrize('test_group', _load_tests()) def test_rules_engine(test_group): # FIXME: The test fixture assumes that matching is case insensitive when it # is not. To avoid errors, just lowercase all rules, inputs, and expected # values. - insense_rules = _make_case_insensitive(test_group['rules']) - test_rules = _prepare_rules(insense_rules) + test_rules = _make_case_insensitive(test_group['rules']) rules_engine = RulesEngine(test_rules) for test in test_group['tests']: @@ -66,3 +66,46 @@ def test_rules_engine(test_group): assert expected == '' else: assert result == expected + + +@pytest.mark.parametrize('test_group', _load_tests()) +def test_rules_engine_metric_harvest(test_group): + # FIXME: The test fixture assumes that matching is case insensitive when it + # is not. To avoid errors, just lowercase all rules, inputs, and expected + # values. + test_rules = _make_case_insensitive(test_group['rules']) + rules_engine = RulesEngine(test_rules) + + # Set rules engine on core application + api_application = application_instance(activate=False) + api_name = api_application.name + core_application = api_application._agent.application(api_name) + old_rules = core_application._rules_engine["metric"] # save previoius rules + core_application._rules_engine["metric"] = rules_engine + + def send_metrics(): + # Send all metrics in this test batch in one transaction, then harvest so the normalizer is run. + @background_task(name="send_metrics") + def _test(): + for test in test_group['tests']: + # lowercase each value + input_str = test['input'].lower() + record_custom_metric(input_str, {"count": 1}) + _test() + core_application.harvest() + + try: + # Create a map of all result metrics to validate after harvest + test_metrics = [] + for test in test_group['tests']: + expected = (test['expected'] or '').lower() + if expected == '': # Ignored + test_metrics.append((expected, None)) + else: + test_metrics.append((expected, 1)) + + # Harvest and validate resulting payload + validate_metric_payload(metrics=test_metrics)(send_metrics)() + finally: + # Replace original rules engine + core_application._rules_engine["metric"] = old_rules diff --git a/tests/cross_agent/test_rum_client_config.py b/tests/cross_agent/test_rum_client_config.py index c2a4a465f..5b8da4b84 100644 --- a/tests/cross_agent/test_rum_client_config.py +++ b/tests/cross_agent/test_rum_client_config.py @@ -26,10 +26,11 @@ ) from newrelic.api.wsgi_application import wsgi_application +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, "fixtures", "rum_client_config.json") def _load_tests(): - fixture = os.path.join(os.curdir, "fixtures", "rum_client_config.json") - with open(fixture, "r") as fh: + with open(FIXTURE, "r") as fh: js = fh.read() return json.loads(js) diff --git a/tests/datastore_firestore/conftest.py b/tests/datastore_firestore/conftest.py new file mode 100644 index 000000000..28e138fa2 --- /dev/null +++ b/tests/datastore_firestore/conftest.py @@ -0,0 +1,124 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import uuid + +import pytest + +from google.cloud.firestore import Client +from google.cloud.firestore import Client, AsyncClient + +from testing_support.db_settings import firestore_settings +from testing_support.fixture.event_loop import event_loop as loop # noqa: F401; pylint: disable=W0611 +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) + +from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.api.time_trace import current_trace +from newrelic.common.system_info import LOCALHOST_EQUIVALENTS, gethostname + +DB_SETTINGS = firestore_settings()[0] +FIRESTORE_HOST = DB_SETTINGS["host"] +FIRESTORE_PORT = DB_SETTINGS["port"] + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "debug.log_explain_plan_queries": True, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (datastore_firestore)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) + + +@pytest.fixture() +def instance_info(): + host = gethostname() if FIRESTORE_HOST in LOCALHOST_EQUIVALENTS else FIRESTORE_HOST + return {"host": host, "port_path_or_id": str(FIRESTORE_PORT), "db.instance": "projects/google-cloud-firestore-emulator/databases/(default)"} + + +@pytest.fixture(scope="session") +def client(): + os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) + client = Client() + # Ensure connection is available + client.collection("healthcheck").document("healthcheck").set( + {}, retry=None, timeout=5 + ) + return client + + +@pytest.fixture(scope="function") +def collection(client): + collection_ = client.collection("firestore_collection_" + str(uuid.uuid4())) + yield collection_ + client.recursive_delete(collection_) + + +@pytest.fixture(scope="session") +def async_client(loop): + os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) + client = AsyncClient() + loop.run_until_complete(client.collection("healthcheck").document("healthcheck").set({}, retry=None, timeout=5)) # Ensure connection is available + return client + + +@pytest.fixture(scope="function") +def async_collection(async_client, collection): + # Use the same collection name as the collection fixture + yield async_client.collection(collection.id) + + +@pytest.fixture(scope="session") +def assert_trace_for_generator(): + def _assert_trace_for_generator(generator_func, *args, **kwargs): + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + + # Check for generator trace on collections + _trace_check = [] + for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. + + return _assert_trace_for_generator + + +@pytest.fixture(scope="session") +def assert_trace_for_async_generator(loop): + def _assert_trace_for_async_generator(generator_func, *args, **kwargs): + _trace_check = [] + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + + async def coro(): + # Check for generator trace on collections + async for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + + loop.run_until_complete(coro()) + + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. + + return _assert_trace_for_async_generator diff --git a/tests/datastore_firestore/test_async_batching.py b/tests/datastore_firestore/test_async_batching.py new file mode 100644 index 000000000..08890c39a --- /dev/null +++ b/tests/datastore_firestore/test_async_batching.py @@ -0,0 +1,68 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture() +def exercise_async_write_batch(async_client, async_collection): + async def _exercise_async_write_batch(): + docs = [async_collection.document(str(x)) for x in range(1, 4)] + async_batch = async_client.batch() + for doc in docs: + async_batch.set(doc, {}) + + await async_batch.commit() + return _exercise_async_write_batch + + +def test_firestore_async_write_batch(loop, exercise_async_write_batch): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_write_batch") + def _test(): + loop.run_until_complete(exercise_async_write_batch()) + + _test() + + +def test_firestore_async_write_batch_trace_node_datastore_params(loop, exercise_async_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_write_batch()) + + _test() diff --git a/tests/datastore_firestore/test_async_client.py b/tests/datastore_firestore/test_async_client.py new file mode 100644 index 000000000..1a17181d5 --- /dev/null +++ b/tests/datastore_firestore/test_async_client.py @@ -0,0 +1,83 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture() +def existing_document(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_async_client(async_client, existing_document): + async def _exercise_async_client(): + assert len([_ async for _ in async_client.collections()]) >= 1 + doc = [_ async for _ in async_client.get_all([existing_document])][0] + assert doc.to_dict()["x"] == 1 + return _exercise_async_client + + +def test_firestore_async_client(loop, exercise_async_client): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_client") + def _test(): + loop.run_until_complete(exercise_async_client()) + + _test() + + +@background_task() +def test_firestore_async_client_generators(async_client, collection, assert_trace_for_async_generator): + doc = collection.document("test") + doc.set({}) + + assert_trace_for_async_generator(async_client.collections) + assert_trace_for_async_generator(async_client.get_all, [doc]) + + +def test_firestore_async_client_trace_node_datastore_params(loop, exercise_async_client, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_client()) + + _test() diff --git a/tests/datastore_firestore/test_async_collections.py b/tests/datastore_firestore/test_async_collections.py new file mode 100644 index 000000000..a1004a720 --- /dev/null +++ b/tests/datastore_firestore/test_async_collections.py @@ -0,0 +1,89 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture() +def exercise_async_collections(async_collection): + async def _exercise_async_collections(): + async_collection.document("DoesNotExist") + await async_collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + await async_collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + + documents_get = await async_collection.get() + assert len(documents_get) == 2 + documents_stream = [_ async for _ in async_collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ async for _ in async_collection.list_documents()] + assert len(documents_list) == 2 + return _exercise_async_collections + + +def test_firestore_async_collections(loop, exercise_async_collections, async_collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/add" % async_collection.id, 2), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/add", 2), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_collections", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_collections") + def _test(): + loop.run_until_complete(exercise_async_collections()) + + _test() + + +@background_task() +def test_firestore_async_collections_generators(collection, async_collection, assert_trace_for_async_generator): + collection.add({}) + collection.add({}) + assert len([_ for _ in collection.list_documents()]) == 2 + + assert_trace_for_async_generator(async_collection.stream) + assert_trace_for_async_generator(async_collection.list_documents) + + +def test_firestore_async_collections_trace_node_datastore_params(loop, exercise_async_collections, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_collections()) + + _test() diff --git a/tests/datastore_firestore/test_async_documents.py b/tests/datastore_firestore/test_async_documents.py new file mode 100644 index 000000000..9c0a30479 --- /dev/null +++ b/tests/datastore_firestore/test_async_documents.py @@ -0,0 +1,101 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture() +def exercise_async_documents(async_collection): + async def _exercise_async_documents(): + italy_doc = async_collection.document("Italy") + await italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + await italy_doc.get() + italian_cities = italy_doc.collection("cities") + await italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ async for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 + + usa_doc = async_collection.document("USA") + await usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + await usa_doc.update({"president": "Joe Biden"}) + + await async_collection.document("USA").delete() + return _exercise_async_documents + + +def test_firestore_async_documents(loop, exercise_async_documents): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/Italy/set", 1), + ("Datastore/statement/Firestore/Italy/get", 1), + ("Datastore/statement/Firestore/Italy/collections", 1), + ("Datastore/statement/Firestore/cities/add", 1), + ("Datastore/statement/Firestore/USA/create", 1), + ("Datastore/statement/Firestore/USA/update", 1), + ("Datastore/statement/Firestore/USA/delete", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/set", 1), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/add", 1), + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/create", 1), + ("Datastore/operation/Firestore/update", 1), + ("Datastore/operation/Firestore/delete", 1), + ("Datastore/all", 7), + ("Datastore/allOther", 7), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_documents", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_documents") + def _test(): + loop.run_until_complete(exercise_async_documents()) + + _test() + + +@background_task() +def test_firestore_async_documents_generators(collection, async_collection, assert_trace_for_async_generator): + subcollection_doc = collection.document("SubCollections") + subcollection_doc.set({}) + subcollection_doc.collection("collection1").add({}) + subcollection_doc.collection("collection2").add({}) + assert len([_ for _ in subcollection_doc.collections()]) == 2 + + async_subcollection = async_collection.document(subcollection_doc.id) + + assert_trace_for_async_generator(async_subcollection.collections) + + +def test_firestore_async_documents_trace_node_datastore_params(loop, exercise_async_documents, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_documents()) + + _test() diff --git a/tests/datastore_firestore/test_async_query.py b/tests/datastore_firestore/test_async_query.py new file mode 100644 index 000000000..c3e43d0e4 --- /dev/null +++ b/tests/datastore_firestore/test_async_query.py @@ -0,0 +1,225 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== AsyncQuery ===== + +@pytest.fixture() +def exercise_async_query(async_collection): + async def _exercise_async_query(): + async_query = async_collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + assert len(await async_query.get()) == 3 + assert len([_ async for _ in async_query.stream()]) == 3 + return _exercise_async_query + + +def test_firestore_async_query(loop, exercise_async_query, async_collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + # @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_query") + def _test(): + loop.run_until_complete(exercise_async_query()) + + _test() + + +@background_task() +def test_firestore_async_query_generators(async_collection, assert_trace_for_async_generator): + async_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_async_generator(async_query.stream) + + +def test_firestore_async_query_trace_node_datastore_params(loop, exercise_async_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_query()) + + _test() + +# ===== AsyncAggregationQuery ===== + +@pytest.fixture() +def exercise_async_aggregation_query(async_collection): + async def _exercise_async_aggregation_query(): + async_aggregation_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert (await async_aggregation_query.get())[0][0].value == 3 + assert [_ async for _ in async_aggregation_query.stream()][0][0].value == 3 + return _exercise_async_aggregation_query + + +def test_firestore_async_aggregation_query(loop, exercise_async_aggregation_query, async_collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_aggregation_query") + def _test(): + loop.run_until_complete(exercise_async_aggregation_query()) + + _test() + + +@background_task() +def test_firestore_async_aggregation_query_generators(async_collection, assert_trace_for_async_generator): + async_aggregation_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_async_generator(async_aggregation_query.stream) + + +def test_firestore_async_aggregation_query_trace_node_datastore_params(loop, exercise_async_aggregation_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_aggregation_query()) + + _test() + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, async_client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a coroutine that returns an async_generator of Cursor objects. + Each Cursor must point at a valid document path. To test this, we can patch the RPC to return 1 Cursor + which is pointed at any document available. The get_partitions will take that and make 2 QueryPartition + objects out of it, which should be enough to ensure we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + async def mock_partition_query(*args, **kwargs): + async def _mock_partition_query(): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + return _mock_partition_query() + + monkeypatch.setattr(async_client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_async_collection_group(async_client, async_collection): + async def _exercise_async_collection_group(): + async_collection_group = async_client.collection_group(async_collection.id) + assert len(await async_collection_group.get()) + assert len([d async for d in async_collection_group.stream()]) + + partitions = [p async for p in async_collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(await partitions.pop().query().get()) + assert len(documents) == 6 + return _exercise_async_collection_group + + +def test_firestore_async_collection_group(loop, exercise_async_collection_group, async_collection, patch_partition_queries): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_collection_group") + def _test(): + loop.run_until_complete(exercise_async_collection_group()) + + _test() + + +@background_task() +def test_firestore_async_collection_group_generators(async_client, async_collection, assert_trace_for_async_generator, patch_partition_queries): + async_collection_group = async_client.collection_group(async_collection.id) + assert_trace_for_async_generator(async_collection_group.get_partitions, 1) + + +def test_firestore_async_collection_group_trace_node_datastore_params(loop, exercise_async_collection_group, instance_info, patch_partition_queries): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_collection_group()) + + _test() diff --git a/tests/datastore_firestore/test_async_transaction.py b/tests/datastore_firestore/test_async_transaction.py new file mode 100644 index 000000000..134c080bd --- /dev/null +++ b/tests/datastore_firestore/test_async_transaction.py @@ -0,0 +1,149 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_async_transaction_commit(async_client, async_collection): + async def _exercise_async_transaction_commit(): + from google.cloud.firestore import async_transactional + + @async_transactional + async def _exercise(async_transaction): + # get a DocumentReference + with pytest.raises(TypeError): # get is currently broken. It attempts to await an async_generator instead of consuming it. + [_ async for _ in async_transaction.get(async_collection.document("doc1"))] + + # get a Query + with pytest.raises(TypeError): # get is currently broken. It attempts to await an async_generator instead of consuming it. + async_query = async_collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ async for _ in async_transaction.get(async_query)]) == 1 + + # get_all on a list of DocumentReferences + with pytest.raises(TypeError): # get_all is currently broken. It attempts to await an async_generator instead of consuming it. + all_docs = async_transaction.get_all([async_collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ async for _ in all_docs]) == 3 + + # set and delete methods + async_transaction.set(async_collection.document("doc2"), {"x": 0}) + async_transaction.delete(async_collection.document("doc3")) + + await _exercise(async_client.transaction()) + assert len([_ async for _ in async_collection.list_documents()]) == 2 + return _exercise_async_transaction_commit + + +@pytest.fixture() +def exercise_async_transaction_rollback(async_client, async_collection): + async def _exercise_async_transaction_rollback(): + from google.cloud.firestore import async_transactional + + @async_transactional + async def _exercise(async_transaction): + # set and delete methods + async_transaction.set(async_collection.document("doc2"), {"x": 99}) + async_transaction.delete(async_collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + await _exercise(async_client.transaction()) + assert len([_ async for _ in async_collection.list_documents()]) == 3 + return _exercise_async_transaction_rollback + + +def test_firestore_async_transaction_commit(loop, exercise_async_transaction_commit, async_collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + # ("Datastore/operation/Firestore/get_all", 2), + # ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + # ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), # Should be 5 if not for broken APIs + ("Datastore/allOther", 2), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_transaction") + def _test(): + loop.run_until_complete(exercise_async_transaction_commit()) + + _test() + + +def test_firestore_async_transaction_rollback(loop, exercise_async_transaction_rollback, async_collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_transaction") + def _test(): + loop.run_until_complete(exercise_async_transaction_rollback()) + + _test() + + +def test_firestore_async_transaction_commit_trace_node_datastore_params(loop, exercise_async_transaction_commit, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_transaction_commit()) + + _test() + + +def test_firestore_async_transaction_rollback_trace_node_datastore_params(loop, exercise_async_transaction_rollback, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_transaction_rollback()) + + _test() diff --git a/tests/datastore_firestore/test_batching.py b/tests/datastore_firestore/test_batching.py new file mode 100644 index 000000000..5dcdd7b39 --- /dev/null +++ b/tests/datastore_firestore/test_batching.py @@ -0,0 +1,124 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + +# ===== WriteBatch ===== + + +@pytest.fixture() +def exercise_write_batch(client, collection): + def _exercise_write_batch(): + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = client.batch() + for doc in docs: + batch.set(doc, {}) + + batch.commit() + return _exercise_write_batch + + +def test_firestore_write_batch(exercise_write_batch): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_write_batch") + def _test(): + exercise_write_batch() + + _test() + + +def test_firestore_write_batch_trace_node_datastore_params(exercise_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_write_batch() + + _test() + + +# ===== BulkWriteBatch ===== + + +@pytest.fixture() +def exercise_bulk_write_batch(client, collection): + def _exercise_bulk_write_batch(): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = BulkWriteBatch(client) + for doc in docs: + batch.set(doc, {}) + + batch.commit() + return _exercise_bulk_write_batch + + +def test_firestore_bulk_write_batch(exercise_bulk_write_batch): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_bulk_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_bulk_write_batch") + def _test(): + exercise_bulk_write_batch() + + _test() + + +def test_firestore_bulk_write_batch_trace_node_datastore_params(exercise_bulk_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_bulk_write_batch() + + _test() diff --git a/tests/datastore_firestore/test_client.py b/tests/datastore_firestore/test_client.py new file mode 100644 index 000000000..06580356a --- /dev/null +++ b/tests/datastore_firestore/test_client.py @@ -0,0 +1,81 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def sample_data(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_client(client, sample_data): + def _exercise_client(): + assert len([_ for _ in client.collections()]) + doc = [_ for _ in client.get_all([sample_data])][0] + assert doc.to_dict()["x"] == 1 + return _exercise_client + + +def test_firestore_client(exercise_client): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_client") + def _test(): + exercise_client() + + _test() + + +@background_task() +def test_firestore_client_generators(client, sample_data, assert_trace_for_generator): + assert_trace_for_generator(client.collections) + assert_trace_for_generator(client.get_all, [sample_data]) + + +def test_firestore_client_trace_node_datastore_params(exercise_client, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_client() + + _test() diff --git a/tests/datastore_firestore/test_collections.py b/tests/datastore_firestore/test_collections.py new file mode 100644 index 000000000..c5c443dce --- /dev/null +++ b/tests/datastore_firestore/test_collections.py @@ -0,0 +1,92 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_collections(collection): + def _exercise_collections(): + collection.document("DoesNotExist") + collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + + documents_get = collection.get() + assert len(documents_get) == 2 + documents_stream = [_ for _ in collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ for _ in collection.list_documents()] + assert len(documents_list) == 2 + return _exercise_collections + + +def test_firestore_collections(exercise_collections, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ("Datastore/statement/Firestore/%s/add" % collection.id, 2), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/add", 2), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collections", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collections") + def _test(): + exercise_collections() + + _test() + + +@background_task() +def test_firestore_collections_generators(collection, assert_trace_for_generator): + collection.add({}) + collection.add({}) + assert len([_ for _ in collection.list_documents()]) == 2 + + assert_trace_for_generator(collection.stream) + assert_trace_for_generator(collection.list_documents) + + +def test_firestore_collections_trace_node_datastore_params(exercise_collections, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_collections() + + _test() diff --git a/tests/datastore_firestore/test_documents.py b/tests/datastore_firestore/test_documents.py new file mode 100644 index 000000000..200689960 --- /dev/null +++ b/tests/datastore_firestore/test_documents.py @@ -0,0 +1,102 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_documents(collection): + def _exercise_documents(): + italy_doc = collection.document("Italy") + italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + italy_doc.get() + italian_cities = italy_doc.collection("cities") + italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 + + usa_doc = collection.document("USA") + usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + usa_doc.update({"president": "Joe Biden"}) + + collection.document("USA").delete() + return _exercise_documents + + +def test_firestore_documents(exercise_documents): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/Italy/set", 1), + ("Datastore/statement/Firestore/Italy/get", 1), + ("Datastore/statement/Firestore/Italy/collections", 1), + ("Datastore/statement/Firestore/cities/add", 1), + ("Datastore/statement/Firestore/USA/create", 1), + ("Datastore/statement/Firestore/USA/update", 1), + ("Datastore/statement/Firestore/USA/delete", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/set", 1), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/add", 1), + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/create", 1), + ("Datastore/operation/Firestore/update", 1), + ("Datastore/operation/Firestore/delete", 1), + ("Datastore/all", 7), + ("Datastore/allOther", 7), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_documents", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_documents") + def _test(): + exercise_documents() + + _test() + + +@background_task() +def test_firestore_documents_generators(collection, assert_trace_for_generator): + subcollection_doc = collection.document("SubCollections") + subcollection_doc.set({}) + subcollection_doc.collection("collection1").add({}) + subcollection_doc.collection("collection2").add({}) + assert len([_ for _ in subcollection_doc.collections()]) == 2 + + assert_trace_for_generator(subcollection_doc.collections) + + +def test_firestore_documents_trace_node_datastore_params(exercise_documents, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_documents() + + _test() diff --git a/tests/datastore_firestore/test_query.py b/tests/datastore_firestore/test_query.py new file mode 100644 index 000000000..5e681f53e --- /dev/null +++ b/tests/datastore_firestore/test_query.py @@ -0,0 +1,229 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== Query ===== + + +@pytest.fixture() +def exercise_query(collection): + def _exercise_query(): + query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + assert len(query.get()) == 3 + assert len([_ for _ in query.stream()]) == 3 + return _exercise_query + + +def test_firestore_query(exercise_query, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_query") + def _test(): + exercise_query() + + _test() + + +@background_task() +def test_firestore_query_generators(collection, assert_trace_for_generator): + query = collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_generator(query.stream) + + +def test_firestore_query_trace_node_datastore_params(exercise_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_query() + + _test() + +# ===== AggregationQuery ===== + + +@pytest.fixture() +def exercise_aggregation_query(collection): + def _exercise_aggregation_query(): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert aggregation_query.get()[0][0].value == 3 + assert [_ for _ in aggregation_query.stream()][0][0].value == 3 + return _exercise_aggregation_query + + +def test_firestore_aggregation_query(exercise_aggregation_query, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_aggregation_query") + def _test(): + exercise_aggregation_query() + + _test() + + +@background_task() +def test_firestore_aggregation_query_generators(collection, assert_trace_for_generator): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_generator(aggregation_query.stream) + + +def test_firestore_aggregation_query_trace_node_datastore_params(exercise_aggregation_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_aggregation_query() + + _test() + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a generator of Cursor objects. Each Cursor must point at a valid document path. + To test this, we can patch the RPC to return 1 Cursor which is pointed at any document available. + The get_partitions will take that and make 2 QueryPartition objects out of it, which should be enough to ensure + we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + def mock_partition_query(*args, **kwargs): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + + monkeypatch.setattr(client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_collection_group(client, collection, patch_partition_queries): + def _exercise_collection_group(): + collection_group = client.collection_group(collection.id) + assert len(collection_group.get()) + assert len([d for d in collection_group.stream()]) + + partitions = [p for p in collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(partitions.pop().query().get()) + assert len(documents) == 6 + return _exercise_collection_group + + +def test_firestore_collection_group(exercise_collection_group, client, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collection_group") + def _test(): + exercise_collection_group() + + _test() + + +@background_task() +def test_firestore_collection_group_generators(client, collection, assert_trace_for_generator, patch_partition_queries): + collection_group = client.collection_group(collection.id) + assert_trace_for_generator(collection_group.get_partitions, 1) + + +def test_firestore_collection_group_trace_node_datastore_params(exercise_collection_group, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_collection_group() + + _test() diff --git a/tests/datastore_firestore/test_transaction.py b/tests/datastore_firestore/test_transaction.py new file mode 100644 index 000000000..c322a797e --- /dev/null +++ b/tests/datastore_firestore/test_transaction.py @@ -0,0 +1,149 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_transaction_commit(client, collection): + def _exercise_transaction_commit(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # get a DocumentReference + [_ for _ in transaction.get(collection.document("doc1"))] + + # get a Query + query = collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ for _ in transaction.get(query)]) == 1 + + # get_all on a list of DocumentReferences + all_docs = transaction.get_all([collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ for _ in all_docs]) == 3 + + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 0}) + transaction.delete(collection.document("doc3")) + + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 2 + return _exercise_transaction_commit + + +@pytest.fixture() +def exercise_transaction_rollback(client, collection): + def _exercise_transaction_rollback(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 99}) + transaction.delete(collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 3 + return _exercise_transaction_rollback + + +def test_firestore_transaction_commit(exercise_transaction_commit, collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ("Datastore/operation/Firestore/get_all", 2), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback(exercise_transaction_rollback, collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_rollback() + + _test() + + +def test_firestore_transaction_commit_trace_node_datastore_params(exercise_transaction_commit, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback_trace_node_datastore_params(exercise_transaction_rollback, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_transaction_rollback() + + _test() diff --git a/tests/datastore_redis/test_asyncio.py b/tests/datastore_redis/test_asyncio.py index 97c1b7853..5ffdda582 100644 --- a/tests/datastore_redis/test_asyncio.py +++ b/tests/datastore_redis/test_asyncio.py @@ -28,20 +28,27 @@ # Settings DB_SETTINGS = redis_settings()[0] -REDIS_VERSION = get_package_version_tuple("redis") +REDIS_PY_VERSION = get_package_version_tuple("redis") # Metrics -_enable_scoped_metrics = [("Datastore/operation/Redis/publish", 3)] +_base_scoped_metrics = [("Datastore/operation/Redis/publish", 3)] -_enable_rollup_metrics = [ - ("Datastore/all", 3), - ("Datastore/allOther", 3), - ("Datastore/Redis/all", 3), - ("Datastore/Redis/allOther", 3), +if REDIS_PY_VERSION >= (5, 0): + _base_scoped_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) + +datastore_all_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 + +_base_rollup_metrics = [ + ("Datastore/all", datastore_all_metric_count), + ("Datastore/allOther", datastore_all_metric_count), + ("Datastore/Redis/all", datastore_all_metric_count), + ("Datastore/Redis/allOther", datastore_all_metric_count), ("Datastore/operation/Redis/publish", 3), - ("Datastore/instance/Redis/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 3), + ("Datastore/instance/Redis/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), datastore_all_metric_count), ] +if REDIS_PY_VERSION >= (5, 0): + _base_rollup_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) # Tests @@ -53,7 +60,7 @@ def client(loop): # noqa return loop.run_until_complete(redis.asyncio.Redis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0)) -@pytest.mark.skipif(REDIS_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") +@pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") @validate_transaction_metrics("test_asyncio:test_async_pipeline", background_task=True) @background_task() def test_async_pipeline(client, loop): # noqa @@ -65,11 +72,11 @@ async def _test_pipeline(client): loop.run_until_complete(_test_pipeline(client)) -@pytest.mark.skipif(REDIS_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") +@pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") @validate_transaction_metrics( "test_asyncio:test_async_pubsub", - scoped_metrics=_enable_scoped_metrics, - rollup_metrics=_enable_rollup_metrics, + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_base_rollup_metrics, background_task=True, ) @background_task() diff --git a/tests/datastore_redis/test_custom_conn_pool.py b/tests/datastore_redis/test_custom_conn_pool.py index 156c9ce31..8e4503b75 100644 --- a/tests/datastore_redis/test_custom_conn_pool.py +++ b/tests/datastore_redis/test_custom_conn_pool.py @@ -21,14 +21,16 @@ import redis from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.db_settings import redis_settings from testing_support.util import instance_hostname + DB_SETTINGS = redis_settings()[0] -REDIS_PY_VERSION = redis.VERSION +REDIS_PY_VERSION = get_package_version_tuple("redis") class FakeConnectionPool(object): @@ -56,39 +58,41 @@ def release(self, connection): # We don't record instance metrics when using redis blaster, # so we just check for base metrics. +datastore_all_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 -_base_scoped_metrics = ( +_base_scoped_metrics = [ ('Datastore/operation/Redis/get', 1), ('Datastore/operation/Redis/set', 1), ('Datastore/operation/Redis/client_list', 1), -) - -_base_rollup_metrics = ( - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Redis/all', 3), - ('Datastore/Redis/allOther', 3), +] +# client_setinfo was introduced in v5.0.0 and assigns info displayed in client_list output +if REDIS_PY_VERSION >= (5, 0): + _base_scoped_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) + +_base_rollup_metrics = [ + ('Datastore/all', datastore_all_metric_count), + ('Datastore/allOther', datastore_all_metric_count), + ('Datastore/Redis/all', datastore_all_metric_count), + ('Datastore/Redis/allOther', datastore_all_metric_count), ('Datastore/operation/Redis/get', 1), ('Datastore/operation/Redis/set', 1), ('Datastore/operation/Redis/client_list', 1), -) - -_disable_scoped_metrics = list(_base_scoped_metrics) -_disable_rollup_metrics = list(_base_rollup_metrics) - -_enable_scoped_metrics = list(_base_scoped_metrics) -_enable_rollup_metrics = list(_base_rollup_metrics) +] +if REDIS_PY_VERSION >= (5, 0): + _base_rollup_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) _host = instance_hostname(DB_SETTINGS['host']) _port = DB_SETTINGS['port'] _instance_metric_name = 'Datastore/instance/Redis/%s/%s' % (_host, _port) -_enable_rollup_metrics.append( - (_instance_metric_name, 3) +instance_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 + +_enable_rollup_metrics = _base_rollup_metrics.append( + (_instance_metric_name, instance_metric_count) ) -_disable_rollup_metrics.append( +_disable_rollup_metrics = _base_rollup_metrics.append( (_instance_metric_name, None) ) @@ -106,7 +110,7 @@ def exercise_redis(client): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_custom_conn_pool:test_fake_conn_pool_enable_instance', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -132,7 +136,7 @@ def test_fake_conn_pool_enable_instance(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_custom_conn_pool:test_fake_conn_pool_disable_instance', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/datastore_redis/test_execute_command.py b/tests/datastore_redis/test_execute_command.py index 747588072..741bc5034 100644 --- a/tests/datastore_redis/test_execute_command.py +++ b/tests/datastore_redis/test_execute_command.py @@ -16,6 +16,7 @@ import redis from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics @@ -23,7 +24,8 @@ from testing_support.util import instance_hostname DB_SETTINGS = redis_settings()[0] -REDIS_PY_VERSION = redis.VERSION +REDIS_PY_VERSION = get_package_version_tuple("redis") + # Settings @@ -36,34 +38,34 @@ # Metrics -_base_scoped_metrics = ( +_base_scoped_metrics = [ ('Datastore/operation/Redis/client_list', 1), -) - -_base_rollup_metrics = ( - ('Datastore/all', 1), - ('Datastore/allOther', 1), - ('Datastore/Redis/all', 1), - ('Datastore/Redis/allOther', 1), +] +if REDIS_PY_VERSION >= (5, 0): + _base_scoped_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) + +_base_rollup_metrics = [ + ('Datastore/all', 3), + ('Datastore/allOther', 3), + ('Datastore/Redis/all', 3), + ('Datastore/Redis/allOther', 3), ('Datastore/operation/Redis/client_list', 1), -) - -_disable_scoped_metrics = list(_base_scoped_metrics) -_disable_rollup_metrics = list(_base_rollup_metrics) - -_enable_scoped_metrics = list(_base_scoped_metrics) -_enable_rollup_metrics = list(_base_rollup_metrics) +] +if REDIS_PY_VERSION >= (5, 0): + _base_rollup_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) _host = instance_hostname(DB_SETTINGS['host']) _port = DB_SETTINGS['port'] _instance_metric_name = 'Datastore/instance/Redis/%s/%s' % (_host, _port) -_enable_rollup_metrics.append( - (_instance_metric_name, 1) +instance_metric_count = 3 if REDIS_PY_VERSION >= (5, 0) else 1 + +_enable_rollup_metrics = _base_rollup_metrics.append( + (_instance_metric_name, instance_metric_count) ) -_disable_rollup_metrics.append( +_disable_rollup_metrics = _base_rollup_metrics.append( (_instance_metric_name, None) ) @@ -76,7 +78,7 @@ def exercise_redis_single_arg(client): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_strict_redis_execute_command_two_args_enable', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -88,7 +90,7 @@ def test_strict_redis_execute_command_two_args_enable(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_strict_redis_execute_command_two_args_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() @@ -100,7 +102,7 @@ def test_strict_redis_execute_command_two_args_disabled(): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_redis_execute_command_two_args_enable', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -112,7 +114,7 @@ def test_redis_execute_command_two_args_enable(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_redis_execute_command_two_args_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() @@ -126,7 +128,7 @@ def test_redis_execute_command_two_args_disabled(): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_strict_redis_execute_command_as_one_arg_enable', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -140,7 +142,7 @@ def test_strict_redis_execute_command_as_one_arg_enable(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_strict_redis_execute_command_as_one_arg_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() @@ -154,7 +156,7 @@ def test_strict_redis_execute_command_as_one_arg_disabled(): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_redis_execute_command_as_one_arg_enable', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -168,7 +170,7 @@ def test_redis_execute_command_as_one_arg_enable(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_redis_execute_command_as_one_arg_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/datastore_redis/test_get_and_set.py b/tests/datastore_redis/test_get_and_set.py index 0e2df4bb1..720433ae3 100644 --- a/tests/datastore_redis/test_get_and_set.py +++ b/tests/datastore_redis/test_get_and_set.py @@ -48,10 +48,7 @@ ('Datastore/operation/Redis/set', 1), ) -_disable_scoped_metrics = list(_base_scoped_metrics) _disable_rollup_metrics = list(_base_rollup_metrics) - -_enable_scoped_metrics = list(_base_scoped_metrics) _enable_rollup_metrics = list(_base_rollup_metrics) _host = instance_hostname(DB_SETTINGS['host']) @@ -78,7 +75,7 @@ def exercise_redis(client): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_get_and_set:test_strict_redis_operation_enable_instance', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -90,7 +87,7 @@ def test_strict_redis_operation_enable_instance(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_get_and_set:test_strict_redis_operation_disable_instance', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() @@ -102,7 +99,7 @@ def test_strict_redis_operation_disable_instance(): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_get_and_set:test_redis_operation_enable_instance', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -114,7 +111,7 @@ def test_redis_operation_enable_instance(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_get_and_set:test_redis_operation_disable_instance', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/datastore_redis/test_instance_info.py b/tests/datastore_redis/test_instance_info.py index b3e9a0d5d..211e96169 100644 --- a/tests/datastore_redis/test_instance_info.py +++ b/tests/datastore_redis/test_instance_info.py @@ -15,9 +15,10 @@ import pytest import redis +from newrelic.common.package_version_utils import get_package_version_tuple from newrelic.hooks.datastore_redis import _conn_attrs_to_dict, _instance_info -REDIS_PY_VERSION = redis.VERSION +REDIS_PY_VERSION = get_package_version_tuple("redis") _instance_info_tests = [ ((), {}, ("localhost", "6379", "0")), diff --git a/tests/datastore_redis/test_multiple_dbs.py b/tests/datastore_redis/test_multiple_dbs.py index 15777cc38..9a5e299f0 100644 --- a/tests/datastore_redis/test_multiple_dbs.py +++ b/tests/datastore_redis/test_multiple_dbs.py @@ -16,6 +16,7 @@ import redis from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics @@ -23,6 +24,8 @@ from testing_support.util import instance_hostname DB_MULTIPLE_SETTINGS = redis_settings() +REDIS_PY_VERSION = get_package_version_tuple("redis") + # Settings @@ -35,27 +38,31 @@ # Metrics -_base_scoped_metrics = ( +_base_scoped_metrics = [ ('Datastore/operation/Redis/get', 1), ('Datastore/operation/Redis/set', 1), ('Datastore/operation/Redis/client_list', 1), -) - -_base_rollup_metrics = ( - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Redis/all', 3), - ('Datastore/Redis/allOther', 3), +] +# client_setinfo was introduced in v5.0.0 and assigns info displayed in client_list output +if REDIS_PY_VERSION >= (5, 0): + _base_scoped_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) + +datastore_all_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 + +_base_rollup_metrics = [ + ('Datastore/all', datastore_all_metric_count), + ('Datastore/allOther', datastore_all_metric_count), + ('Datastore/Redis/all', datastore_all_metric_count), + ('Datastore/Redis/allOther', datastore_all_metric_count), ('Datastore/operation/Redis/get', 1), ('Datastore/operation/Redis/set', 1), ('Datastore/operation/Redis/client_list', 1), -) +] -_disable_scoped_metrics = list(_base_scoped_metrics) -_disable_rollup_metrics = list(_base_rollup_metrics) +# client_setinfo was introduced in v5.0.0 and assigns info displayed in client_list output +if REDIS_PY_VERSION >= (5, 0): + _base_rollup_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) -_enable_scoped_metrics = list(_base_scoped_metrics) -_enable_rollup_metrics = list(_base_rollup_metrics) if len(DB_MULTIPLE_SETTINGS) > 1: redis_1 = DB_MULTIPLE_SETTINGS[0] @@ -70,16 +77,20 @@ instance_metric_name_1 = 'Datastore/instance/Redis/%s/%s' % (host_1, port_1) instance_metric_name_2 = 'Datastore/instance/Redis/%s/%s' % (host_2, port_2) - _enable_rollup_metrics.extend([ - (instance_metric_name_1, 2), - (instance_metric_name_2, 1), + instance_metric_name_1_count = 2 if REDIS_PY_VERSION >= (5, 0) else 2 + instance_metric_name_2_count = 3 if REDIS_PY_VERSION >= (5, 0) else 1 + + _enable_rollup_metrics = _base_rollup_metrics.extend([ + (instance_metric_name_1, instance_metric_name_1_count), + (instance_metric_name_2, instance_metric_name_2_count), ]) - _disable_rollup_metrics.extend([ + _disable_rollup_metrics = _base_rollup_metrics.extend([ (instance_metric_name_1, None), (instance_metric_name_2, None), ]) + def exercise_redis(client_1, client_2): client_1.set('key', 'value') client_1.get('key') @@ -90,7 +101,7 @@ def exercise_redis(client_1, client_2): reason='Test environment not configured with multiple databases.') @override_application_settings(_enable_instance_settings) @validate_transaction_metrics('test_multiple_dbs:test_multiple_datastores_enabled', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -106,7 +117,7 @@ def test_multiple_datastores_enabled(): reason='Test environment not configured with multiple databases.') @override_application_settings(_disable_instance_settings) @validate_transaction_metrics('test_multiple_dbs:test_multiple_datastores_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/datastore_redis/test_rb.py b/tests/datastore_redis/test_rb.py index 5678c2787..3b25593be 100644 --- a/tests/datastore_redis/test_rb.py +++ b/tests/datastore_redis/test_rb.py @@ -23,6 +23,7 @@ import six from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics @@ -30,7 +31,7 @@ from testing_support.util import instance_hostname DB_SETTINGS = redis_settings()[0] -REDIS_PY_VERSION = redis.VERSION +REDIS_PY_VERSION = get_package_version_tuple("redis") # Settings @@ -61,10 +62,7 @@ ('Datastore/operation/Redis/set', 1), ) -_disable_scoped_metrics = list(_base_scoped_metrics) _disable_rollup_metrics = list(_base_rollup_metrics) - -_enable_scoped_metrics = list(_base_scoped_metrics) _enable_rollup_metrics = list(_base_rollup_metrics) _host = instance_hostname(DB_SETTINGS['host']) @@ -80,25 +78,26 @@ (_instance_metric_name, None) ) -# Operations +# Operations def exercise_redis(routing_client): routing_client.set('key', 'value') routing_client.get('key') + def exercise_fanout(cluster): with cluster.fanout(hosts='all') as client: client.execute_command('CLIENT', 'LIST') -# Tests +# Tests @pytest.mark.skipif(six.PY3, reason='Redis Blaster is Python 2 only.') @pytest.mark.skipif(REDIS_PY_VERSION < (2, 10, 2), reason='Redis Blaster requires redis>=2.10.2') @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_rb:test_redis_blaster_operation_enable_instance', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -121,7 +120,7 @@ def test_redis_blaster_operation_enable_instance(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_rb:test_redis_blaster_operation_disable_instance', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/framework_ariadne/_target_application.py b/tests/framework_ariadne/_target_application.py index a59e7432e..fef782608 100644 --- a/tests/framework_ariadne/_target_application.py +++ b/tests/framework_ariadne/_target_application.py @@ -27,7 +27,7 @@ from framework_ariadne._target_schema_sync import ( target_wsgi_application as target_wsgi_application_sync, ) -from framework_ariadne.test_application import ariadne_version_tuple +from framework_ariadne._target_schema_sync import ariadne_version_tuple from graphql import MiddlewareManager diff --git a/tests/framework_ariadne/_target_schema_sync.py b/tests/framework_ariadne/_target_schema_sync.py index 2725f0866..8860e71ac 100644 --- a/tests/framework_ariadne/_target_schema_sync.py +++ b/tests/framework_ariadne/_target_schema_sync.py @@ -26,7 +26,9 @@ from framework_graphql._target_schema_sync import books, magazines, libraries from testing_support.asgi_testing import AsgiTest -from framework_ariadne.test_application import ariadne_version_tuple +from framework_ariadne.test_application import ARIADNE_VERSION + +ariadne_version_tuple = tuple(map(int, ARIADNE_VERSION.split("."))) if ariadne_version_tuple < (0, 16): from ariadne.asgi import GraphQL as GraphQLASGI diff --git a/tests/framework_ariadne/conftest.py b/tests/framework_ariadne/conftest.py index 210399bb9..42b08faba 100644 --- a/tests/framework_ariadne/conftest.py +++ b/tests/framework_ariadne/conftest.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import six -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { "transaction_tracer.explain_threshold": 0.0, diff --git a/tests/framework_graphene/_target_application.py b/tests/framework_graphene/_target_application.py index 22d18897a..3f4b23e57 100644 --- a/tests/framework_graphene/_target_application.py +++ b/tests/framework_graphene/_target_application.py @@ -11,13 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from graphql import __version__ as version -from newrelic.packages import six - +from ._target_schema_async import target_schema as target_schema_async from ._target_schema_sync import target_schema as target_schema_sync - - -is_graphql_2 = int(version.split(".")[0]) == 2 +from framework_graphene.test_application import GRAPHENE_VERSION def check_response(query, response): @@ -34,61 +30,26 @@ def _run_sync(query, middleware=None): check_response(query, response) return response.data + return _run_sync def run_async(schema): import asyncio + def _run_async(query, middleware=None): loop = asyncio.get_event_loop() response = loop.run_until_complete(schema.execute_async(query, middleware=middleware)) check_response(query, response) return response.data - return _run_async - -def run_promise(schema): - def _run_promise(query, middleware=None): - response = schema.execute(query, middleware=middleware, return_promise=True).get() - check_response(query, response) - - return response.data - return _run_promise - - -def run_promise(schema, scheduler): - from graphql import graphql - from promise import set_default_scheduler - - def _run_promise(query, middleware=None): - set_default_scheduler(scheduler) - - promise = graphql(schema, query, middleware=middleware, return_promise=True) - response = promise.get() - - check_response(query, response) - - return response.data - - return _run_promise + return _run_async target_application = { "sync-sync": run_sync(target_schema_sync), -} + "async-sync": run_async(target_schema_sync), + "async-async": run_async(target_schema_async), + } -if is_graphql_2: - from ._target_schema_promise import target_schema as target_schema_promise - from promise.schedulers.immediate import ImmediateScheduler - - if six.PY3: - from promise.schedulers.asyncio import AsyncioScheduler as AsyncScheduler - else: - from promise.schedulers.thread import ThreadScheduler as AsyncScheduler - target_application["sync-promise"] = run_promise(target_schema_promise, ImmediateScheduler()) - target_application["async-promise"] = run_promise(target_schema_promise, AsyncScheduler()) -elif six.PY3: - from ._target_schema_async import target_schema as target_schema_async - target_application["async-sync"] = run_async(target_schema_sync) - target_application["async-async"] = run_async(target_schema_async) diff --git a/tests/framework_graphene/_target_schema_promise.py b/tests/framework_graphene/_target_schema_promise.py deleted file mode 100644 index 905f47a0b..000000000 --- a/tests/framework_graphene/_target_schema_promise.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from graphene import Field, Int, List -from graphene import Mutation as GrapheneMutation -from graphene import NonNull, ObjectType, Schema, String, Union -from promise import promisify - -from ._target_schema_sync import Author, Book, Magazine, Item, Library, Storage, authors, books, magazines, libraries - - -storage = [] - - -@promisify -def resolve_library(self, info, index): - return libraries[index] - -@promisify -def resolve_storage(self, info): - return [storage.pop()] - -@promisify -def resolve_search(self, info, contains): - search_books = [b for b in books if contains in b.name] - search_magazines = [m for m in magazines if contains in m.name] - return search_books + search_magazines - -@promisify -def resolve_hello(self, info): - return "Hello!" - -@promisify -def resolve_echo(self, info, echo): - return echo - -@promisify -def resolve_error(self, info): - raise RuntimeError("Runtime Error!") - -@promisify -def resolve_storage_add(self, info, string): - storage.append(string) - return StorageAdd(string=string) - - -class StorageAdd(GrapheneMutation): - class Arguments: - string = String(required=True) - - string = String() - mutate = resolve_storage_add - - -class Query(ObjectType): - library = Field(Library, index=Int(required=True), resolver=resolve_library) - hello = String(resolver=resolve_hello) - search = Field(List(Item), contains=String(required=True), resolver=resolve_search) - echo = Field(String, echo=String(required=True), resolver=resolve_echo) - storage = Field(Storage, resolver=resolve_storage) - error = String(resolver=resolve_error) - error_non_null = Field(NonNull(String), resolver=resolve_error) - error_middleware = String(resolver=resolve_hello) - - -class Mutation(ObjectType): - storage_add = StorageAdd.Field() - - -target_schema = Schema(query=Query, mutation=Mutation, auto_camelcase=False) diff --git a/tests/framework_graphene/test_application.py b/tests/framework_graphene/test_application.py index c4d1f15d6..838f3b515 100644 --- a/tests/framework_graphene/test_application.py +++ b/tests/framework_graphene/test_application.py @@ -15,9 +15,12 @@ import pytest from framework_graphql.test_application import * +from newrelic.common.package_version_utils import get_package_version +GRAPHENE_VERSION = get_package_version("graphene") -@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async", "sync-promise", "async-promise"]) + +@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async"]) def target_application(request): from ._target_application import target_application @@ -26,17 +29,9 @@ def target_application(request): pytest.skip("Unsupported combination.") return - try: - import graphene - version = graphene.__version__ - except Exception: - import pkg_resources - version = pkg_resources.get_distribution("graphene").version - param = request.param.split("-") is_background = param[0] not in {"wsgi", "asgi"} schema_type = param[1] extra_spans = 4 if param[0] == "wsgi" else 0 - - assert version is not None - return "Graphene", version, target_application, is_background, schema_type, extra_spans + assert GRAPHENE_VERSION is not None + return "Graphene", GRAPHENE_VERSION, target_application, is_background, schema_type, extra_spans diff --git a/tests/framework_graphql/_target_application.py b/tests/framework_graphql/_target_application.py index 5ed7d9edd..91da5d767 100644 --- a/tests/framework_graphql/_target_application.py +++ b/tests/framework_graphql/_target_application.py @@ -12,19 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from graphql import __version__ as version from graphql.language.source import Source -from newrelic.packages import six -from newrelic.hooks.framework_graphql import is_promise - +from ._target_schema_async import target_schema as target_schema_async from ._target_schema_sync import target_schema as target_schema_sync -is_graphql_2 = int(version.split(".")[0]) == 2 - - -def check_response(query, response): +def check_response(query, response): if isinstance(query, str) and "error" not in query or isinstance(query, Source) and "error" not in query.body: assert not response.errors, response.errors assert response.data @@ -41,7 +35,7 @@ def _run_sync(query, middleware=None): response = graphql(schema, query, middleware=middleware) - check_response(query, response) + check_response(query, response) return response.data @@ -50,6 +44,7 @@ def _run_sync(query, middleware=None): def run_async(schema): import asyncio + from graphql import graphql def _run_async(query, middleware=None): @@ -57,46 +52,15 @@ def _run_async(query, middleware=None): loop = asyncio.get_event_loop() response = loop.run_until_complete(coro) - check_response(query, response) + check_response(query, response) return response.data return _run_async -def run_promise(schema, scheduler): - from graphql import graphql - from promise import set_default_scheduler - - def _run_promise(query, middleware=None): - set_default_scheduler(scheduler) - - promise = graphql(schema, query, middleware=middleware, return_promise=True) - response = promise.get() - - check_response(query, response) - - return response.data - - return _run_promise - - target_application = { "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "async-async": run_async(target_schema_async), } - -if is_graphql_2: - from ._target_schema_promise import target_schema as target_schema_promise - from promise.schedulers.immediate import ImmediateScheduler - - if six.PY3: - from promise.schedulers.asyncio import AsyncioScheduler as AsyncScheduler - else: - from promise.schedulers.thread import ThreadScheduler as AsyncScheduler - - target_application["sync-promise"] = run_promise(target_schema_promise, ImmediateScheduler()) - target_application["async-promise"] = run_promise(target_schema_promise, AsyncScheduler()) -elif six.PY3: - from ._target_schema_async import target_schema as target_schema_async - target_application["async-sync"] = run_async(target_schema_sync) - target_application["async-async"] = run_async(target_schema_async) diff --git a/tests/framework_graphql/_target_schema_async.py b/tests/framework_graphql/_target_schema_async.py index 1ea417c10..aad4eb271 100644 --- a/tests/framework_graphql/_target_schema_async.py +++ b/tests/framework_graphql/_target_schema_async.py @@ -24,10 +24,7 @@ GraphQLUnionType, ) -try: - from ._target_schema_sync import books, libraries, magazines -except ImportError: - from framework_graphql._target_schema_sync import books, libraries, magazines +from ._target_schema_sync import books, libraries, magazines storage = [] @@ -106,62 +103,33 @@ async def resolve_error(root, info): raise RuntimeError("Runtime Error!") -try: - hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) - library_field = GraphQLField( - Library, - resolver=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolver=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolver=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolver=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolver=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) -except TypeError: - hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) - library_field = GraphQLField( - Library, - resolve=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolve=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolve=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolve=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolve=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) +hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) +library_field = GraphQLField( + Library, + resolve=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolve=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolve=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolve=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolve=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) query = GraphQLObjectType( name="Query", diff --git a/tests/framework_graphql/_target_schema_promise.py b/tests/framework_graphql/_target_schema_promise.py deleted file mode 100644 index b0bf8cef7..000000000 --- a/tests/framework_graphql/_target_schema_promise.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from graphql import ( - GraphQLArgument, - GraphQLField, - GraphQLInt, - GraphQLList, - GraphQLNonNull, - GraphQLObjectType, - GraphQLSchema, - GraphQLString, - GraphQLUnionType, -) -from promise import promisify - -from ._target_schema_sync import books, libraries, magazines - -storage = [] - - -@promisify -def resolve_library(parent, info, index): - return libraries[index] - - -@promisify -def resolve_storage_add(parent, info, string): - storage.append(string) - return string - - -@promisify -def resolve_storage(parent, info): - return [storage.pop()] - - -@promisify -def resolve_search(parent, info, contains): - search_books = [b for b in books if contains in b["name"]] - search_magazines = [m for m in magazines if contains in m["name"]] - return search_books + search_magazines - - -Author = GraphQLObjectType( - "Author", - { - "first_name": GraphQLField(GraphQLString), - "last_name": GraphQLField(GraphQLString), - }, -) - -Book = GraphQLObjectType( - "Book", - { - "id": GraphQLField(GraphQLInt), - "name": GraphQLField(GraphQLString), - "isbn": GraphQLField(GraphQLString), - "author": GraphQLField(Author), - "branch": GraphQLField(GraphQLString), - }, -) - -Magazine = GraphQLObjectType( - "Magazine", - { - "id": GraphQLField(GraphQLInt), - "name": GraphQLField(GraphQLString), - "issue": GraphQLField(GraphQLInt), - "branch": GraphQLField(GraphQLString), - }, -) - - -Library = GraphQLObjectType( - "Library", - { - "id": GraphQLField(GraphQLInt), - "branch": GraphQLField(GraphQLString), - "book": GraphQLField(GraphQLList(Book)), - "magazine": GraphQLField(GraphQLList(Magazine)), - }, -) - -Storage = GraphQLList(GraphQLString) - - -@promisify -def resolve_hello(root, info): - return "Hello!" - - -@promisify -def resolve_echo(root, info, echo): - return echo - - -@promisify -def resolve_error(root, info): - raise RuntimeError("Runtime Error!") - - -try: - hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) - library_field = GraphQLField( - Library, - resolver=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolver=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolver=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolver=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolver=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) -except TypeError: - hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) - library_field = GraphQLField( - Library, - resolve=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolve=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolve=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolve=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolve=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) - -query = GraphQLObjectType( - name="Query", - fields={ - "hello": hello_field, - "library": library_field, - "search": search_field, - "echo": echo_field, - "storage": storage_field, - "error": error_field, - "error_non_null": error_non_null_field, - "error_middleware": error_middleware_field, - }, -) - -mutation = GraphQLObjectType( - name="Mutation", - fields={ - "storage_add": storage_add_field, - }, -) - -target_schema = GraphQLSchema(query=query, mutation=mutation) diff --git a/tests/framework_graphql/_target_schema_sync.py b/tests/framework_graphql/_target_schema_sync.py index ddfd8d190..302a6c66e 100644 --- a/tests/framework_graphql/_target_schema_sync.py +++ b/tests/framework_graphql/_target_schema_sync.py @@ -158,62 +158,33 @@ def resolve_error(root, info): raise RuntimeError("Runtime Error!") -try: - hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) - library_field = GraphQLField( - Library, - resolver=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolver=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolver=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolver=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolver=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) -except TypeError: - hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) - library_field = GraphQLField( - Library, - resolve=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolve=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolve=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolve=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolve=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) +hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) +library_field = GraphQLField( + Library, + resolve=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolve=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolve=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolve=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolve=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) query = GraphQLObjectType( name="Query", diff --git a/tests/framework_graphql/conftest.py b/tests/framework_graphql/conftest.py index 48cac2226..5302da2b8 100644 --- a/tests/framework_graphql/conftest.py +++ b/tests/framework_graphql/conftest.py @@ -13,10 +13,12 @@ # limitations under the License. import pytest -import six - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) +from newrelic.packages import six _default_settings = { "transaction_tracer.explain_threshold": 0.0, @@ -31,7 +33,8 @@ default_settings=_default_settings, ) -@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async", "sync-promise", "async-promise"]) + +@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async"]) def target_application(request): from ._target_application import target_application diff --git a/tests/framework_graphql/test_application.py b/tests/framework_graphql/test_application.py index 2f636cf38..77abc5443 100644 --- a/tests/framework_graphql/test_application.py +++ b/tests/framework_graphql/test_application.py @@ -13,10 +13,11 @@ # limitations under the License. import pytest -from testing_support.fixtures import dt_enabled, override_application_settings -from testing_support.validators.validate_code_level_metrics import ( - validate_code_level_metrics, +from framework_graphql.test_application_async import ( + error_middleware_async, + example_middleware_async, ) +from testing_support.fixtures import dt_enabled, override_application_settings from testing_support.validators.validate_code_level_metrics import ( validate_code_level_metrics, ) @@ -33,9 +34,11 @@ from newrelic.api.background_task import background_task from newrelic.common.object_names import callable_name -from newrelic.packages import six +from newrelic.common.package_version_utils import get_package_version +graphql_version = get_package_version("graphql-core") + def conditional_decorator(decorator, condition): def _conditional_decorator(func): if not condition: @@ -45,14 +48,6 @@ def _conditional_decorator(func): return _conditional_decorator -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - - major_version = int(version.split(".")[0]) - return major_version == 2 - - def to_graphql_source(query): def delay_import(): try: @@ -61,13 +56,6 @@ def delay_import(): # Fallback if Source is not implemented return query - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - return Source(query) return delay_import @@ -82,32 +70,27 @@ def error_middleware(next, root, info, **args): raise RuntimeError("Runtime Error!") -example_middleware = [example_middleware] -error_middleware = [error_middleware] - -if six.PY3: - try: - from test_application_async import error_middleware_async, example_middleware_async - except ImportError: - from framework_graphql.test_application_async import error_middleware_async, example_middleware_async +def test_no_harm_no_transaction(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application - example_middleware.append(example_middleware_async) - error_middleware.append(error_middleware_async) + def _test(): + response = target_application("{ __schema { types { name } } }") + _test() -def test_no_harm_no_transaction(target_application): - framework, version, target_application, is_bg, schema_type, extra_spans = target_application - response = target_application("{ __schema { types { name } } }") - assert not response.get("errors", None) +example_middleware = [example_middleware] +error_middleware = [error_middleware] +example_middleware.append(example_middleware_async) +error_middleware.append(error_middleware_async) _runtime_error_name = callable_name(RuntimeError) _test_runtime_error = [(_runtime_error_name, "Runtime Error!")] def _graphql_base_rollup_metrics(framework, version, background_task=True): - from graphql import __version__ as graphql_version + graphql_version = get_package_version("graphql-core") metrics = [ ("Python/Framework/GraphQL/%s" % graphql_version, 1), @@ -148,13 +131,12 @@ def test_basic(target_application): def _test(): response = target_application("{ hello }") assert response["hello"] == "Hello!" - assert not response.get("errors", None) _test() @dt_enabled -def test_query_and_mutation(target_application, is_graphql_2): +def test_query_and_mutation(target_application): framework, version, target_application, is_bg, schema_type, extra_spans = target_application mutation_path = "storage_add" if framework != "Graphene" else "storage_add.string" @@ -190,7 +172,9 @@ def test_query_and_mutation(target_application, is_graphql_2): "graphql.field.returnType": "[String%s]%s" % (type_annotation, type_annotation), } - @validate_code_level_metrics("framework_%s._target_schema_%s" % (framework.lower(), schema_type), "resolve_storage_add") + @validate_code_level_metrics( + "framework_%s._target_schema_%s" % (framework.lower(), schema_type), "resolve_storage_add" + ) @validate_span_events(exact_agents=_expected_mutation_operation_attributes) @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) @validate_transaction_metrics( @@ -376,15 +360,12 @@ def _test(): ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), ], ) -def test_exception_in_validation(target_application, is_graphql_2, query, exc_class): +def test_exception_in_validation(target_application, query, exc_class): framework, version, target_application, is_bg, schema_type, extra_spans = target_application if "syntax" in query: txn_name = "graphql.language.parser:parse" else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" + txn_name = "graphql.validation.validate:validate" # Import path differs between versions if exc_class == "GraphQLError": @@ -450,7 +431,6 @@ def test_operation_metrics_and_attrs(target_application): @conditional_decorator(background_task(), is_bg) def _test(): response = target_application("query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert not response.get("errors", None) _test() @@ -583,7 +563,7 @@ def _test(): @pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_ignored_introspection_transactions(target_application, capture_introspection_setting): +def test_introspection_transactions(target_application, capture_introspection_setting): framework, version, target_application, is_bg, schema_type, extra_spans = target_application txn_ct = 1 if capture_introspection_setting else 0 @@ -594,6 +574,5 @@ def test_ignored_introspection_transactions(target_application, capture_introspe @background_task() def _test(): response = target_application("{ __schema { types { name } } }") - assert not response.get("errors", None) - _test() + _test() \ No newline at end of file diff --git a/tests/framework_starlette/test_graphql.py b/tests/framework_starlette/test_graphql.py deleted file mode 100644 index 24ec3ab38..000000000 --- a/tests/framework_starlette/test_graphql.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import pytest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -def get_starlette_version(): - import starlette - - version = getattr(starlette, "__version__", "0.0.0").split(".") - return tuple(int(x) for x in version) - - -@pytest.fixture(scope="session") -def target_application(): - import _test_graphql - - return _test_graphql.target_application - - -@dt_enabled -@pytest.mark.parametrize("endpoint", ("/async", "/sync")) -@pytest.mark.skipif(get_starlette_version() >= (0, 17), reason="Starlette GraphQL support dropped in v0.17.0") -def test_graphql_metrics_and_attrs(target_application, endpoint): - from graphql import __version__ as version - - from newrelic.hooks.framework_graphene import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Graphene/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_scoped_metrics = [ - ("GraphQL/resolve/Graphene/hello", 1), - ("GraphQL/operation/Graphene/query//hello", 1), - ] - _test_unscoped_metrics = [ - ("GraphQL/all", 1), - ("GraphQL/Graphene/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Graphene/allWeb", 1), - ] + _test_scoped_metrics - - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - "graphql.operation.query": "{ hello }", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String", - } - - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=_test_scoped_metrics, - rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS, - ) - def _test(): - response = target_application.make_request( - "POST", endpoint, body=json.dumps({"query": "{ hello }"}), headers={"Content-Type": "application/json"} - ) - assert response.status == 200 - assert "Hello!" in response.body.decode("utf-8") - - _test() diff --git a/tests/framework_strawberry/_target_application.py b/tests/framework_strawberry/_target_application.py index ec618be5c..afba04873 100644 --- a/tests/framework_strawberry/_target_application.py +++ b/tests/framework_strawberry/_target_application.py @@ -15,10 +15,18 @@ import asyncio import json -import pytest -from ._target_schema_sync import target_schema as target_schema_sync, target_asgi_application as target_asgi_application_sync -from ._target_schema_async import target_schema as target_schema_async, target_asgi_application as target_asgi_application_async +import pytest +from framework_strawberry._target_schema_async import ( + target_asgi_application as target_asgi_application_async, +) +from framework_strawberry._target_schema_async import ( + target_schema as target_schema_async, +) +from framework_strawberry._target_schema_sync import ( + target_asgi_application as target_asgi_application_sync, +) +from framework_strawberry._target_schema_sync import target_schema as target_schema_sync def run_sync(schema): @@ -36,6 +44,7 @@ def _run_sync(query, middleware=None): assert response.errors return response.data + return _run_sync @@ -55,6 +64,7 @@ def _run_async(query, middleware=None): assert response.errors return response.data + return _run_async @@ -78,6 +88,7 @@ def _run_asgi(query, middleware=None): assert "errors" not in body or not body["errors"] return body["data"] + return _run_asgi diff --git a/tests/framework_strawberry/_target_schema_async.py b/tests/framework_strawberry/_target_schema_async.py index 397166d4d..373cef537 100644 --- a/tests/framework_strawberry/_target_schema_async.py +++ b/tests/framework_strawberry/_target_schema_async.py @@ -13,17 +13,23 @@ # limitations under the License. from typing import List + import strawberry.mutation import strawberry.type +from framework_strawberry._target_schema_sync import ( + Item, + Library, + Storage, + books, + libraries, + magazines, +) from strawberry import Schema, field from strawberry.asgi import GraphQL from strawberry.schema.config import StrawberryConfig from strawberry.types.types import Optional from testing_support.asgi_testing import AsgiTest -from ._target_schema_sync import Library, Item, Storage, books, magazines, libraries - - storage = [] diff --git a/tests/framework_strawberry/conftest.py b/tests/framework_strawberry/conftest.py index c5cdbf0c8..6345b3033 100644 --- a/tests/framework_strawberry/conftest.py +++ b/tests/framework_strawberry/conftest.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import six - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { "transaction_tracer.explain_threshold": 0.0, @@ -30,7 +29,3 @@ app_name="Python Agent Test (framework_strawberry)", default_settings=_default_settings, ) - - -if six.PY2: - collect_ignore = ["test_application_async.py"] diff --git a/tests/framework_strawberry/test_application.py b/tests/framework_strawberry/test_application.py index 76082dee9..5a3f579ba 100644 --- a/tests/framework_strawberry/test_application.py +++ b/tests/framework_strawberry/test_application.py @@ -11,25 +11,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import pytest from framework_graphql.test_application import * +from testing_support.fixtures import override_application_settings +from testing_support.validators.validate_transaction_count import ( + validate_transaction_count, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version + +STRAWBERRY_VERSION = get_package_version("strawberry-graphql") @pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async", "asgi-sync", "asgi-async"]) def target_application(request): from ._target_application import target_application - target_application = target_application[request.param] - try: - import strawberry - version = strawberry.__version__ - except Exception: - import pkg_resources - version = pkg_resources.get_distribution("strawberry-graphql").version + target_application = target_application[request.param] is_asgi = "asgi" in request.param schema_type = request.param.split("-")[1] - assert version is not None - return "Strawberry", version, target_application, not is_asgi, schema_type, 0 + assert STRAWBERRY_VERSION is not None + return "Strawberry", STRAWBERRY_VERSION, target_application, not is_asgi, schema_type, 0 + + +@pytest.mark.parametrize("capture_introspection_setting", (True, False)) +def test_introspection_transactions(target_application, capture_introspection_setting): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + + txn_ct = 1 if capture_introspection_setting else 0 + + @override_application_settings( + {"instrumentation.graphql.capture_introspection_queries": capture_introspection_setting} + ) + @validate_transaction_count(txn_ct) + @background_task() + def _test(): + response = target_application("{ __schema { types { name } } }") + _test() diff --git a/tests/messagebroker_pika/test_pika_async_connection_consume.py b/tests/messagebroker_pika/test_pika_async_connection_consume.py index 4e44c7ed7..5a5ce86b6 100644 --- a/tests/messagebroker_pika/test_pika_async_connection_consume.py +++ b/tests/messagebroker_pika/test_pika_async_connection_consume.py @@ -49,20 +49,20 @@ from newrelic.api.background_task import background_task + DB_SETTINGS = rabbitmq_settings()[0] _message_broker_tt_params = { - "queue_name": QUEUE, - "routing_key": QUEUE, - "correlation_id": CORRELATION_ID, - "reply_to": REPLY_TO, - "headers": HEADERS.copy(), + 'queue_name': QUEUE, + 'routing_key': QUEUE, + 'correlation_id': CORRELATION_ID, + 'reply_to': REPLY_TO, + 'headers': HEADERS.copy(), } # Tornado's IO loop is not configurable in versions 5.x and up try: - class MyIOLoop(tornado.ioloop.IOLoop.configured_class()): def handle_callback_exception(self, *args, **kwargs): raise @@ -73,44 +73,38 @@ def handle_callback_exception(self, *args, **kwargs): connection_classes = [pika.SelectConnection, TornadoConnection] -parametrized_connection = pytest.mark.parametrize("ConnectionClass", connection_classes) +parametrized_connection = pytest.mark.parametrize('ConnectionClass', + connection_classes) _test_select_conn_basic_get_inside_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, 1), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, 1), ] if six.PY3: _test_select_conn_basic_get_inside_txn_metrics.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_get_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn.' + '.on_message'), 1)) else: - _test_select_conn_basic_get_inside_txn_metrics.append(("Function/test_pika_async_connection_consume:on_message", 1)) + _test_select_conn_basic_get_inside_txn_metrics.append( + ('Function/test_pika_async_connection_consume:on_message', 1)) @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_async_connection_consume" + (".test_async_connection_basic_get_inside_txn." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('callback_as_partial', [True, False]) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_get_inside_txn.", "on_message", py2_namespace="test_pika_async_connection_consume") @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_inside_txn"), - scoped_metrics=_test_select_conn_basic_get_inside_txn_metrics, - rollup_metrics=_test_select_conn_basic_get_inside_txn_metrics, - background_task=True, -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn'), + scoped_metrics=_test_select_conn_basic_get_inside_txn_metrics, + rollup_metrics=_test_select_conn_basic_get_inside_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_inside_txn(producer, ConnectionClass, callback_as_partial): +def test_async_connection_basic_get_inside_txn(producer, ConnectionClass, + callback_as_partial): def on_message(channel, method_frame, header_frame, body): assert method_frame assert body == BODY @@ -128,7 +122,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -139,8 +135,9 @@ def on_open_connection(connection): @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) -def test_select_connection_basic_get_outside_txn(producer, ConnectionClass, callback_as_partial): +@pytest.mark.parametrize('callback_as_partial', [True, False]) +def test_select_connection_basic_get_outside_txn(producer, ConnectionClass, + callback_as_partial): metrics_list = [] @capture_transaction_metrics(metrics_list) @@ -163,8 +160,8 @@ def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) connection = ConnectionClass( - pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection - ) + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -181,24 +178,25 @@ def on_open_connection(connection): _test_select_conn_basic_get_inside_txn_no_callback_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @pytest.mark.skipif( - condition=pika_version_info[0] > 0, reason="pika 1.0 removed the ability to use basic_get with callback=None" -) + condition=pika_version_info[0] > 0, + reason='pika 1.0 removed the ability to use basic_get with callback=None') @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_inside_txn_no_callback"), + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn_no_callback'), scoped_metrics=_test_select_conn_basic_get_inside_txn_no_callback_metrics, rollup_metrics=_test_select_conn_basic_get_inside_txn_no_callback_metrics, - background_task=True, -) + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_inside_txn_no_callback(producer, ConnectionClass): +def test_async_connection_basic_get_inside_txn_no_callback(producer, + ConnectionClass): def on_open_channel(channel): channel.basic_get(callback=None, queue=QUEUE) channel.close() @@ -208,7 +206,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -219,26 +219,27 @@ def on_open_connection(connection): _test_async_connection_basic_get_empty_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) +@pytest.mark.parametrize('callback_as_partial', [True, False]) @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_empty"), - scoped_metrics=_test_async_connection_basic_get_empty_metrics, - rollup_metrics=_test_async_connection_basic_get_empty_metrics, - background_task=True, -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_empty'), + scoped_metrics=_test_async_connection_basic_get_empty_metrics, + rollup_metrics=_test_async_connection_basic_get_empty_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_empty(ConnectionClass, callback_as_partial): - QUEUE = "test_async_empty" +def test_async_connection_basic_get_empty(ConnectionClass, + callback_as_partial): + QUEUE = 'test_async_empty' def on_message(channel, method_frame, header_frame, body): - assert False, body.decode("UTF-8") + assert False, body.decode('UTF-8') if callback_as_partial: on_message = functools.partial(on_message) @@ -252,7 +253,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -263,42 +266,33 @@ def on_open_connection(connection): _test_select_conn_basic_consume_in_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_select_conn_basic_consume_in_txn_metrics.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_inside_txn.' + '.on_message'), 1)) else: - _test_select_conn_basic_consume_in_txn_metrics.append(("Function/test_pika_async_connection_consume:on_message", 1)) + _test_select_conn_basic_consume_in_txn_metrics.append( + ('Function/test_pika_async_connection_consume:on_message', 1)) @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_consume_inside_txn"), - scoped_metrics=_test_select_conn_basic_consume_in_txn_metrics, - rollup_metrics=_test_select_conn_basic_consume_in_txn_metrics, - background_task=True, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_inside_txn." if six.PY3 else ""), - "on_message", -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_inside_txn'), + scoped_metrics=_test_select_conn_basic_consume_in_txn_metrics, + rollup_metrics=_test_select_conn_basic_consume_in_txn_metrics, + background_task=True) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_inside_txn.", "on_message", py2_namespace="test_pika_async_connection_consume") @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_async_connection_basic_consume_inside_txn(producer, ConnectionClass): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -311,7 +305,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -322,67 +318,46 @@ def on_open_connection(connection): _test_select_conn_basic_consume_two_exchanges = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE_2, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE_2, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE_2, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE_2, None), ] if six.PY3: _test_select_conn_basic_consume_two_exchanges.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_two_exchanges." - ".on_message_1" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges.' + '.on_message_1'), 1)) _test_select_conn_basic_consume_two_exchanges.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_two_exchanges." - ".on_message_2" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges.' + '.on_message_2'), 1)) else: _test_select_conn_basic_consume_two_exchanges.append( - ("Function/test_pika_async_connection_consume:on_message_1", 1) - ) + ('Function/test_pika_async_connection_consume:on_message_1', 1)) _test_select_conn_basic_consume_two_exchanges.append( - ("Function/test_pika_async_connection_consume:on_message_2", 1) - ) + ('Function/test_pika_async_connection_consume:on_message_2', 1)) @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_consume_two_exchanges"), - scoped_metrics=_test_select_conn_basic_consume_two_exchanges, - rollup_metrics=_test_select_conn_basic_consume_two_exchanges, - background_task=True, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_two_exchanges." if six.PY3 else ""), - "on_message_1", -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_two_exchanges." if six.PY3 else ""), - "on_message_2", -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges'), + scoped_metrics=_test_select_conn_basic_consume_two_exchanges, + rollup_metrics=_test_select_conn_basic_consume_two_exchanges, + background_task=True) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_two_exchanges.", "on_message_1", py2_namespace="test_pika_async_connection_consume") +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_two_exchanges.", "on_message_2", py2_namespace="test_pika_async_connection_consume") @background_task() -def test_async_connection_basic_consume_two_exchanges(producer, producer_2, ConnectionClass): +def test_async_connection_basic_consume_two_exchanges(producer, producer_2, + ConnectionClass): global events_received events_received = 0 def on_message_1(channel, method_frame, header_frame, body): channel.basic_ack(method_frame.delivery_tag) - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY global events_received @@ -395,7 +370,7 @@ def on_message_1(channel, method_frame, header_frame, body): def on_message_2(channel, method_frame, header_frame, body): channel.basic_ack(method_frame.delivery_tag) - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY global events_received @@ -413,7 +388,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -424,11 +401,12 @@ def on_open_connection(connection): # This should not create a transaction -@function_not_called("newrelic.core.stats_engine", "StatsEngine.record_transaction") -@override_application_settings({"debug.record_transaction_failure": True}) +@function_not_called('newrelic.core.stats_engine', + 'StatsEngine.record_transaction') +@override_application_settings({'debug.record_transaction_failure': True}) def test_tornado_connection_basic_consume_outside_transaction(producer): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -441,7 +419,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = TornadoConnection(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = TornadoConnection( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -452,44 +432,31 @@ def on_open_connection(connection): if six.PY3: - _txn_name = ( - "test_pika_async_connection_consume:" - "test_select_connection_basic_consume_outside_transaction." - ".on_message" - ) + _txn_name = ('test_pika_async_connection_consume:' + 'test_select_connection_basic_consume_outside_transaction.' + '.on_message') _test_select_connection_consume_outside_txn_metrics = [ - ( - ( - "Function/test_pika_async_connection_consume:" - "test_select_connection_basic_consume_outside_transaction." - ".on_message" - ), - None, - ) - ] + (('Function/test_pika_async_connection_consume:' + 'test_select_connection_basic_consume_outside_transaction.' + '.on_message'), None)] else: - _txn_name = "test_pika_async_connection_consume:on_message" + _txn_name = ( + 'test_pika_async_connection_consume:on_message') _test_select_connection_consume_outside_txn_metrics = [ - ("Function/test_pika_async_connection_consume:on_message", None) - ] + ('Function/test_pika_async_connection_consume:on_message', None)] # This should create a transaction @validate_transaction_metrics( - _txn_name, - scoped_metrics=_test_select_connection_consume_outside_txn_metrics, - rollup_metrics=_test_select_connection_consume_outside_txn_metrics, - background_task=True, - group="Message/RabbitMQ/Exchange/%s" % EXCHANGE, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_select_connection_basic_consume_outside_transaction." if six.PY3 else ""), - "on_message", -) + _txn_name, + scoped_metrics=_test_select_connection_consume_outside_txn_metrics, + rollup_metrics=_test_select_connection_consume_outside_txn_metrics, + background_task=True, + group='Message/RabbitMQ/Exchange/%s' % EXCHANGE) +@validate_code_level_metrics("test_pika_async_connection_consume.test_select_connection_basic_consume_outside_transaction.", "on_message", py2_namespace="test_pika_async_connection_consume") def test_select_connection_basic_consume_outside_transaction(producer): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -503,12 +470,12 @@ def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) connection = pika.SelectConnection( - pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection - ) + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() except: connection.close() connection.ioloop.stop() - raise + raise \ No newline at end of file diff --git a/tests/messagebroker_pika/test_pika_blocking_connection_consume.py b/tests/messagebroker_pika/test_pika_blocking_connection_consume.py index 7b41674a2..92df917f7 100644 --- a/tests/messagebroker_pika/test_pika_blocking_connection_consume.py +++ b/tests/messagebroker_pika/test_pika_blocking_connection_consume.py @@ -38,30 +38,32 @@ DB_SETTINGS = rabbitmq_settings()[0] _message_broker_tt_params = { - "queue_name": QUEUE, - "routing_key": QUEUE, - "correlation_id": CORRELATION_ID, - "reply_to": REPLY_TO, - "headers": HEADERS.copy(), + 'queue_name': QUEUE, + 'routing_key': QUEUE, + 'correlation_id': CORRELATION_ID, + 'reply_to': REPLY_TO, + 'headers': HEADERS.copy(), } _test_blocking_connection_basic_get_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, 1), - (("Function/pika.adapters.blocking_connection:" "_CallbackResult.set_value_once"), 1), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, 1), + (('Function/pika.adapters.blocking_connection:' + '_CallbackResult.set_value_once'), 1) ] @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_get"), - scoped_metrics=_test_blocking_connection_basic_get_metrics, - rollup_metrics=_test_blocking_connection_basic_get_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_get'), + scoped_metrics=_test_blocking_connection_basic_get_metrics, + rollup_metrics=_test_blocking_connection_basic_get_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_get(producer): - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() method_frame, _, _ = channel.basic_get(QUEUE) assert method_frame @@ -69,22 +71,23 @@ def test_blocking_connection_basic_get(producer): _test_blocking_connection_basic_get_empty_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_get_empty"), - scoped_metrics=_test_blocking_connection_basic_get_empty_metrics, - rollup_metrics=_test_blocking_connection_basic_get_empty_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_get_empty'), + scoped_metrics=_test_blocking_connection_basic_get_empty_metrics, + rollup_metrics=_test_blocking_connection_basic_get_empty_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_get_empty(): - QUEUE = "test_blocking_empty-%s" % os.getpid() - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + QUEUE = 'test_blocking_empty-%s' % os.getpid() + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() channel.queue_declare(queue=QUEUE) @@ -100,7 +103,8 @@ def test_blocking_connection_basic_get_outside_transaction(producer): @capture_transaction_metrics(metrics_list) def test_basic_get(): - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() channel.queue_declare(queue=QUEUE) @@ -116,57 +120,46 @@ def test_basic_get(): _test_blocking_conn_basic_consume_no_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: - _txn_name = ( - "test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_outside_transaction." - ".on_message" - ) + _txn_name = ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_outside_transaction.' + '.on_message') _test_blocking_conn_basic_consume_no_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_outside_transaction." - ".on_message" - ), - None, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_outside_transaction.' + '.on_message'), None)) else: - _txn_name = "test_pika_blocking_connection_consume:" "on_message" + _txn_name = ('test_pika_blocking_connection_consume:' + 'on_message') _test_blocking_conn_basic_consume_no_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", None) - ) + ('Function/test_pika_blocking_connection_consume:on_message', None)) -@pytest.mark.parametrize("as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_blocking_connection_consume" - + (".test_blocking_connection_basic_consume_outside_transaction." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('as_partial', [True, False]) +@validate_code_level_metrics("test_pika_blocking_connection_consume.test_blocking_connection_basic_consume_outside_transaction.", "on_message", py2_namespace="test_pika_blocking_connection_consume") @validate_transaction_metrics( - _txn_name, - scoped_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, - background_task=True, - group="Message/RabbitMQ/Exchange/%s" % EXCHANGE, -) + _txn_name, + scoped_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, + background_task=True, + group='Message/RabbitMQ/Exchange/%s' % EXCHANGE) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) -def test_blocking_connection_basic_consume_outside_transaction(producer, as_partial): +def test_blocking_connection_basic_consume_outside_transaction(producer, + as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) @@ -178,51 +171,41 @@ def on_message(channel, method_frame, header_frame, body): _test_blocking_conn_basic_consume_in_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_blocking_conn_basic_consume_in_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_inside_txn.' + '.on_message'), 1)) else: _test_blocking_conn_basic_consume_in_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", 1) - ) + ('Function/test_pika_blocking_connection_consume:on_message', 1)) -@pytest.mark.parametrize("as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_blocking_connection_consume" - + (".test_blocking_connection_basic_consume_inside_txn." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('as_partial', [True, False]) +@validate_code_level_metrics("test_pika_blocking_connection_consume.test_blocking_connection_basic_consume_inside_txn.", "on_message", py2_namespace="test_pika_blocking_connection_consume") @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_consume_inside_txn"), - scoped_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_inside_txn'), + scoped_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_consume_inside_txn(producer, as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) try: @@ -233,40 +216,33 @@ def on_message(channel, method_frame, header_frame, body): _test_blocking_conn_basic_consume_stopped_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), - ("OtherTransaction/Message/RabbitMQ/Exchange/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), + ('OtherTransaction/Message/RabbitMQ/Exchange/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_blocking_conn_basic_consume_stopped_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_stopped_txn." - ".on_message" - ), - None, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_stopped_txn.' + '.on_message'), None)) else: _test_blocking_conn_basic_consume_stopped_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", None) - ) + ('Function/test_pika_blocking_connection_consume:on_message', None)) -@pytest.mark.parametrize("as_partial", [True, False]) +@pytest.mark.parametrize('as_partial', [True, False]) @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_consume_stopped_txn"), - scoped_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_stopped_txn'), + scoped_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_consume_stopped_txn(producer, as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() @@ -275,11 +251,12 @@ def on_message(channel, method_frame, header_frame, body): if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) try: channel.start_consuming() except: channel.stop_consuming() - raise + raise \ No newline at end of file diff --git a/tests/testing_support/db_settings.py b/tests/testing_support/db_settings.py index e32e2ecfa..f7bda3d7a 100644 --- a/tests/testing_support/db_settings.py +++ b/tests/testing_support/db_settings.py @@ -190,6 +190,28 @@ def mongodb_settings(): return settings +def firestore_settings(): + """Return a list of dict of settings for connecting to firestore. + + This only includes the host and port as the collection name is defined in + the firestore conftest file. + Will return the correct settings, depending on which of the environments it + is running in. It attempts to set variables in the following order, where + later environments override earlier ones. + + 1. Local + 2. Github Actions + """ + + host = "host.docker.internal" if "GITHUB_ACTIONS" in os.environ else "127.0.0.1" + instances = 2 + settings = [ + {"host": host, "port": 8080 + instance_num} + for instance_num in range(instances) + ] + return settings + + def elasticsearch_settings(): """Return a list of dict of settings for connecting to elasticsearch. diff --git a/tests/testing_support/validators/validate_code_level_metrics.py b/tests/testing_support/validators/validate_code_level_metrics.py index 1f99d9d52..c3a880b35 100644 --- a/tests/testing_support/validators/validate_code_level_metrics.py +++ b/tests/testing_support/validators/validate_code_level_metrics.py @@ -17,6 +17,7 @@ from testing_support.fixtures import dt_enabled from newrelic.common.object_wrapper import function_wrapper + def validate_code_level_metrics(namespace, function, py2_namespace=None, builtin=False, count=1, index=-1): """Verify that code level metrics are generated for a callable.""" @@ -42,5 +43,4 @@ def validate_code_level_metrics(namespace, function, py2_namespace=None, builtin def wrapper(wrapped, instance, args, kwargs): validator(dt_enabled(wrapped))(*args, **kwargs) - return wrapper - + return wrapper \ No newline at end of file diff --git a/tests/testing_support/validators/validate_datastore_trace_inputs.py b/tests/testing_support/validators/validate_datastore_trace_inputs.py index ade4ebea6..365a14ebd 100644 --- a/tests/testing_support/validators/validate_datastore_trace_inputs.py +++ b/tests/testing_support/validators/validate_datastore_trace_inputs.py @@ -23,7 +23,7 @@ """ -def validate_datastore_trace_inputs(operation=None, target=None): +def validate_datastore_trace_inputs(operation=None, target=None, host=None, port_path_or_id=None, database_name=None): @transient_function_wrapper("newrelic.api.datastore_trace", "DatastoreTrace.__init__") @catch_background_exceptions def _validate_datastore_trace_inputs(wrapped, instance, args, kwargs): @@ -44,6 +44,18 @@ def _bind_params(product, target, operation, host=None, port_path_or_id=None, da assert captured_target == target, "%s didn't match expected %s" % (captured_target, target) if operation is not None: assert captured_operation == operation, "%s didn't match expected %s" % (captured_operation, operation) + if host is not None: + assert captured_host == host, "%s didn't match expected %s" % (captured_host, host) + if port_path_or_id is not None: + assert captured_port_path_or_id == port_path_or_id, "%s didn't match expected %s" % ( + captured_port_path_or_id, + port_path_or_id, + ) + if database_name is not None: + assert captured_database_name == database_name, "%s didn't match expected %s" % ( + captured_database_name, + database_name, + ) return wrapped(*args, **kwargs) diff --git a/tox.ini b/tox.ini index d36edbe73..0aea970da 100644 --- a/tox.ini +++ b/tox.ini @@ -83,6 +83,7 @@ envlist = memcached-datastore_memcache-{py27,py37,py38,py39,py310,py311,pypy27,pypy38}-memcached01, mysql-datastore_mysql-mysql080023-py27, mysql-datastore_mysql-mysqllatest-{py37,py38,py39,py310,py311}, + firestore-datastore_firestore-{py37,py38,py39,py310,py311}, postgres-datastore_postgresql-{py37,py38,py39}, postgres-datastore_psycopg2-{py27,py37,py38,py39,py310,py311}-psycopg2latest postgres-datastore_psycopg2cffi-{py27,pypy27,py37,py38,py39,py310,py311}-psycopg2cffilatest, @@ -132,12 +133,9 @@ envlist = ; temporarily disabling flaskmaster tests python-framework_flask-{py37,py38,py39,py310,py311,pypy38}-flask{latest}, python-framework_graphene-{py37,py38,py39,py310,py311}-graphenelatest, - python-framework_graphene-{py27,py37,py38,py39,pypy27,pypy38}-graphene{0200,0201}, - python-framework_graphene-{py310,py311}-graphene0201, - python-framework_graphql-{py27,py37,py38,py39,py310,py311,pypy27,pypy38}-graphql02, - python-framework_graphql-{py37,py38,py39,py310,py311,pypy38}-graphql03, + python-framework_graphql-{py37,py38,py39,py310,py311,pypy38}-graphqllatest, ; temporarily disabling graphqlmaster tests - python-framework_graphql-py37-graphql{0202,0203,0300,0301,0302}, + python-framework_graphql-py37-graphql{0300,0301,0302}, grpc-framework_grpc-py27-grpc0125, grpc-framework_grpc-{py37,py38,py39,py310,py311}-grpclatest, python-framework_pyramid-{pypy27,py27,py38}-Pyramid0104, @@ -235,6 +233,7 @@ deps = datastore_elasticsearch: requests datastore_elasticsearch-elasticsearch07: elasticsearch<8.0 datastore_elasticsearch-elasticsearch08: elasticsearch<9.0 + datastore_firestore: google-cloud-firestore datastore_memcache-memcached01: python-memcached<2 datastore_mysql-mysqllatest: mysql-connector-python datastore_mysql-mysql080023: mysql-connector-python<8.0.24 @@ -308,12 +307,7 @@ deps = framework_flask-flaskmaster: https://github.com/pallets/werkzeug/archive/main.zip framework_flask-flaskmaster: https://github.com/pallets/flask/archive/main.zip#egg=flask[async] framework_graphene-graphenelatest: graphene - framework_graphene-graphene0200: graphene<2.1 - framework_graphene-graphene0201: graphene<2.2 - framework_graphql-graphql02: graphql-core<3 - framework_graphql-graphql03: graphql-core<4 - framework_graphql-graphql0202: graphql-core<2.3 - framework_graphql-graphql0203: graphql-core<2.4 + framework_graphql-graphqllatest: graphql-core<4 framework_graphql-graphql0300: graphql-core<3.1 framework_graphql-graphql0301: graphql-core<3.2 framework_graphql-graphql0302: graphql-core<3.3 @@ -345,7 +339,6 @@ deps = framework_sanic-saniclatest: sanic framework_sanic-sanic{1812,190301,1906}: aiohttp framework_sanic-sanic{1812,190301,1906,1912,200904,210300,2109,2112,2203,2290}: websockets<11 - framework_starlette: graphene<3 framework_starlette-starlette0014: starlette<0.15 framework_starlette-starlette0015: starlette<0.16 framework_starlette-starlette0019: starlette<0.20 @@ -437,6 +430,7 @@ changedir = datastore_asyncpg: tests/datastore_asyncpg datastore_bmemcached: tests/datastore_bmemcached datastore_elasticsearch: tests/datastore_elasticsearch + datastore_firestore: tests/datastore_firestore datastore_memcache: tests/datastore_memcache datastore_mysql: tests/datastore_mysql datastore_postgresql: tests/datastore_postgresql @@ -494,10 +488,10 @@ usefixtures = [coverage:run] branch = True disable_warnings = couldnt-parse -source = newrelic +source = newrelic [coverage:paths] -source = +source = newrelic/ .tox/**/site-packages/newrelic/ /__w/**/site-packages/newrelic/ @@ -506,4 +500,4 @@ source = directory = ${TOX_ENV_DIR-.}/htmlcov [coverage:xml] -output = ${TOX_ENV_DIR-.}/coverage.xml +output = ${TOX_ENV_DIR-.}/coverage.xml \ No newline at end of file